Skip to content

Commit

Permalink
use checked_cast_to
Browse files Browse the repository at this point in the history
  • Loading branch information
FAlbertDev committed Sep 28, 2023
1 parent 19159f5 commit 9d07e2a
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 17 deletions.
5 changes: 3 additions & 2 deletions src/lib/pubkey/hss_lms/hss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <botan/internal/hss.h>
#include <botan/internal/hss_lms_utils.h>
#include <botan/internal/loadstor.h>
#include <botan/internal/safeint.h>
#include <botan/internal/scan_name.h>
#include <botan/internal/stl_util.h>

Expand Down Expand Up @@ -89,8 +90,8 @@ HSS_LMS_Params::HSS_LMS_Params(std::string_view algo_params) {
SCAN_Name scan_layer(scan.arg(i));
BOTAN_ARG_CHECK(scan_layer.algo_name() == "HW", "Invalid name for layer parameters");
BOTAN_ARG_CHECK(scan_layer.arg_count() == 2, "Invalid number of layer parameters");
const auto h = scan_layer.arg_as_integer(0);
const auto w = static_cast<uint8_t>(scan_layer.arg_as_integer(1));
const auto h = checked_cast_to<uint8_t>(scan_layer.arg_as_integer(0));
const auto w = checked_cast_to<uint8_t>(scan_layer.arg_as_integer(1));
m_lms_lmots_params.push_back({LMS_Params::create_or_throw(hash, h), LMOTS_Params::create_or_throw(hash, w)});
}
m_max_sig_count = calc_max_sig_count();
Expand Down
3 changes: 2 additions & 1 deletion src/lib/pubkey/hss_lms/hss.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <botan/asn1_obj.h>
#include <botan/internal/lm_ots.h>
#include <botan/internal/lms.h>
#include <botan/internal/safeint.h>

