From 884b36c6af70079b08f922e90052cbf4de3d13cb Mon Sep 17 00:00:00 2001 From: Ronald Holshausen Date: Sat, 20 Jan 2024 04:33:05 +1100 Subject: [PATCH] fix: Repeated enum fields must be encoded as packed varints #27 --- .gitignore | 2 +- src/message_builder.rs | 67 +++++++++++++++++++++++++++++++++++--- src/message_decoder/mod.rs | 66 +++++++++++++++++++++++++++++++++---- src/utils.rs | 2 +- 4 files changed, 123 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index 37a4abb..d732df6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ # will have compiled files and executables -/target/ +target/ # These are backup files generated by rustfmt **/*.rs.bk diff --git a/src/message_builder.rs b/src/message_builder.rs index 087acfb..35238c9 100644 --- a/src/message_builder.rs +++ b/src/message_builder.rs @@ -13,7 +13,7 @@ use prost_types::{DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, Fi use prost_types::field_descriptor_proto::Type; use tracing::{trace, warn}; -use crate::utils::{last_name, should_be_packed_type}; +use crate::utils::{last_name, should_be_packed_type, display_bytes}; /// Enum to set what type of field the value is for #[derive(Clone, Copy, Debug, PartialEq)] @@ -133,9 +133,9 @@ impl MessageBuilder { } } - trace!("encode_message: {} bytes", buffer.len()); - - Ok(buffer.freeze()) + let bytes = buffer.freeze(); + trace!("encode_message: {} bytes {}", bytes.len(), display_bytes(&bytes)); + Ok(bytes) } fn encode_single_field(&self, mut buffer: &mut BytesMut, field_data: &FieldValueInner, value: Option) -> anyhow::Result<()> { @@ -329,6 +329,7 @@ impl MessageBuilder { buffer: &mut BytesMut, field_value: &FieldValueInner ) -> anyhow::Result<()> { + trace!(">> encode_packed_field({:?})", field_value); if let Some(tag) = field_value.descriptor.number { match field_value.proto_type { Type::Double => { @@ -366,6 +367,16 @@ impl MessageBuilder { prost::encoding::int32::encode_packed(tag as u32, &values, buffer); Ok(()) } + Type::Enum => { + let values = field_value.values.iter() + .map(|v| match &v.rtype { + RType::Enum(i, _) => *i, + _ => v.rtype.as_i32().unwrap_or_default() + }) + .collect::>(); + prost::encoding::int32::encode_packed(tag as u32, &values, buffer); + Ok(()) + } Type::Fixed64 => { let values = field_value.values.iter() .map(|v| v.rtype.as_u64().unwrap_or_default()) @@ -708,7 +719,7 @@ impl MessageFieldValue { } #[cfg(test)] -mod tests { +pub(crate) mod tests { use base64::Engine; use base64::engine::general_purpose::STANDARD as BASE64; use bytes::{Bytes, BytesMut}; @@ -743,6 +754,7 @@ mod tests { use crate::message_builder::MessageFieldValueType::Repeated; use crate::message_decoder::{decode_message, ProtobufFieldData}; use crate::protobuf::tests::DESCRIPTOR_WITH_ENUM_BYTES; + use crate::utils::find_enum_by_name_in_message; const ENCODED_MESSAGE: &str = "CuIFChxnb29nbGUvcHJvdG9idWYvc3RydWN0LnByb3RvEg9nb29nbGUucHJv\ dG9idWYimAEKBlN0cnVjdBI7CgZmaWVsZHMYASADKAsyIy5nb29nbGUucHJvdG9idWYuU3RydWN0LkZpZWxkc0VudHJ5\ @@ -1732,6 +1744,51 @@ mod tests { expect!(result.to_vec()).to(be_equal_to(expected)); } + pub(crate) const REPEATED_ENUM_DESCRIPTORS: &str = "Cv4EChNyZXBlYXRlZF9lbnVtLnByb3RvEglwYWN0aXNzdWUieQoTQn\ + Jva2VuU2FtcGxlUmVxdWVzdBI3CgR0eXBlGAEgAygOMiMucGFjdGlzc3VlLkJyb2tlblNhbXBsZVJlcXVlc3QuVHlwZVIEd\ + HlwZSIpCgRUeXBlEgsKB1VOS05PV04QABIJCgVUWVBFMRABEgkKBVRZUEUyEAIiJgoUQnJva2VuU2FtcGxlUmVzcG9uc2US\ + DgoCb2sYASABKAhSAm9rInsKFFdvcmtpbmdTYW1wbGVSZXF1ZXN0EjgKBHR5cGUYASABKA4yJC5wYWN0aXNzdWUuV29ya2l\ + uZ1NhbXBsZVJlcXVlc3QuVHlwZVIEdHlwZSIpCgRUeXBlEgsKB1VOS05PV04QABIJCgVUWVBFMRABEgkKBVRZUEUyEAIiJw\ + oVV29ya2luZ1NhbXBsZVJlc3BvbnNlEg4KAm9rGAEgASgIUgJvazJlChNCcm9rZW5TYW1wbGVTZXJ2aWNlEk4KCUdldFNhb\ + XBsZRIeLnBhY3Rpc3N1ZS5Ccm9rZW5TYW1wbGVSZXF1ZXN0Gh8ucGFjdGlzc3VlLkJyb2tlblNhbXBsZVJlc3BvbnNlIgAya\ + AoUV29ya2luZ1NhbXBsZVNlcnZpY2USUAoJR2V0U2FtcGxlEh8ucGFjdGlzc3VlLldvcmtpbmdTYW1wbGVSZXF1ZXN0GiAuc\ + GFjdGlzc3VlLldvcmtpbmdTYW1wbGVSZXNwb25zZSIAQjpaOGdpdGh1Yi5jb20vc3Rhbi1pcy1oYXRlL3BhY3QtcHJvdG8ta\ + XNzdWUtZGVtby87cGFjdGlzc3VlYgZwcm90bzM="; + + #[test_log::test] + fn repeated_enum_fields_must_be_packed() { + let file_descriptor = get_file_descriptor("repeated_enum.proto", REPEATED_ENUM_DESCRIPTORS).unwrap(); + let request_descriptor = file_descriptor.message_type.iter() + .find(|desc| desc.name.clone().unwrap_or_default() == "BrokenSampleRequest") + .unwrap(); + let values_field_descriptor = request_descriptor.field.iter() + .find(|desc| desc.name.clone().unwrap_or_default() == "type") + .unwrap(); + let mut builder = MessageBuilder::new(request_descriptor, "BrokenSampleRequest", &file_descriptor); + let enum_proto = find_enum_by_name_in_message(&request_descriptor.enum_type, "Type").unwrap(); + let message_field_value = MessageFieldValue { + name: "type".to_string(), + raw_value: Some("Type2".to_string()), + rtype: RType::Enum(2, enum_proto.clone()) + }; + let message_field_value2 = MessageFieldValue { + name: "type".to_string(), + raw_value: Some("Type1".to_string()), + rtype: RType::Enum(1, enum_proto.clone()) + }; + builder.add_repeated_field_value(values_field_descriptor, "type", message_field_value); + builder.add_repeated_field_value(values_field_descriptor, "type", message_field_value2); + + let expected = vec![ + 10, // Field 1, VARINT + 2, // 2 bytes + 2, // Enum 2 (Type2) + 1 // Enum 1 (Type1) + ]; + let result = builder.encode_message().unwrap(); + expect!(result.to_vec()).to(be_equal_to(expected)); + } + #[test_log::test] fn test_field_with_global_enum() { let bytes: &[u8] = &DESCRIPTOR_WITH_ENUM_BYTES; diff --git a/src/message_decoder/mod.rs b/src/message_decoder/mod.rs index 1d753e4..d1cf835 100644 --- a/src/message_decoder/mod.rs +++ b/src/message_decoder/mod.rs @@ -281,11 +281,7 @@ pub fn decode_message( Type::Bool => vec![ (ProtobufFieldData::Boolean(varint > 0), wire_type) ], Type::Uint32 => vec![ (ProtobufFieldData::UInteger32(varint as u32), wire_type) ], Type::Enum => { - let enum_type_name = field_descriptor.type_name.clone().unwrap_or_default(); - let enum_proto = find_enum_by_name_in_message(&descriptor.enum_type, enum_type_name.as_str()) - .or_else(|| find_enum_by_name(descriptors, enum_type_name.as_str())) - .ok_or_else(|| anyhow!("Did not find the enum {} for the field {} in the Protobuf descriptor", enum_type_name, field_num))?; - vec![ (ProtobufFieldData::Enum(varint as i32, enum_proto.clone()), wire_type) ] + vec![ (decode_enum(descriptor, descriptors, &field_descriptor, varint)?, wire_type) ] }, Type::Sint32 => { let value = varint as u32; @@ -333,7 +329,7 @@ pub fn decode_message( Type::Bytes => vec![ (ProtobufFieldData::Bytes(data_buffer.to_vec()), wire_type) ], _ => if should_be_packed_type(t) && is_repeated_field(&field_descriptor) { debug!("Reading length delimited field as a packed repeated field"); - decode_packed_field(field_descriptor, &mut data_buffer)? + decode_packed_field(field_descriptor, descriptor, descriptors, &mut data_buffer)? } else { error!("Was expecting {:?} but received an unknown length-delimited type", t); let mut buf = BytesMut::with_capacity((data_length + 8) as usize); @@ -397,7 +393,25 @@ pub fn decode_message( Ok(fields.iter().sorted_by(|a, b| Ord::cmp(&a.field_num, &b.field_num)).cloned().collect()) } -fn decode_packed_field(field: FieldDescriptorProto, data: &mut Bytes) -> anyhow::Result> { +fn decode_enum( + descriptor: &DescriptorProto, + descriptors: &FileDescriptorSet, + field_descriptor: &FieldDescriptorProto, + varint: u64 +) -> anyhow::Result { + let enum_type_name = field_descriptor.type_name.clone().unwrap_or_default(); + let enum_proto = find_enum_by_name_in_message(&descriptor.enum_type, enum_type_name.as_str()) + .or_else(|| find_enum_by_name(descriptors, enum_type_name.as_str())) + .ok_or_else(|| anyhow!("Did not find the enum {} for the field in the Protobuf descriptor", enum_type_name))?; + Ok(ProtobufFieldData::Enum(varint as i32, enum_proto.clone())) +} + +fn decode_packed_field( + field: FieldDescriptorProto, + descriptor: &DescriptorProto, + descriptors: &FileDescriptorSet, + data: &mut Bytes +) -> anyhow::Result> { let mut values = vec![]; let t: Type = field.r#type(); match t { @@ -429,6 +443,13 @@ fn decode_packed_field(field: FieldDescriptorProto, data: &mut Bytes) -> anyhow: values.push((ProtobufFieldData::Integer32(varint as i32), WireType::Varint)); } } + Type::Enum => { + while data.remaining() > 0 { + let varint = decode_varint(data)?; + let enum_value = decode_enum(descriptor, descriptors, &field, varint)?; + values.push((enum_value, WireType::Varint)); + } + } Type::Fixed64 => { while data.remaining() >= mem::size_of::() { values.push((ProtobufFieldData::UInteger64(data.get_u64_le()), WireType::SixtyFourBit)); @@ -492,6 +513,8 @@ fn find_field_descriptor(field_num: i32, descriptor: &DescriptorProto) -> anyhow #[cfg(test)] mod tests { + use base64::Engine; + use base64::engine::general_purpose::STANDARD as BASE64; use bytes::{BufMut, Bytes, BytesMut}; use expectest::prelude::*; use pact_plugin_driver::proto::InitPluginRequest; @@ -514,6 +537,7 @@ mod tests { }; use crate::message_decoder::{decode_message, ProtobufFieldData}; use crate::protobuf::tests::DESCRIPTOR_WITH_ENUM_BYTES; + use crate::message_builder::tests::REPEATED_ENUM_DESCRIPTORS; const FIELD_1_MESSAGE: [u8; 2] = [8, 1]; const FIELD_2_MESSAGE: [u8; 2] = [16, 55]; @@ -1207,4 +1231,32 @@ mod tests { expect!(field_result.wire_type).to(be_equal_to(WireType::Varint)); expect!(&field_result.data).to(be_equal_to(&ProtobufFieldData::Enum(1, enum_proto.clone()))); } + + #[test_log::test] + fn decode_message_with_repeated_enum_field() { + let bytes = BASE64.decode(REPEATED_ENUM_DESCRIPTORS).unwrap(); + let buffer = Bytes::from(bytes); + let fds: FileDescriptorSet = FileDescriptorSet::decode(buffer).unwrap(); + let main_descriptor = fds.file.iter() + .find(|fd| fd.name.clone().unwrap_or_default() == "repeated_enum.proto") + .unwrap(); + let message_descriptor = main_descriptor.message_type.iter() + .find(|md| md.name.clone().unwrap_or_default() == "BrokenSampleRequest").unwrap(); + let enum_proto = message_descriptor.enum_type.first().unwrap(); + + let message_bytes: &[u8] = &[10, 3, 2, 0, 1]; + let mut buffer = Bytes::from(message_bytes); + let result = decode_message(&mut buffer, &message_descriptor, &fds).unwrap(); + expect!(result.len()).to(be_equal_to(3)); + + expect!(result[0].field_num).to(be_equal_to(1)); + expect!(result[0].wire_type).to(be_equal_to(WireType::Varint)); + expect!(&result[0].data).to(be_equal_to(&ProtobufFieldData::Enum(2, enum_proto.clone()))); + expect!(result[1].field_num).to(be_equal_to(1)); + expect!(result[1].wire_type).to(be_equal_to(WireType::Varint)); + expect!(&result[1].data).to(be_equal_to(&ProtobufFieldData::Enum(0, enum_proto.clone()))); + expect!(result[2].field_num).to(be_equal_to(1)); + expect!(result[2].wire_type).to(be_equal_to(WireType::Varint)); + expect!(&result[2].data).to(be_equal_to(&ProtobufFieldData::Enum(1, enum_proto.clone()))); + } } diff --git a/src/utils.rs b/src/utils.rs index 3950b65..e1cc4d6 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -461,7 +461,7 @@ pub(crate) fn find_service_descriptor<'a>( pub fn should_be_packed_type(field_type: Type) -> bool { matches!(field_type, Type::Double | Type::Float | Type::Int64 | Type::Uint64 | Type::Int32 | Type::Fixed64 | Type::Fixed32 | Type::Uint32 | Type::Sfixed32 | Type::Sfixed64 | Type::Sint32 | - Type::Sint64) + Type::Sint64 | Type::Enum) } /// Tries to convert a Protobuf Value to a Map. Returns an error if the incoming value is not a