4949# include <google/protobuf/stubs/casts.h>
5050# include <google/protobuf/stubs/strutil.h>
5151# include <google/protobuf/stubs/stl_util.h>
52+ # include <google/protobuf/dynamic_message.h>
53+
5254
5355// clang-format off
5456# include <google/protobuf/port_def.inc>
@@ -580,28 +582,54 @@ TEST(WireFormatTest, ParseMessageSet) {
580582 EXPECT_EQ (message_set.DebugString (), dynamic_message_set.DebugString ());
581583}
582584
583- TEST (WireFormatTest, ParseMessageSetWithReverseTagOrder) {
585+ namespace {
586+ std ::string BuildMessageSetItemStart () {
584587 std ::string data;
585588 {
586- UNITTEST ::TestMessageSetExtension1 message;
587- message.set_i (123 );
588- // Build a MessageSet manually with its message content put before its
589- // type_id.
590589 io ::StringOutputStream output_stream (& data);
591590 io ::CodedOutputStream coded_output (& output_stream);
592591 coded_output.WriteTag (WireFormatLite ::kMessageSetItemStartTag);
592+ }
593+ return data;
594+ }
595+ std ::string BuildMessageSetItemEnd () {
596+ std ::string data;
597+ {
598+ io ::StringOutputStream output_stream (& data);
599+ io ::CodedOutputStream coded_output (& output_stream);
600+ coded_output.WriteTag (WireFormatLite ::kMessageSetItemEndTag);
601+ }
602+ return data;
603+ }
604+ std ::string BuildMessageSetTestExtension1 (int value = 123 ) {
605+ std ::string data;
606+ {
607+ UNITTEST ::TestMessageSetExtension1 message;
608+ message.set_i (value);
609+ io ::StringOutputStream output_stream (& data);
610+ io ::CodedOutputStream coded_output (& output_stream);
593611 // Write the message content first.
594612 WireFormatLite ::WriteTag (WireFormatLite ::kMessageSetMessageNumber,
595613 WireFormatLite ::WIRETYPE_LENGTH_DELIMITED,
596614 & coded_output);
597615 coded_output.WriteVarint32 (message.ByteSizeLong ());
598616 message.SerializeWithCachedSizes (& coded_output);
599- // Write the type id.
600- uint32 type_id = message.GetDescriptor ()- > extension (0 )- > number ();
617+ }
618+ return data;
619+ }
620+ std ::string BuildMessageSetItemTypeId (int extension_number) {
621+ std ::string data;
622+ {
623+ io ::StringOutputStream output_stream (& data);
624+ io ::CodedOutputStream coded_output (& output_stream);
601625 WireFormatLite ::WriteUInt32 (WireFormatLite ::kMessageSetTypeIdNumber,
602- type_id, & coded_output);
603- coded_output.WriteTag (WireFormatLite ::kMessageSetItemEndTag);
626+ extension_number, & coded_output);
604627 }
628+ return data;
629+ }
630+ void ValidateTestMessageSet (const std ::string& test_case,
631+ const std ::string& data) {
632+ SCOPED_TRACE (test_case);
605633 {
606634 PROTO2_WIREFORMAT_UNITTEST ::TestMessageSet message_set;
607635 ASSERT_TRUE (message_set.ParseFromString (data));
@@ -611,6 +639,11 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
611639 .GetExtension (
612640 UNITTEST ::TestMessageSetExtension1 ::message_set_extension)
613641 .i ());
642+
643+ // Make sure it does not contain anything else.
644+ message_set.ClearExtension (
645+ UNITTEST ::TestMessageSetExtension1 ::message_set_extension);
646+ EXPECT_EQ (message_set.SerializeAsString (), " " );
614647 }
615648 {
616649 // Test parse the message via Reflection.
@@ -626,6 +659,61 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
626659 UNITTEST ::TestMessageSetExtension1 ::message_set_extension)
627660 .i ());
628661 }
662+ {
663+ // Test parse the message via DynamicMessage.
664+ DynamicMessageFactory factory;
665+ std ::unique_ptr< Message> msg (
666+ factory
667+ .GetPrototype (
668+ PROTO2_WIREFORMAT_UNITTEST ::TestMessageSet ::descriptor ())
669+ - > New ());
670+ msg- > ParseFromString (data);
671+ auto* reflection = msg- > GetReflection ();
672+ std ::vector< const FieldDescriptor* > fields;
673+ reflection- > ListFields (* msg, & fields);
674+ ASSERT_EQ (fields.size (), 1 );
675+ const auto& sub = reflection- > GetMessage (* msg, fields[0 ]);
676+ reflection = sub.GetReflection ();
677+ EXPECT_EQ (123 , reflection- > GetInt32 (
678+ sub, sub.GetDescriptor ()- > FindFieldByName (" i" )));
679+ }
680+ }
681+ } // namespace
682+
683+ TEST (WireFormatTest, ParseMessageSetWithAnyTagOrder) {
684+ std ::string start = BuildMessageSetItemStart ();
685+ std ::string end = BuildMessageSetItemEnd ();
686+ std ::string id = BuildMessageSetItemTypeId (
687+ UNITTEST ::TestMessageSetExtension1 ::descriptor ()- > extension (0 )- > number ());
688+ std ::string message = BuildMessageSetTestExtension1 ();
689+
690+ ValidateTestMessageSet (" id + message" , start + id + message + end);
691+ ValidateTestMessageSet (" message + id" , start + message + id + end);
692+ }
693+
694+ TEST (WireFormatTest, ParseMessageSetWithDuplicateTags) {
695+ std ::string start = BuildMessageSetItemStart ();
696+ std ::string end = BuildMessageSetItemEnd ();
697+ std ::string id = BuildMessageSetItemTypeId (
698+ UNITTEST ::TestMessageSetExtension1 ::descriptor ()- > extension (0 )- > number ());
699+ std ::string other_id = BuildMessageSetItemTypeId (123456 );
700+ std ::string message = BuildMessageSetTestExtension1 ();
701+ std ::string other_message = BuildMessageSetTestExtension1 (321 );
702+
703+ // Double id
704+ ValidateTestMessageSet (" id + other_id + message" ,
705+ start + id + other_id + message + end);
706+ ValidateTestMessageSet (" id + message + other_id" ,
707+ start + id + message + other_id + end);
708+ ValidateTestMessageSet (" message + id + other_id" ,
709+ start + message + id + other_id + end);
710+ // Double message
711+ ValidateTestMessageSet (" id + message + other_message" ,
712+ start + id + message + other_message + end);
713+ ValidateTestMessageSet (" message + id + other_message" ,
714+ start + message + id + other_message + end);
715+ ValidateTestMessageSet (" message + other_message + id" ,
716+ start + message + other_message + id + end);
629717}
630718
631719void SerializeReverseOrder (
0 commit comments