diff --git a/src/net.cpp b/src/net.cpp index bd98e3e9a..11b5e103f 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -694,7 +694,7 @@ bool CNode::ReceiveMsgBytes(Span msg_bytes, bool& complete) // decompose a transport agnostic CNetMessage from the deserializer bool reject_message{false}; bool disconnect{false}; - CNetMessage msg = m_deserializer->GetMessage(time, reject_message, disconnect); + CNetMessage msg = m_deserializer->GetMessage(time, reject_message, disconnect, {}); if (disconnect) { // v2 p2p incorrect MAC tag. Disconnect from peer. @@ -792,7 +792,10 @@ const uint256& V1TransportDeserializer::GetMessageHash() const return data_hash; } -CNetMessage V1TransportDeserializer::GetMessage(const std::chrono::microseconds time, bool& reject_message, bool& disconnect) +CNetMessage V1TransportDeserializer::GetMessage(const std::chrono::microseconds time, + bool& reject_message, + bool& disconnect, + Span aad) { // Initialize out parameter reject_message = false; @@ -897,7 +900,10 @@ int V2TransportDeserializer::readData(Span pkt_bytes) return copy_bytes; } -CNetMessage V2TransportDeserializer::GetMessage(const std::chrono::microseconds time, bool& reject_message, bool& disconnect) +CNetMessage V2TransportDeserializer::GetMessage(const std::chrono::microseconds time, + bool& reject_message, + bool& disconnect, + Span aad) { const size_t min_contents_size = 1; // BIP324 1-byte message type id is the minimum contents @@ -916,7 +922,7 @@ CNetMessage V2TransportDeserializer::GetMessage(const std::chrono::microseconds BIP324HeaderFlags flags; size_t msg_type_size = 1; // at least one byte needed for message type - if (m_cipher_suite->Crypt({}, + if (m_cipher_suite->Crypt(aad, Span{reinterpret_cast(vRecv.data() + BIP324_LENGTH_FIELD_LEN), BIP324_HEADER_LEN + m_contents_size + RFC8439_EXPANSION}, Span{reinterpret_cast(vRecv.data()), m_contents_size}, flags, false)) { // MAC check was successful @@ -1009,7 +1015,7 @@ bool V2TransportSerializer::prepareForTransport(CSerializedNetMsg& msg, std::vec BIP324HeaderFlags flags{BIP324_NONE}; // encrypt the payload, this should always succeed (controlled buffers, don't check the MAC during encrypting) - auto success = m_cipher_suite->Crypt({}, + auto success = m_cipher_suite->Crypt(msg.aad, Span{reinterpret_cast(msg.data.data()), contents_size}, Span{reinterpret_cast(msg.data.data()), encrypted_pkt_size}, flags, true); diff --git a/src/net.h b/src/net.h index 34cbf5fad..60ba47ecd 100644 --- a/src/net.h +++ b/src/net.h @@ -127,6 +127,7 @@ struct CSerializedNetMsg { } std::vector data; + std::vector aad; // associated authenticated data for encrypted BIP324 (v2) transport std::string m_type; }; @@ -270,7 +271,10 @@ class TransportDeserializer { /** read and deserialize data, advances msg_bytes data pointer */ virtual int Read(Span& msg_bytes) = 0; // decomposes a message from the context - virtual CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message, bool& disconnect) = 0; + virtual CNetMessage GetMessage(std::chrono::microseconds time, + bool& reject_message, + bool& disconnect, + Span aad) = 0; virtual ~TransportDeserializer() {} }; @@ -334,7 +338,10 @@ class V1TransportDeserializer final : public TransportDeserializer } return ret; } - CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message, bool& disconnect) override; + CNetMessage GetMessage(std::chrono::microseconds time, + bool& reject_message, + bool& disconnect, + Span aad) override; }; /** V2TransportDeserializer is a transport deserializer after BIP324 */ @@ -392,7 +399,10 @@ class V2TransportDeserializer final : public TransportDeserializer } return ret; } - CNetMessage GetMessage(const std::chrono::microseconds time, bool& reject_message, bool& disconnect) override; + CNetMessage GetMessage(const std::chrono::microseconds time, + bool& reject_message, + bool& disconnect, + Span aad) override; }; /** The TransportSerializer prepares messages for the network transport diff --git a/src/test/fuzz/p2p_transport_serialization.cpp b/src/test/fuzz/p2p_transport_serialization.cpp index 1ba35fd13..8b23c0110 100644 --- a/src/test/fuzz/p2p_transport_serialization.cpp +++ b/src/test/fuzz/p2p_transport_serialization.cpp @@ -70,7 +70,7 @@ FUZZ_TARGET_INIT(p2p_transport_serialization, initialize_p2p_transport_serializa const std::chrono::microseconds m_time{std::numeric_limits::max()}; bool reject_message{false}; bool disconnect{false}; - CNetMessage msg = deserializer.GetMessage(m_time, reject_message, disconnect); + CNetMessage msg = deserializer.GetMessage(m_time, reject_message, disconnect, {}); assert(msg.m_type.size() <= CMessageHeader::COMMAND_SIZE); assert(msg.m_raw_message_size <= mutable_msg_bytes.size()); assert(msg.m_raw_message_size == CMessageHeader::HEADER_SIZE + msg.m_message_size); diff --git a/src/test/fuzz/p2p_v2_transport_serialization.cpp b/src/test/fuzz/p2p_v2_transport_serialization.cpp index 7c6375f8d..5620d4fe9 100644 --- a/src/test/fuzz/p2p_v2_transport_serialization.cpp +++ b/src/test/fuzz/p2p_v2_transport_serialization.cpp @@ -40,6 +40,7 @@ FUZZ_TARGET(p2p_v2_transport_serialization) // There is no sense in providing a mac assist if the length is incorrect. bool mac_assist = length_assist && fdp.ConsumeBool(); + auto aad = fdp.ConsumeBytes(fdp.ConsumeIntegralInRange(0, 1024)); auto encrypted_packet = fdp.ConsumeRemainingBytes(); bool is_decoy_packet{false}; @@ -53,7 +54,7 @@ FUZZ_TARGET(p2p_v2_transport_serialization) if (mac_assist) { std::array tag; - ComputeRFC8439Tag(GetPoly1305Key(c20), {}, + ComputeRFC8439Tag(GetPoly1305Key(c20), aad, {reinterpret_cast(encrypted_packet.data()) + BIP324_LENGTH_FIELD_LEN, encrypted_packet.size() - BIP324_LENGTH_FIELD_LEN - RFC8439_EXPANSION}, tag); @@ -61,7 +62,7 @@ FUZZ_TARGET(p2p_v2_transport_serialization) std::vector dec_header_and_contents( encrypted_packet.size() - BIP324_LENGTH_FIELD_LEN - RFC8439_EXPANSION); - RFC8439Decrypt({}, key_P, nonce, + RFC8439Decrypt(aad, key_P, nonce, {reinterpret_cast(encrypted_packet.data() + BIP324_LENGTH_FIELD_LEN), encrypted_packet.size() - BIP324_LENGTH_FIELD_LEN}, dec_header_and_contents); @@ -81,7 +82,7 @@ FUZZ_TARGET(p2p_v2_transport_serialization) const std::chrono::microseconds m_time{std::numeric_limits::max()}; bool reject_message{true}; bool disconnect{true}; - CNetMessage result{deserializer.GetMessage(m_time, reject_message, disconnect)}; + CNetMessage result{deserializer.GetMessage(m_time, reject_message, disconnect, aad)}; if (mac_assist) { assert(!disconnect); @@ -102,6 +103,7 @@ FUZZ_TARGET(p2p_v2_transport_serialization) std::vector header; auto msg = CNetMsgMaker{result.m_recv.GetVersion()}.Make(result.m_type, MakeUCharSpan(result.m_recv)); + msg.aad = aad; // if decryption succeeds, encryption must succeed assert(serializer.prepareForTransport(msg, header)); } diff --git a/src/test/net_tests.cpp b/src/test/net_tests.cpp index 3f6a4a4de..9507b77b6 100644 --- a/src/test/net_tests.cpp +++ b/src/test/net_tests.cpp @@ -953,7 +953,7 @@ void message_serialize_deserialize_test(bool v2, const std::vectorGetMessage(GetTime(), reject_message, disconnect)}; + CNetMessage result{deserializer->GetMessage(GetTime(), reject_message, disconnect, {})}; BOOST_CHECK(!reject_message); BOOST_CHECK(!disconnect); BOOST_CHECK_EQUAL(result.m_type, msg_orig.m_type);