Skip to content

Commit

Permalink
checked_cast_to_or_throw
Browse files Browse the repository at this point in the history
  • Loading branch information
FAlbertDev committed Sep 28, 2023
1 parent 6668d74 commit 2312605
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 9 deletions.
6 changes: 4 additions & 2 deletions src/lib/pubkey/hss_lms/hss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ HSS_LMS_Params::HSS_LMS_Params(std::string_view algo_params) {
SCAN_Name scan_layer(scan.arg(i));
BOTAN_ARG_CHECK(scan_layer.algo_name() == "HW", "Invalid name for layer parameters");
BOTAN_ARG_CHECK(scan_layer.arg_count() == 2, "Invalid number of layer parameters");
const auto h = checked_cast_to<uint8_t>(scan_layer.arg_as_integer(0));
const auto w = checked_cast_to<uint8_t>(scan_layer.arg_as_integer(1));
const auto h =
checked_cast_to_or_throw<uint8_t, Invalid_Argument>(scan_layer.arg_as_integer(0), "Invalid parameter");
const auto w =
checked_cast_to_or_throw<uint8_t, Invalid_Argument>(scan_layer.arg_as_integer(1), "Invalid parameter");
m_lms_lmots_params.push_back({LMS_Params::create_or_throw(hash, h), LMOTS_Params::create_or_throw(hash, w)});
}
m_max_sig_count = calc_max_sig_count();
Expand Down
2 changes: 1 addition & 1 deletion src/lib/pubkey/hss_lms/lm_ots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ std::vector<uint8_t> gen_Q_with_cksm(const LMOTS_Params& params,
} // namespace

LMOTS_Params LMOTS_Params::create_or_throw(LMOTS_Algorithm_Type type) {
uint8_t type_value = checked_cast_to<uint8_t>(type);
uint8_t type_value = checked_cast_to_or_throw<uint8_t, Decoding_Error>(type, "Unsupported LM-OTS algorithm type");

if(type >= LMOTS_Algorithm_Type::SHA256_N32_W1 && type <= LMOTS_Algorithm_Type::SHA256_N32_W8) {
uint8_t w = 1 << (type_value - checked_cast_to<uint8_t>(LMOTS_Algorithm_Type::SHA256_N32_W1));
Expand Down
2 changes: 1 addition & 1 deletion src/lib/pubkey/hss_lms/lms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void lms_treehash(StrongSpan<LMS_Tree_Node> out_root,
} // namespace

LMS_Params LMS_Params::create_or_throw(LMS_Algorithm_Type type) {
uint8_t type_value = checked_cast_to<uint8_t>(type);
uint8_t type_value = checked_cast_to_or_throw<uint8_t, Decoding_Error>(type, "Unsupported LMS algorithm type");

if(type >= LMS_Algorithm_Type::SHA256_M32_H5 && type <= LMS_Algorithm_Type::SHA256_M32_H25) {
uint8_t h = 5 * (type_value - checked_cast_to<uint8_t>(LMS_Algorithm_Type::SHA256_M32_H5) + 1);
Expand Down
11 changes: 8 additions & 3 deletions src/lib/utils/safeint.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,20 @@ inline std::optional<size_t> checked_mul(size_t x, size_t y) {
return z;
}

template <typename RT, typename AT>
RT checked_cast_to(AT i) {
template <typename RT, typename ExceptionType, typename AT>
constexpr RT checked_cast_to_or_throw(AT i, std::string_view error_msg_on_fail) {
RT c = static_cast<RT>(i);
if(i != static_cast<AT>(c)) {
throw Internal_Error("Error during integer conversion");
throw ExceptionType(error_msg_on_fail);
}
return c;
}

template <typename RT, typename AT>
constexpr RT checked_cast_to(AT i) {
return checked_cast_to_or_throw<RT, Internal_Error>(i, "Error during integer conversion");
}

#define BOTAN_CHECKED_ADD(x, y) checked_add(x, y, __FILE__, __LINE__)
#define BOTAN_CHECKED_MUL(x, y) checked_mul(x, y)

Expand Down
6 changes: 4 additions & 2 deletions src/tests/test_hss_lms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ class HSS_LMS_Too_Short_Test final : public Test {
for(size_t n = 0; n < sk_bytes.size(); ++n) {
result.test_throws<Botan::Decoding_Error>("Partial private key invalid", [&]() {
std::span<const uint8_t> partial_key = {sk_bytes.data(), n};
std::make_unique<Botan::HSS_LMS_PrivateKey>(partial_key);
Botan::HSS_LMS_PrivateKey key(partial_key);
BOTAN_UNUSED(key);
});
}
return result;
Expand All @@ -163,7 +164,8 @@ class HSS_LMS_Too_Short_Test final : public Test {
for(size_t n = 0; n < sk_bytes.size(); ++n) {
result.test_throws<Botan::Decoding_Error>("Partial public key invalid", [&]() {
std::span<const uint8_t> partial_key = {sk_bytes.data(), n};
std::make_unique<Botan::HSS_LMS_PublicKey>(partial_key);
Botan::HSS_LMS_PublicKey key(partial_key);
BOTAN_UNUSED(key);
});
}
return result;
Expand Down

0 comments on commit 2312605

Please sign in to comment.