#include <cstdint>
#include <memory>
Expand Down Expand Up @@ -99,7 +100,7 @@ class BOTAN_TEST_API HSS_LMS_Params final {
/**
* @brief Returns the number of layers the HSS tree has.
*/
HSS_Level L() const { return HSS_Level(static_cast<uint32_t>(m_lms_lmots_params.size())); }
HSS_Level L() const { return HSS_Level(checked_cast_to<uint32_t>(m_lms_lmots_params.size())); }

/**
* @brief The maximal number of signatures allowed for these HSS parameters
Expand Down
16 changes: 8 additions & 8 deletions src/lib/pubkey/hss_lms/lm_ots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,22 @@ std::vector<uint8_t> gen_Q_with_cksm(const LMOTS_Params& params,
} // namespace

LMOTS_Params LMOTS_Params::create_or_throw(LMOTS_Algorithm_Type type) {
uint8_t type_value = static_cast<uint8_t>(type);
uint8_t type_value = checked_cast_to<uint8_t>(type);

if(type >= LMOTS_Algorithm_Type::SHA256_N32_W1 && type <= LMOTS_Algorithm_Type::SHA256_N32_W8) {
uint8_t w = 1 << (type_value - static_cast<uint8_t>(LMOTS_Algorithm_Type::SHA256_N32_W1));
uint8_t w = 1 << (type_value - checked_cast_to<uint8_t>(LMOTS_Algorithm_Type::SHA256_N32_W1));
return LMOTS_Params(type, "SHA-256", w);
}
if(type >= LMOTS_Algorithm_Type::SHA256_N24_W1 && type <= LMOTS_Algorithm_Type::SHA256_N24_W8) {
uint8_t w = 1 << (type_value - static_cast<uint8_t>(LMOTS_Algorithm_Type::SHA256_N24_W1));
uint8_t w = 1 << (type_value - checked_cast_to<uint8_t>(LMOTS_Algorithm_Type::SHA256_N24_W1));
return LMOTS_Params(type, "Truncated(SHA-256,192)", w);
}
if(type >= LMOTS_Algorithm_Type::SHAKE_N32_W1 && type <= LMOTS_Algorithm_Type::SHAKE_N32_W8) {
uint8_t w = 1 << (type_value - static_cast<uint8_t>(LMOTS_Algorithm_Type::SHAKE_N32_W1));
uint8_t w = 1 << (type_value - checked_cast_to<uint8_t>(LMOTS_Algorithm_Type::SHAKE_N32_W1));
return LMOTS_Params(type, "SHAKE-256(256)", w);
}
if(type >= LMOTS_Algorithm_Type::SHAKE_N24_W1 && type <= LMOTS_Algorithm_Type::SHAKE_N24_W8) {
uint8_t w = 1 << (type_value - static_cast<uint8_t>(LMOTS_Algorithm_Type::SHAKE_N24_W1));
uint8_t w = 1 << (type_value - checked_cast_to<uint8_t>(LMOTS_Algorithm_Type::SHAKE_N24_W1));
return LMOTS_Params(type, "SHAKE-256(192)", w);
}

Expand All @@ -148,7 +148,7 @@ LMOTS_Params LMOTS_Params::create_or_throw(std::string_view hash_name, uint8_t w
} else {
throw Decoding_Error("Unsupported hash function");
}
auto type = static_cast<LMOTS_Algorithm_Type>(static_cast<uint8_t>(base_type) + type_offset);
auto type = checked_cast_to<LMOTS_Algorithm_Type>(checked_cast_to<uint8_t>(base_type) + type_offset);
return LMOTS_Params(type, hash_name, w);
}

Expand All @@ -159,8 +159,8 @@ LMOTS_Params::LMOTS_Params(LMOTS_Algorithm_Type algorithm_type, std::string_view
// RFC 8553 Appendix B - Parameter Computation
auto u = ceil_division<size_t>(8 * m_n, m_w); // ceil(8*n/w)
auto v = ceil_division<size_t>(high_bit(((1 << m_w) - 1) * u), m_w); // ceil((floor(lg[(2^w - 1) * u]) + 1) / w)
m_ls = static_cast<uint8_t>(16 - (v * w));
m_p = static_cast<uint16_t>(u + v);
m_ls = checked_cast_to<uint8_t>(16 - (v * w));
m_p = checked_cast_to<uint16_t>(u + v);
}

LMOTS_Signature::LMOTS_Signature(LMOTS_Algorithm_Type lmots_type,
Expand Down
13 changes: 7 additions & 6 deletions src/lib/pubkey/hss_lms/lms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <botan/internal/lms.h>

#include <botan/internal/hss_lms_utils.h>
#include <botan/internal/safeint.h>

namespace Botan {
namespace {
Expand Down Expand Up @@ -119,22 +120,22 @@ void lms_treehash(StrongSpan<LMS_Tree_Node> out_root,
} // namespace

LMS_Params LMS_Params::create_or_throw(LMS_Algorithm_Type type) {
uint8_t type_value = static_cast<uint8_t>(type);
uint8_t type_value = checked_cast_to<uint8_t>(type);

if(type >= LMS_Algorithm_Type::SHA256_M32_H5 && type <= LMS_Algorithm_Type::SHA256_M32_H25) {
uint8_t h = 5 * (type_value - static_cast<uint8_t>(LMS_Algorithm_Type::SHA256_M32_H5) + 1);
uint8_t h = 5 * (type_value - checked_cast_to<uint8_t>(LMS_Algorithm_Type::SHA256_M32_H5) + 1);
return LMS_Params(type, "SHA-256", h);
}
if(type >= LMS_Algorithm_Type::SHA256_M24_H5 && type <= LMS_Algorithm_Type::SHA256_M24_H25) {
uint8_t h = 5 * (type_value - static_cast<uint8_t>(LMS_Algorithm_Type::SHA256_M24_H5) + 1);
uint8_t h = 5 * (type_value - checked_cast_to<uint8_t>(LMS_Algorithm_Type::SHA256_M24_H5) + 1);
return LMS_Params(type, "Truncated(SHA-256,192)", h);
}
if(type >= LMS_Algorithm_Type::SHAKE_M32_H5 && type <= LMS_Algorithm_Type::SHAKE_M32_H25) {
uint8_t h = 5 * (type_value - static_cast<uint8_t>(LMS_Algorithm_Type::SHAKE_M32_H5) + 1);
uint8_t h = 5 * (type_value - checked_cast_to<uint8_t>(LMS_Algorithm_Type::SHAKE_M32_H5) + 1);
return LMS_Params(type, "SHAKE-256(256)", h);
}
if(type >= LMS_Algorithm_Type::SHAKE_M24_H5 && type <= LMS_Algorithm_Type::SHAKE_M24_H25) {
uint8_t h = 5 * (type_value - static_cast<uint8_t>(LMS_Algorithm_Type::SHAKE_M24_H5) + 1);
uint8_t h = 5 * (type_value - checked_cast_to<uint8_t>(LMS_Algorithm_Type::SHAKE_M24_H5) + 1);
return LMS_Params(type, "SHAKE-256(192)", h);
}

Expand All @@ -157,7 +158,7 @@ LMS_Params LMS_Params::create_or_throw(std::string_view hash_name, size_t h) {
} else {
throw Decoding_Error("Unsupported hash function");
}
auto type = static_cast<LMS_Algorithm_Type>(static_cast<uint8_t>(base_type) + type_offset);
auto type = checked_cast_to<LMS_Algorithm_Type>(checked_cast_to<uint8_t>(base_type) + type_offset);
return LMS_Params(type, hash_name, h);
}

Expand Down

0 comments on commit 9d07e2a

Please sign in to comment.