diff --git a/src/lib/pubkey/hss_lms/hss.cpp b/src/lib/pubkey/hss_lms/hss.cpp index 9c0de73382c..dfbe8a79eaa 100644 --- a/src/lib/pubkey/hss_lms/hss.cpp +++ b/src/lib/pubkey/hss_lms/hss.cpp @@ -90,8 +90,10 @@ HSS_LMS_Params::HSS_LMS_Params(std::string_view algo_params) { SCAN_Name scan_layer(scan.arg(i)); BOTAN_ARG_CHECK(scan_layer.algo_name() == "HW", "Invalid name for layer parameters"); BOTAN_ARG_CHECK(scan_layer.arg_count() == 2, "Invalid number of layer parameters"); - const auto h = checked_cast_to(scan_layer.arg_as_integer(0)); - const auto w = checked_cast_to(scan_layer.arg_as_integer(1)); + const auto h = + checked_cast_to_or_throw(scan_layer.arg_as_integer(0), "Invalid parameter"); + const auto w = + checked_cast_to_or_throw(scan_layer.arg_as_integer(1), "Invalid parameter"); m_lms_lmots_params.push_back({LMS_Params::create_or_throw(hash, h), LMOTS_Params::create_or_throw(hash, w)}); } m_max_sig_count = calc_max_sig_count(); diff --git a/src/lib/pubkey/hss_lms/lm_ots.cpp b/src/lib/pubkey/hss_lms/lm_ots.cpp index f0486c8258f..b9c467b528f 100644 --- a/src/lib/pubkey/hss_lms/lm_ots.cpp +++ b/src/lib/pubkey/hss_lms/lm_ots.cpp @@ -110,7 +110,7 @@ std::vector gen_Q_with_cksm(const LMOTS_Params& params, } // namespace LMOTS_Params LMOTS_Params::create_or_throw(LMOTS_Algorithm_Type type) { - uint8_t type_value = checked_cast_to(type); + uint8_t type_value = checked_cast_to_or_throw(type, "Unsupported LM-OTS algorithm type"); if(type >= LMOTS_Algorithm_Type::SHA256_N32_W1 && type <= LMOTS_Algorithm_Type::SHA256_N32_W8) { uint8_t w = 1 << (type_value - checked_cast_to(LMOTS_Algorithm_Type::SHA256_N32_W1)); diff --git a/src/lib/pubkey/hss_lms/lms.cpp b/src/lib/pubkey/hss_lms/lms.cpp index 8d18a5e431a..91e0bd3b0fa 100644 --- a/src/lib/pubkey/hss_lms/lms.cpp +++ b/src/lib/pubkey/hss_lms/lms.cpp @@ -120,7 +120,7 @@ void lms_treehash(StrongSpan out_root, } // namespace LMS_Params LMS_Params::create_or_throw(LMS_Algorithm_Type type) { - uint8_t type_value = checked_cast_to(type); + uint8_t type_value = checked_cast_to_or_throw(type, "Unsupported LMS algorithm type"); if(type >= LMS_Algorithm_Type::SHA256_M32_H5 && type <= LMS_Algorithm_Type::SHA256_M32_H25) { uint8_t h = 5 * (type_value - checked_cast_to(LMS_Algorithm_Type::SHA256_M32_H5) + 1); diff --git a/src/lib/utils/safeint.h b/src/lib/utils/safeint.h index 3fa7547cc29..a765e1dbbb1 100644 --- a/src/lib/utils/safeint.h +++ b/src/lib/utils/safeint.h @@ -61,15 +61,20 @@ inline std::optional checked_mul(size_t x, size_t y) { return z; } -template -RT checked_cast_to(AT i) { +template +constexpr RT checked_cast_to_or_throw(AT i, std::string_view error_msg_on_fail) { RT c = static_cast(i); if(i != static_cast(c)) { - throw Internal_Error("Error during integer conversion"); + throw ExceptionType(error_msg_on_fail); } return c; } +template +constexpr RT checked_cast_to(AT i) { + return checked_cast_to_or_throw(i, "Error during integer conversion"); +} + #define BOTAN_CHECKED_ADD(x, y) checked_add(x, y, __FILE__, __LINE__) #define BOTAN_CHECKED_MUL(x, y) checked_mul(x, y) diff --git a/src/tests/test_hss_lms.cpp b/src/tests/test_hss_lms.cpp index ab3cfc61fa8..58730efd583 100644 --- a/src/tests/test_hss_lms.cpp +++ b/src/tests/test_hss_lms.cpp @@ -145,7 +145,8 @@ class HSS_LMS_Too_Short_Test final : public Test { for(size_t n = 0; n < sk_bytes.size(); ++n) { result.test_throws("Partial private key invalid", [&]() { std::span partial_key = {sk_bytes.data(), n}; - std::make_unique(partial_key); + Botan::HSS_LMS_PrivateKey key(partial_key); + BOTAN_UNUSED(key); }); } return result; @@ -163,7 +164,8 @@ class HSS_LMS_Too_Short_Test final : public Test { for(size_t n = 0; n < sk_bytes.size(); ++n) { result.test_throws("Partial public key invalid", [&]() { std::span partial_key = {sk_bytes.data(), n}; - std::make_unique(partial_key); + Botan::HSS_LMS_PublicKey key(partial_key); + BOTAN_UNUSED(key); }); } return result;