diff --git a/src/lib/pubkey/dilithium/dilithium_common/dilithium_algos.cpp b/src/lib/pubkey/dilithium/dilithium_common/dilithium_algos.cpp index a0604a9cf6..8d4fd3f45e 100644 --- a/src/lib/pubkey/dilithium/dilithium_common/dilithium_algos.cpp +++ b/src/lib/pubkey/dilithium/dilithium_common/dilithium_algos.cpp @@ -523,24 +523,14 @@ DilithiumSerializedCommitment encode_commitment(const DilithiumPolyVec& w1, cons * NIST FIPS 204, Algorithm 29 (SampleInBall) */ DilithiumPoly sample_in_ball(StrongSpan seed, const DilithiumConstants& mode) { - auto xof = mode.symmetric_primitives().H(seed); - // This generator resembles the while loop in the spec. - auto next_byte_lower_than = [&xof, n = uint16_t(0)](size_t i) mutable -> uint8_t { - while(n < DilithiumConstants::SAMPLE_IN_BALL_XOF_BOUND) { - ++n; - if(const uint8_t b = xof.output_next_byte(); b <= i) { - return b; - } - } - - throw Internal_Error("SampleInBall did not terminate"); - }; + auto& xof = mode.symmetric_primitives().H(seed); + auto bounded_xof = Bounded_XOF(xof); DilithiumPoly c; - uint64_t signs = load_le(xof.output<8>()); + uint64_t signs = load_le(bounded_xof.next<8>()); for(size_t i = c.size() - mode.tau(); i < c.size(); ++i) { - const auto j = next_byte_lower_than(i); + const auto j = bounded_xof.next_byte([i](uint8_t byte) { return byte <= i; }); c[i] = c[j]; c[j] = 1 - 2 * (signs & 1); signs >>= 1; @@ -565,25 +555,13 @@ void sample_ntt_uniform(StrongSpan rho, * A generator that returns the next coefficient sampled from the XOF, * according to: NIST FIPS 204, Algorithm 14 (CoeffFromThreeBytes). */ - auto next_coeff = [n = uint16_t(0)](Botan::XOF& xof) mutable -> uint32_t { - std::array bytes = {0}; - std::span sampling_sink_in_bytes = std::span{bytes}.first<3>(); - - while(n < DilithiumConstants::SAMPLE_NTT_POLY_FROM_XOF_BOUND) { - n += static_cast(sampling_sink_in_bytes.size()); - xof.output(sampling_sink_in_bytes); - const auto z = load_le(bytes) & 0x7FFFFF; - if(z < DilithiumConstants::Q) { - return z; - } - } - - throw Internal_Error("RejNTTPoly did not terminate"); - }; - auto& xof = mode.symmetric_primitives().H(rho, nonce); + auto bounded_xof = Bounded_XOF(xof); + for(auto& coeff : p) { - coeff = next_coeff(xof); + coeff = + bounded_xof.next<3>([](const auto bytes) { return make_uint32(0, bytes[2], bytes[1], bytes[0]) & 0x7FFFFF; }, + [](const uint32_t z) { return z < DilithiumConstants::Q; }); } BOTAN_DEBUG_ASSERT(p.ct_validate_value_range(0, DilithiumConstants::Q - 1)); @@ -617,15 +595,15 @@ void sample_uniform_eta(DilithiumPoly& p, Botan::XOF& xof) { // A generator that returns the next coefficient sampled from the XOF. As the // sampling uses half-bytes, this keeps track of the additionally sampled // coefficient as needed. - auto next_coeff = [&xof, stashed_coeff = std::optional{}, n = uint16_t(0)]() mutable -> int32_t { + auto next_coeff = [bounded_xof = Bounded_XOF(xof), + stashed_coeff = std::optional{}]() mutable -> int32_t { if(auto stashed = std::exchange(stashed_coeff, std::nullopt)) { return *stashed; } BOTAN_DEBUG_ASSERT(!stashed_coeff.has_value()); - while(n < DilithiumConstants::SAMPLE_POLY_FROM_XOF_BOUND) { - ++n; - const auto b = xof.output_next_byte(); + while(true) { + const auto b = bounded_xof.next_byte(); const auto z0 = coeff_from_halfbyte(b & 0x0F); const auto z1 = coeff_from_halfbyte(b >> 4); @@ -637,8 +615,6 @@ void sample_uniform_eta(DilithiumPoly& p, Botan::XOF& xof) { return *z1; } } - - throw Internal_Error("RejBoundedPoly did not terminate"); }; for(auto& coeff : p) {