Skip to content

Commit

Permalink
allow the composite size to be specified in keywords
Browse files Browse the repository at this point in the history
  • Loading branch information
sumiya11 committed Dec 11, 2024
1 parent 26668d4 commit 4f73f9d
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 59 deletions.
88 changes: 44 additions & 44 deletions src/groebner/groebner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,43 +257,42 @@ function _groebner_learn_and_apply(

# At this point, either the reconstruction or the correctness check failed.
# Continue to compute Groebner bases modulo different primes in batches.
B = params.composite
primes_used = 1
batchsize = 4
batchsize = B
batchsize_scaling = 0.10

witness_set = modular_witness_set(state.gb_coeffs_zz, params)

# Initialize a tracer that computes the bases in batches of 4
trace_4x = trace_copy(trace, CompositeNumber{4, Int32}, false)
# Initialize a tracer that computes the bases in batches
trace_Bx = trace_copy(trace, CompositeNumber{B, Int32}, false)

iters = 0
while !correct_basis
for j in 1:4:batchsize
prime_4x = ntuple(i -> Int32(modular_next_prime!(state)), 4)
@invariant iszero(batchsize % B)

# Perform reduction modulo primes and store result in basis_ff_4x
ring_ff_4x, basis_ff_4x = modular_reduce_mod_p_in_batch!(ring, basis_zz, prime_4x)
params_zp_4x = params_mod_p(
for j in 1:B:batchsize
prime_Bx = ntuple(i -> Int32(modular_next_prime!(state)), B)

# Perform reduction modulo several primes
ring_ff_Bx, basis_ff_Bx = modular_reduce_mod_p_in_batch!(ring, basis_zz, prime_Bx)
params_zp_Bx = params_mod_p(
params,
CompositeNumber{4, Int32}(prime_4x),
CompositeNumber{B, Int32}(prime_Bx),
using_wide_type_for_coeffs=false
)
trace_4x.buf_basis = basis_ff_4x
trace_4x.ring = ring_ff_4x
trace_Bx.buf_basis = basis_ff_Bx
trace_Bx.ring = ring_ff_Bx

flag = f4_apply!(trace_4x, ring_ff_4x, trace_4x.buf_basis, params_zp_4x)
flag = f4_apply!(trace_Bx, ring_ff_Bx, trace_Bx.buf_basis, params_zp_Bx)
!flag && continue

gb_coeffs_1, gb_coeffs_2, gb_coeffs_3, gb_coeffs_4 =
ir_unpack_composite_coefficients(trace_4x.gb_basis.coeffs)
gb_coeffs_unpacked = ir_unpack_composite_coefficients(trace_Bx.gb_basis.coeffs)

# TODO: This causes unnecessary conversions of arrays.
append!(state.used_primes, prime_4x)
push!(state.gb_coeffs_ff_all, gb_coeffs_1)
push!(state.gb_coeffs_ff_all, gb_coeffs_2)
push!(state.gb_coeffs_ff_all, gb_coeffs_3)
push!(state.gb_coeffs_ff_all, gb_coeffs_4)
primes_used += 4
append!(state.used_primes, prime_Bx)
append!(state.gb_coeffs_ff_all, gb_coeffs_unpacked)
primes_used += B
end

crt_vec_partial!(
Expand Down Expand Up @@ -431,18 +430,19 @@ function _groebner_learn_and_apply_threaded(

# At this point, either the reconstruction or the correctness check failed.
# Continue to compute Groebner bases modulo different primes in batches.
B = params.composite
primes_used = 1
batchsize = align_up(min(32, 4 * nthreads()), 4)
batchsize = align_up(min(32, B * nthreads()), B)
batchsize_scaling = 0.10

# CRT and rational reconstrction settings
witness_set = modular_witness_set(state.gb_coeffs_zz, params)

# Initialize a tracer that computes the bases in batches of 4
trace_4x = trace_copy(trace, CompositeNumber{4, Int32}, false)
# Initialize a tracer that computes the bases in batches
trace_Bx = trace_copy(trace, CompositeNumber{B, Int32}, false)

# Thread buffers
threadbuf_trace_4x = map(_ -> trace_deepcopy(trace_4x), 1:nthreads())
threadbuf_trace_Bx = map(_ -> trace_deepcopy(trace_Bx), 1:nthreads())
threadbuf_gb_coeffs = Vector{Vector{Tuple{Int32, Vector{Vector{Int32}}}}}(undef, nthreads())
for i in 1:nthreads()
threadbuf_gb_coeffs[i] = Vector{Tuple{Int, Vector{Vector{Int32}}}}()
Expand All @@ -451,44 +451,44 @@ function _groebner_learn_and_apply_threaded(

iters = 0
while !correct_basis
@invariant iszero(batchsize % 4)
@invariant iszero(batchsize % B)

threadbuf_primes = ntuple(_ -> Int32(modular_next_prime!(state)), batchsize)
for i in 1:nthreads()
empty!(threadbuf_gb_coeffs[i])
end

Base.Threads.@threads :static for j in 1:4:batchsize
Base.Threads.@threads :static for j in 1:B:batchsize
t_id = threadid()
threadlocal_trace_4x = threadbuf_trace_4x[t_id]
threadlocal_prime_4x = ntuple(k -> threadbuf_primes[j + k - 1], 4)
threadlocal_trace_Bx = threadbuf_trace_Bx[t_id]
threadlocal_prime_Bx = ntuple(k -> threadbuf_primes[j + k - 1], B)
threadlocal_params = threadbuf_params[t_id]

ring_ff_4x, basis_ff_4x =
modular_reduce_mod_p_in_batch!(ring, basis_zz, threadlocal_prime_4x)
threadlocal_params_zp_4x = params_mod_p(
ring_ff_Bx, basis_ff_Bx =
modular_reduce_mod_p_in_batch!(ring, basis_zz, threadlocal_prime_Bx)
threadlocal_params_zp_Bx = params_mod_p(
threadlocal_params, # can be mutated later
CompositeNumber{4, Int32}(threadlocal_prime_4x),
CompositeNumber{B, Int32}(threadlocal_prime_Bx),
using_wide_type_for_coeffs=false
)
threadlocal_trace_4x.buf_basis = basis_ff_4x
threadlocal_trace_4x.ring = ring_ff_4x
threadlocal_trace_Bx.buf_basis = basis_ff_Bx
threadlocal_trace_Bx.ring = ring_ff_Bx

flag = f4_apply!(
threadlocal_trace_4x,
ring_ff_4x,
threadlocal_trace_4x.buf_basis,
threadlocal_params_zp_4x
threadlocal_trace_Bx,
ring_ff_Bx,
threadlocal_trace_Bx.buf_basis,
threadlocal_params_zp_Bx
)
!flag && continue

gb_coeffs_1, gb_coeffs_2, gb_coeffs_3, gb_coeffs_4 =
ir_unpack_composite_coefficients(threadlocal_trace_4x.gb_basis.coeffs)
gb_coeffs_unpacked =
ir_unpack_composite_coefficients(threadlocal_trace_Bx.gb_basis.coeffs)

push!(threadbuf_gb_coeffs[t_id], (threadlocal_prime_4x[1], gb_coeffs_1))
push!(threadbuf_gb_coeffs[t_id], (threadlocal_prime_4x[2], gb_coeffs_2))
push!(threadbuf_gb_coeffs[t_id], (threadlocal_prime_4x[3], gb_coeffs_3))
push!(threadbuf_gb_coeffs[t_id], (threadlocal_prime_4x[4], gb_coeffs_4))
append!(
threadbuf_gb_coeffs[t_id],
collect(zip(threadlocal_prime_Bx, gb_coeffs_unpacked))
)
end

primes_used += batchsize
Expand Down
6 changes: 3 additions & 3 deletions src/groebner/modular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,10 @@ function modular_reduce_mod_p_in_batch!(
coeffs_ff_xn[i][j] = CompositeNumber{N, T}(data)
end
end
ring_ff_4x = PolyRing(ring.nvars, ring.ord, CompositeNumber{N, T}(prime_xn), :zp)
basis_ff_4x = basis_deep_copy_with_new_coeffs(basis, coeffs_ff_xn)
ring_ff_Nx = PolyRing(ring.nvars, ring.ord, CompositeNumber{N, T}(prime_xn), :zp)
basis_ff_Nx = basis_deep_copy_with_new_coeffs(basis, coeffs_ff_xn)

ring_ff_4x, basis_ff_4x
ring_ff_Nx, basis_ff_Nx
end

function modular_prepare!(state::ModularState)
Expand Down
10 changes: 5 additions & 5 deletions src/groebner/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ mutable struct AlgorithmParameters{Arithmetic <: AbstractArithmetic}
# - :learn_and_apply
modular_strategy::Symbol

# If learn & apply strategy can use apply in batches
batched::Bool
# The width of composite numbers in learn & apply
composite::Int

# Multi-threading
threaded_f4::Symbol
Expand Down Expand Up @@ -326,7 +326,7 @@ function AlgorithmParameters(ring::PolyRing, kwargs::KeywordArguments; hint=:non
# falling back to classic multi-modular algorithm.
modular_strategy = :classic_modular
end
batched = kwargs.batched
composite = kwargs._composite

seed = kwargs.seed
rng = Random.Xoshiro(seed)
Expand Down Expand Up @@ -360,7 +360,7 @@ function AlgorithmParameters(ring::PolyRing, kwargs::KeywordArguments; hint=:non
representation,
reduced,
modular_strategy,
batched,
composite,
threaded_f4,
threaded_multimodular,
rng,
Expand Down Expand Up @@ -396,7 +396,7 @@ function params_mod_p(
representation,
params.reduced,
params.modular_strategy,
params.batched,
params.composite,
params.threaded_f4,
params.threaded_multimodular,
params.rng,
Expand Down
2 changes: 1 addition & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ trace, gb_0 = groebner_learn(kat_0);
# 72.813 ms (23722 allocations: 59.44 MiB)
```
Observe the better amortized performance of the batched `groebner_apply!`.
Observe the better amortized performance of the composite `groebner_apply!`.
## Notes
Expand Down
12 changes: 6 additions & 6 deletions src/utils/keywords.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ const _supported_kw_args = (
modular = :auto,
threaded = :auto,
homogenize = :auto,
batched = true,
changematrix = false,
_composite = 4,
_generic = false
),
normalform = (
Expand Down Expand Up @@ -64,8 +64,8 @@ const _supported_kw_args = (
modular = :auto,
threaded = :auto,
homogenize = :auto,
batched = true,
changematrix = true
changematrix = true,
_composite = 4,
),
)
#! format: on
Expand Down Expand Up @@ -93,7 +93,7 @@ mutable struct KeywordArguments
seed::Int
selection::Symbol
modular::Symbol
batched::Bool
_composite::Int
check::Bool
homogenize::Symbol
changematrix::Bool
Expand Down Expand Up @@ -154,7 +154,7 @@ mutable struct KeywordArguments
Possible choices for keyword "modular" are:
`:auto`, `:classic_modular`, `:learn_and_apply`"""

batched = get(kws, :batched, get(default_kw_args, :batched, true))
_composite = get(kws, :_composite, get(default_kw_args, :_composite, 4))

selection = get(kws, :selection, get(default_kw_args, :selection, :auto))
@assert selection in (:auto, :normal, :sugar, :be_divided_and_perish)
Expand Down Expand Up @@ -182,7 +182,7 @@ mutable struct KeywordArguments
seed,
selection,
modular,
batched,
_composite,
check,
homogenize,
changematrix,
Expand Down
7 changes: 7 additions & 0 deletions test/groebner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,13 @@ end
gb = Groebner.groebner(noon, modular=modular)
@test Groebner.isgroebner(gb)

chan = Groebner.Examples.chandran(6)
gb_truth = Groebner.groebner(chan)
for _composite in (1, 2, 4, 8, 16)
gb = groebner(chan, _composite=_composite)
@test gb == gb_truth
end

# Test a number of cases directly
R, (x, y, z) = QQ["x", "y", "z"]
xs = gens(R)
Expand Down

0 comments on commit 4f73f9d

Please sign in to comment.