Skip to content

Commit

Permalink
Add Length-Delimited Encode and Decode functions to upb.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621510731
  • Loading branch information
protobuf-github-bot authored and copybara-github committed Apr 3, 2024
1 parent d7f032a commit c6f6a32
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 14 deletions.
22 changes: 18 additions & 4 deletions upb/message/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ TEST(MessageTest, Freeze) {
//
// static void DecodeEncodeArbitrarySchemaAndPayload(
// const upb::fuzz::MiniTableFuzzInput& input, std::string_view proto_payload,
// int decode_options, int encode_options) {
// int decode_options, int encode_options, bool length_delimited = false) {
// // Lexan does not have setenv
// #ifndef _MSC_VER
// setenv("FUZZTEST_STACK_LIMIT", "262144", 1);
Expand All @@ -605,11 +605,25 @@ TEST(MessageTest, Freeze) {
// upb::fuzz::BuildMiniTable(input, &exts, arena.ptr());
// if (!mini_table) return;
// upb_Message* msg = upb_Message_New(mini_table, arena.ptr());
// upb_Decode(proto_payload.data(), proto_payload.size(), msg, mini_table, exts,
// decode_options, arena.ptr());
// if (length_delimited) {
// size_t num_bytes_read = 0;
// upb_DecodeStatus status = upb_DecodeLengthDelimited(
// proto_payload.data(), proto_payload.size(), msg, &num_bytes_read,
// mini_table, exts, decode_options, arena.ptr());
// ASSERT_TRUE(status != kUpb_DecodeStatus_Ok ||
// num_bytes_read <= proto_payload.size());
// } else {
// upb_Decode(proto_payload.data(), proto_payload.size(), msg, mini_table,
// exts, decode_options, arena.ptr());
// }
// char* ptr;
// size_t size;
// upb_Encode(msg, mini_table, encode_options, arena.ptr(), &ptr, &size);
// if (length_delimited) {
// upb_EncodeLengthDelimited(msg, mini_table, encode_options, arena.ptr(),
// &ptr, &size);
// } else {
// upb_Encode(msg, mini_table, encode_options, arena.ptr(), &ptr, &size);
// }
// }
// FUZZ_TEST(FuzzTest, DecodeEncodeArbitrarySchemaAndPayload);
//
Expand Down
19 changes: 19 additions & 0 deletions upb/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,25 @@ cc_test(
],
)

cc_test(
name = "length_delimited_test",
srcs = ["length_delimited_test.cc"],
copts = UPB_DEFAULT_CPPOPTS,
deps = [
":test_messages_proto2_upb_minitable",
":test_messages_proto2_upb_proto",
"//upb:base",
"//upb:mem",
"//upb:message",
"//upb:message_compare",
"//upb:mini_table",
"//upb:wire",
"//upb/mem:internal",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
],
)

cc_test(
name = "test_cpp",
srcs = ["test_cpp.cc"],
Expand Down
86 changes: 86 additions & 0 deletions upb/test/length_delimited_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@

#include <cstddef>
#include <cstdio>
#include <string>
#include <vector>

#include <gtest/gtest.h>
#include "google/protobuf/test_messages_proto2.upb.h"
#include "google/protobuf/test_messages_proto2.upb_minitable.h"
#include "upb/base/string_view.h"
#include "upb/base/upcast.h"
#include "upb/mem/arena.h"
#include "upb/message/compare.h"
#include "upb/mini_table/message.h"
#include "upb/wire/decode.h"
#include "upb/wire/encode.h"

namespace {

static const upb_MiniTable* kTestMiniTable =
&protobuf_0test_0messages__proto2__TestAllTypesProto2_msg_init;

static void TestEncodeDecodeRoundTrip(
upb_Arena* arena,
std::vector<protobuf_test_messages_proto2_TestAllTypesProto2*> msgs) {
// Encode all of the messages and put their serializations contiguously.
std::string s;
for (auto msg : msgs) {
char* buf;
size_t size;
ASSERT_TRUE(upb_EncodeLengthDelimited(UPB_UPCAST(msg), kTestMiniTable, 0,
arena, &buf,
&size) == kUpb_EncodeStatus_Ok);
ASSERT_GT(size, 0); // Even empty messages are 1 byte in this encoding.
s.append(std::string(buf, size));
}

// Now decode all of the messages contained in the contiguous block.
std::vector<protobuf_test_messages_proto2_TestAllTypesProto2*> decoded;
while (!s.empty()) {
protobuf_test_messages_proto2_TestAllTypesProto2* msg =
protobuf_test_messages_proto2_TestAllTypesProto2_new(arena);
size_t num_bytes_read;
ASSERT_TRUE(upb_DecodeLengthDelimited(
s.data(), s.length(), UPB_UPCAST(msg), &num_bytes_read,
kTestMiniTable, nullptr, 0, arena) == kUpb_DecodeStatus_Ok);
ASSERT_GT(num_bytes_read, 0);
decoded.push_back(msg);
s = s.substr(num_bytes_read);
}

// Make sure that the values round tripped correctly.
ASSERT_EQ(msgs.size(), decoded.size());
for (size_t i = 0; i < msgs.size(); ++i) {
ASSERT_TRUE(upb_Message_IsEqual(UPB_UPCAST(msgs[i]), UPB_UPCAST(decoded[i]),
kTestMiniTable, 0));
}
}

TEST(LengthDelimitedTest, OneEmptyMessage) {
upb_Arena* arena = upb_Arena_New();
protobuf_test_messages_proto2_TestAllTypesProto2* msg =
protobuf_test_messages_proto2_TestAllTypesProto2_new(arena);
TestEncodeDecodeRoundTrip(arena, {msg});
upb_Arena_Free(arena);
}

TEST(LengthDelimitedTest, AFewMessages) {
upb_Arena* arena = upb_Arena_New();
protobuf_test_messages_proto2_TestAllTypesProto2* a =
protobuf_test_messages_proto2_TestAllTypesProto2_new(arena);
protobuf_test_messages_proto2_TestAllTypesProto2* b =
protobuf_test_messages_proto2_TestAllTypesProto2_new(arena);
protobuf_test_messages_proto2_TestAllTypesProto2* c =
protobuf_test_messages_proto2_TestAllTypesProto2_new(arena);

protobuf_test_messages_proto2_TestAllTypesProto2_set_optional_bool(a, true);
protobuf_test_messages_proto2_TestAllTypesProto2_set_optional_int32(b, 1);
protobuf_test_messages_proto2_TestAllTypesProto2_set_oneof_string(
c, upb_StringView_FromString("string"));

TestEncodeDecodeRoundTrip(arena, {a, b, c});
upb_Arena_Free(arena);
}

} // namespace
39 changes: 37 additions & 2 deletions upb/wire/decode.c
Original file line number Diff line number Diff line change
Expand Up @@ -1366,7 +1366,7 @@ static upb_DecodeStatus upb_Decoder_Decode(upb_Decoder* const decoder,
}

upb_DecodeStatus upb_Decode(const char* buf, size_t size, upb_Message* msg,
const upb_MiniTable* m,
const upb_MiniTable* mt,
const upb_ExtensionRegistry* extreg, int options,
upb_Arena* arena) {
UPB_ASSERT(!upb_Message_IsFrozen(msg));
Expand All @@ -1391,7 +1391,42 @@ upb_DecodeStatus upb_Decode(const char* buf, size_t size, upb_Message* msg,
// (particularly parent_or_count).
UPB_PRIVATE(_upb_Arena_SwapIn)(&decoder.arena, arena);

return upb_Decoder_Decode(&decoder, buf, msg, m, arena);
return upb_Decoder_Decode(&decoder, buf, msg, mt, arena);
}

upb_DecodeStatus upb_DecodeLengthDelimited(const char* buf, size_t size,
upb_Message* msg,
size_t* num_bytes_read,
const upb_MiniTable* mt,
const upb_ExtensionRegistry* extreg,
int options, upb_Arena* arena) {
// To avoid needing to make a Decoder just to decode the initial length,
// hand-decode the leading varint for the message length here.
uint64_t msg_len = 0;
for (size_t i = 0;; ++i) {
if (i >= size || i > 9) {
return kUpb_DecodeStatus_Malformed;
}
uint64_t b = *buf;
buf++;
msg_len += (b & 0x7f) << (i * 7);
if ((b & 0x80) == 0) {
*num_bytes_read = i + 1 + msg_len;
break;
}
}

// If the total number of bytes we would read (= the bytes from the varint
// plus however many bytes that varint says we should read) is larger then the
// input buffer then error as malformed.
if (*num_bytes_read > size) {
return kUpb_DecodeStatus_Malformed;
}
if (msg_len > INT32_MAX) {
return kUpb_DecodeStatus_Malformed;
}

return upb_Decode(buf, msg_len, msg, mt, extreg, options, arena);
}

#undef OP_FIXPCK_LG2
Expand Down
10 changes: 9 additions & 1 deletion upb/wire/decode.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,18 @@ typedef enum {
} upb_DecodeStatus;

UPB_API upb_DecodeStatus upb_Decode(const char* buf, size_t size,
upb_Message* msg, const upb_MiniTable* l,
upb_Message* msg, const upb_MiniTable* mt,
const upb_ExtensionRegistry* extreg,
int options, upb_Arena* arena);

// Same as upb_Decode but with a varint-encoded length prepended.
// On success 'num_bytes_read' will be set to the how many bytes were read,
// on failure the contents of num_bytes_read is undefined.
UPB_API upb_DecodeStatus upb_DecodeLengthDelimited(
const char* buf, size_t size, upb_Message* msg, size_t* num_bytes_read,
const upb_MiniTable* mt, const upb_ExtensionRegistry* extreg, int options,
upb_Arena* arena);

#ifdef __cplusplus
} /* extern "C" */
#endif
Expand Down
32 changes: 25 additions & 7 deletions upb/wire/encode.c
Original file line number Diff line number Diff line change
Expand Up @@ -607,14 +607,18 @@ static void encode_message(upb_encstate* e, const upb_Message* msg,
static upb_EncodeStatus upb_Encoder_Encode(upb_encstate* const encoder,
const upb_Message* const msg,
const upb_MiniTable* const l,
char** const buf,
size_t* const size) {
char** const buf, size_t* const size,
bool prepend_len) {
// Unfortunately we must continue to perform hackery here because there are
// code paths which blindly copy the returned pointer without bothering to
// check for errors until much later (b/235839510). So we still set *buf to
// NULL on error and we still set it to non-NULL on a successful empty result.
if (UPB_SETJMP(encoder->err) == 0) {
encode_message(encoder, msg, l, size);
size_t encoded_msg_size;
encode_message(encoder, msg, l, &encoded_msg_size);
if (prepend_len) {
encode_varint(encoder, encoded_msg_size);
}
*size = encoder->limit - encoder->ptr;
if (*size == 0) {
static char ch;
Expand All @@ -633,9 +637,10 @@ static upb_EncodeStatus upb_Encoder_Encode(upb_encstate* const encoder,
return encoder->status;
}

upb_EncodeStatus upb_Encode(const upb_Message* msg, const upb_MiniTable* l,
int options, upb_Arena* arena, char** buf,
size_t* size) {
static upb_EncodeStatus _upb_Encode(const upb_Message* msg,
const upb_MiniTable* l, int options,
upb_Arena* arena, char** buf, size_t* size,
bool prepend_len) {
upb_encstate e;
unsigned depth = (unsigned)options >> 16;

Expand All @@ -648,5 +653,18 @@ upb_EncodeStatus upb_Encode(const upb_Message* msg, const upb_MiniTable* l,
e.options = options;
_upb_mapsorter_init(&e.sorter);

return upb_Encoder_Encode(&e, msg, l, buf, size);
return upb_Encoder_Encode(&e, msg, l, buf, size, prepend_len);
}

upb_EncodeStatus upb_Encode(const upb_Message* msg, const upb_MiniTable* l,
int options, upb_Arena* arena, char** buf,
size_t* size) {
return _upb_Encode(msg, l, options, arena, buf, size, false);
}

upb_EncodeStatus upb_EncodeLengthDelimited(const upb_Message* msg,
const upb_MiniTable* l, int options,
upb_Arena* arena, char** buf,
size_t* size) {
return _upb_Encode(msg, l, options, arena, buf, size, true);
}
7 changes: 7 additions & 0 deletions upb/wire/encode.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ UPB_API upb_EncodeStatus upb_Encode(const upb_Message* msg,
const upb_MiniTable* l, int options,
upb_Arena* arena, char** buf, size_t* size);

// Encodes the message prepended by a varint of the serialized length.
UPB_API upb_EncodeStatus upb_EncodeLengthDelimited(const upb_Message* msg,
const upb_MiniTable* l,
int options,
upb_Arena* arena, char** buf,
size_t* size);

#ifdef __cplusplus
} /* extern "C" */
#endif
Expand Down

0 comments on commit c6f6a32

Please sign in to comment.