Skip to content
Prev Previous commit
Next Next commit
[NRBF] More bug fixes (#107682)
- Don't use `Debug.Fail` not followed by an exception (it may cause problems for apps deployed in Debug)
- avoid Int32 overflow
- throw for unexpected enum values just in case parsing has not rejected them
- validate the number of chars read by BinaryReader.ReadChars
- pass serialization record id to ex message
- return false rather than throw EndOfStreamException when provided Stream has not enough data
- don't restore the position in finally 
- limit max SZ and MD array length to Array.MaxLength, stop using LinkedList<T> as List<T> will be able to hold all elements now
- remove internal enum values that were always illegal, but needed to be handled everywhere
- Fix DebuggerDisplay
  • Loading branch information
adamsitnik committed Sep 13, 2024
commit 143384631c4817cedaccab47d25550771106e44e
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace System.Formats.Nrbf;
/// <remarks>
/// ArrayInfo structures are described in <see href="https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/8fac763f-e46d-43a1-b360-80eb83d2c5fb">[MS-NRBF] 2.4.2.1</see>.
/// </remarks>
[DebuggerDisplay("Length={Length}, {ArrayType}, rank={Rank}")]
[DebuggerDisplay("{ArrayType}, rank={Rank}")]
internal readonly struct ArrayInfo
{
internal const int MaxArrayLength = 2147483591; // Array.MaxLength
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,9 @@ internal ArraySinglePrimitiveRecord(ArrayInfo arrayInfo, IReadOnlyList<T> values
public override T[] GetArray(bool allowNulls = true)
=> (T[])(_arrayNullsNotAllowed ??= (Values is T[] array ? array : Values.ToArray()));

internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
{
Debug.Fail("GetAllowedRecordType should never be called on ArraySinglePrimitiveRecord");
throw new InvalidOperationException();
}
internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() => throw new InvalidOperationException();

private protected override void AddValue(object value)
{
Debug.Fail("AddValue should never be called on ArraySinglePrimitiveRecord");
throw new InvalidOperationException();
}
private protected override void AddValue(object value) => throw new InvalidOperationException();

internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int count)
{
Expand Down Expand Up @@ -94,7 +86,7 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
#if NET
reader.BaseStream.ReadExactly(resultAsBytes);
#else
byte[] bytes = ArrayPool<byte>.Shared.Rent(Math.Min(count * Unsafe.SizeOf<T>(), 256_000));
byte[] bytes = ArrayPool<byte>.Shared.Rent((int)Math.Min(requiredBytes, 256_000));

while (!resultAsBytes.IsEmpty)
{
Expand Down Expand Up @@ -159,31 +151,10 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
private static List<decimal> DecodeDecimals(BinaryReader reader, int count)
{
List<decimal> values = new();
#if NET
Span<byte> buffer = stackalloc byte[256];
for (int i = 0; i < count; i++)
{
int stringLength = reader.Read7BitEncodedInt();
if (!(stringLength > 0 && stringLength <= buffer.Length))
{
ThrowHelper.ThrowInvalidValue(stringLength);
}

reader.BaseStream.ReadExactly(buffer.Slice(0, stringLength));

if (!decimal.TryParse(buffer.Slice(0, stringLength), NumberStyles.Number, CultureInfo.InvariantCulture, out decimal value))
{
ThrowHelper.ThrowInvalidFormat();
}

values.Add(value);
}
#else
for (int i = 0; i < count; i++)
{
values.Add(reader.ParseDecimal());
}
#endif
return values;
}

Expand Down Expand Up @@ -244,12 +215,14 @@ private static List<T> DecodeFromNonSeekableStream(BinaryReader reader, int coun
{
values.Add((T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64()));
}
else
else if (typeof(T) == typeof(TimeSpan))
{
Debug.Assert(typeof(T) == typeof(TimeSpan));

values.Add((T)(object)new TimeSpan(reader.ReadInt64()));
}
else
{
throw new InvalidOperationException();
}
}

return values;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ internal sealed class ArraySingleStringRecord : SZArrayRecord<string?>
public override SerializationRecordType RecordType => SerializationRecordType.ArraySingleString;

/// <inheritdoc />
public override TypeName TypeName => TypeNameHelpers.GetPrimitiveSZArrayTypeName(PrimitiveType.String);
public override TypeName TypeName => TypeNameHelpers.GetPrimitiveSZArrayTypeName(TypeNameHelpers.StringPrimitiveType);

private List<SerializationRecord> Records { get; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay
lengths[i] = ArrayInfo.ParseValidArrayLength(reader);
totalElementCount *= lengths[i];

if (totalElementCount > uint.MaxValue)
if (totalElementCount > ArrayInfo.MaxArrayLength)
{
ThrowHelper.ThrowInvalidValue(lengths[i]); // max array size exceeded
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,7 @@ private BinaryLibraryRecord(SerializationRecordId libraryId, AssemblyNameInfo li

public override SerializationRecordType RecordType => SerializationRecordType.BinaryLibrary;

public override TypeName TypeName
{
get
{
Debug.Fail("TypeName should never be called on BinaryLibraryRecord");
return TypeName.Parse(nameof(BinaryLibraryRecord).AsSpan());
}
}
public override TypeName TypeName => TypeName.Parse(nameof(BinaryLibraryRecord).AsSpan());

internal string? RawLibraryName { get; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,14 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt
case BinaryType.Class:
info[i] = (type, ClassTypeInfo.Decode(reader, options, recordMap));
break;
default:
// Other types have no additional data.
Debug.Assert(type is BinaryType.String or BinaryType.ObjectArray or BinaryType.StringArray or BinaryType.Object);
case BinaryType.String:
case BinaryType.StringArray:
case BinaryType.Object:
case BinaryType.ObjectArray:
// These types have no additional data.
break;
default:
throw new InvalidOperationException();
}
}

Expand Down Expand Up @@ -97,7 +101,8 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt
BinaryType.PrimitiveArray => (PrimitiveArray, default),
BinaryType.Class => (NonSystemClass, default),
BinaryType.SystemClass => (SystemClass, default),
_ => (ObjectArray, default)
BinaryType.ObjectArray => (ObjectArray, default),
_ => throw new InvalidOperationException()
};
}

Expand Down Expand Up @@ -144,15 +149,15 @@ internal TypeName GetArrayTypeName(ArrayInfo arrayInfo)

TypeName elementTypeName = binaryType switch
{
BinaryType.String => TypeNameHelpers.GetPrimitiveTypeName(PrimitiveType.String),
BinaryType.StringArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName(PrimitiveType.String),
BinaryType.String => TypeNameHelpers.GetPrimitiveTypeName(TypeNameHelpers.StringPrimitiveType),
BinaryType.StringArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName(TypeNameHelpers.StringPrimitiveType),
BinaryType.Primitive => TypeNameHelpers.GetPrimitiveTypeName((PrimitiveType)additionalInfo!),
BinaryType.PrimitiveArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName((PrimitiveType)additionalInfo!),
BinaryType.Object => TypeNameHelpers.GetPrimitiveTypeName(TypeNameHelpers.ObjectPrimitiveType),
BinaryType.ObjectArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName(TypeNameHelpers.ObjectPrimitiveType),
BinaryType.SystemClass => (TypeName)additionalInfo!,
BinaryType.Class => ((ClassTypeInfo)additionalInfo!).TypeName,
_ => throw new ArgumentOutOfRangeException(paramName: nameof(binaryType), actualValue: binaryType, message: null)
_ => throw new InvalidOperationException()
};

// In general, arrayRank == 1 may have two different meanings:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,5 @@ private MessageEndRecord()

public override SerializationRecordId Id => SerializationRecordId.NoId;

public override TypeName TypeName
{
get
{
Debug.Fail("TypeName should never be called on MessageEndRecord");
return TypeName.Parse(nameof(MessageEndRecord).AsSpan());
}
}
public override TypeName TypeName => TypeName.Parse(nameof(MessageEndRecord).AsSpan());
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,22 @@ public static bool StartsWithPayloadHeader(Stream stream)
return false;
}

try
byte[] buffer = new byte[SerializedStreamHeaderRecord.Size];
int offset = 0;
while (offset < buffer.Length)
{
#if NET
Span<byte> buffer = stackalloc byte[SerializedStreamHeaderRecord.Size];
stream.ReadExactly(buffer);
#else
byte[] buffer = new byte[SerializedStreamHeaderRecord.Size];
int offset = 0;
while (offset < buffer.Length)
int read = stream.Read(buffer, offset, buffer.Length - offset);
if (read == 0)
{
int read = stream.Read(buffer, offset, buffer.Length - offset);
if (read == 0)
throw new EndOfStreamException();
offset += read;
stream.Position = beginning;
return false;
}
#endif
return StartsWithPayloadHeader(buffer);
}
finally
{
stream.Position = beginning;
offset += read;
}

bool result = StartsWithPayloadHeader(buffer);
stream.Position = beginning;
return result;
}

/// <summary>
Expand Down Expand Up @@ -241,7 +235,8 @@ private static SerializationRecord DecodeNext(BinaryReader reader, RecordMap rec
SerializationRecordType.ObjectNullMultiple => ObjectNullMultipleRecord.Decode(reader),
SerializationRecordType.ObjectNullMultiple256 => ObjectNullMultiple256Record.Decode(reader),
SerializationRecordType.SerializedStreamHeader => SerializedStreamHeaderRecord.Decode(reader),
_ => SystemClassWithMembersAndTypesRecord.Decode(reader, recordMap, options),
SerializationRecordType.SystemClassWithMembersAndTypes => SystemClassWithMembersAndTypesRecord.Decode(reader, recordMap, options),
_ => throw new InvalidOperationException()
};

recordMap.Add(record);
Expand Down Expand Up @@ -269,8 +264,8 @@ private static SerializationRecord DecodeMemberPrimitiveTypedRecord(BinaryReader
PrimitiveType.Double => new MemberPrimitiveTypedRecord<double>(reader.ReadDouble()),
PrimitiveType.Decimal => new MemberPrimitiveTypedRecord<decimal>(reader.ParseDecimal()),
PrimitiveType.DateTime => new MemberPrimitiveTypedRecord<DateTime>(Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64())),
// String is handled with a record, never on it's own
_ => new MemberPrimitiveTypedRecord<TimeSpan>(new TimeSpan(reader.ReadInt64())),
PrimitiveType.TimeSpan => new MemberPrimitiveTypedRecord<TimeSpan>(new TimeSpan(reader.ReadInt64())),
_ => throw new InvalidOperationException()
};
}

Expand All @@ -295,7 +290,8 @@ private static SerializationRecord DecodeArraySinglePrimitiveRecord(BinaryReader
PrimitiveType.Double => Decode<double>(info, reader),
PrimitiveType.Decimal => Decode<decimal>(info, reader),
PrimitiveType.DateTime => Decode<DateTime>(info, reader),
_ => Decode<TimeSpan>(info, reader),
PrimitiveType.TimeSpan => Decode<TimeSpan>(info, reader),
_ => throw new InvalidOperationException()
};

static SerializationRecord Decode<T>(ArrayInfo info, BinaryReader reader) where T : unmanaged
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,5 @@ internal abstract class NullsRecord : SerializationRecord

public override SerializationRecordId Id => SerializationRecordId.NoId;

public override TypeName TypeName
{
get
{
Debug.Fail($"TypeName should never be called on {GetType().Name}");
return TypeName.Parse(GetType().Name.AsSpan());
}
}
public override TypeName TypeName => TypeName.Parse(GetType().Name.AsSpan());
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ namespace System.Formats.Nrbf;
/// </remarks>
internal enum PrimitiveType : byte
{
/// <summary>
/// Used internally to express no value
/// </summary>
None = 0,
Boolean = 1,
Byte = 2,
Char = 3,
Expand All @@ -30,7 +26,19 @@ internal enum PrimitiveType : byte
DateTime = 13,
UInt16 = 14,
UInt32 = 15,
UInt64 = 16,
Null = 17,
String = 18
UInt64 = 16
// This internal enum no longer contains Null and String as they were always illegal:
// - In case of BinaryArray (NRBF 2.4.3.1):
// "If the BinaryTypeEnum value is Primitive, the PrimitiveTypeEnumeration
// value in AdditionalTypeInfo MUST NOT be Null (17) or String (18)."
// - In case of MemberPrimitiveTyped (NRBF 2.5.1):
// "PrimitiveTypeEnum (1 byte): A PrimitiveTypeEnumeration
// value that specifies the Primitive Type of data that is being transmitted.
// This field MUST NOT contain a value of 17 (Null) or 18 (String)."
// - In case of ArraySinglePrimitive (NRBF 2.4.3.3):
// "A PrimitiveTypeEnumeration value that identifies the Primitive Type
// of the items of the Array. The value MUST NOT be 17 (Null) or 18 (String)."
// - In case of MemberTypeInfo (NRBF 2.3.1.2):
// "When the BinaryTypeEnum value is Primitive, the PrimitiveTypeEnumeration
// value in AdditionalInfo MUST NOT be Null (17) or String (18)."
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ internal void Add(SerializationRecord record)
return;
}
#endif
throw new SerializationException(SR.Format(SR.Serialization_DuplicateSerializationRecordId, record.Id));
throw new SerializationException(SR.Format(SR.Serialization_DuplicateSerializationRecordId, record.Id._id));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace System.Formats.Nrbf;
internal sealed class RectangularArrayRecord : ArrayRecord
{
private readonly int[] _lengths;
private readonly ICollection<object> _values;
private readonly List<object> _values;
private TypeName? _typeName;

private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo,
Expand All @@ -24,18 +24,8 @@ private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo,
MemberTypeInfo = memberTypeInfo;
_lengths = lengths;

// A List<T> can hold as many objects as an array, so for multi-dimensional arrays
// with more elements than Array.MaxLength we use LinkedList.
// Testing that many elements takes a LOT of time, so to ensure that both code paths are tested,
// we always use LinkedList code path for Debug builds.
#if DEBUG
_values = new LinkedList<object>();
#else
_values = arrayInfo.TotalElementsCount <= ArrayInfo.MaxArrayLength
? new List<object>(canPreAllocate ? arrayInfo.GetSZArrayLength() : Math.Min(4, arrayInfo.GetSZArrayLength()))
: new LinkedList<object>();
#endif

// ArrayInfo.GetSZArrayLength ensures to return a value <= Array.MaxLength
_values = new List<object>(canPreAllocate ? arrayInfo.GetSZArrayLength() : Math.Min(4, arrayInfo.GetSZArrayLength()));
}

public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;
Expand Down Expand Up @@ -108,6 +98,7 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls)
else if (ElementType == typeof(TimeSpan)) CopyTo<TimeSpan>(_values, result);
else if (ElementType == typeof(DateTime)) CopyTo<DateTime>(_values, result);
else if (ElementType == typeof(decimal)) CopyTo<decimal>(_values, result);
else throw new InvalidOperationException();
}
else
{
Expand All @@ -116,7 +107,7 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls)

return result;

static void CopyTo<T>(ICollection<object> list, Array array)
static void CopyTo<T>(List<object> list, Array array)
{
ref byte arrayDataRef = ref MemoryMarshal.GetArrayDataReference(array);
ref T firstElementRef = ref Unsafe.As<byte, T>(ref arrayDataRef);
Expand Down Expand Up @@ -176,7 +167,10 @@ internal static RectangularArrayRecord Create(BinaryReader reader, ArrayInfo arr
PrimitiveType.Int64 => sizeof(long),
PrimitiveType.UInt64 => sizeof(ulong),
PrimitiveType.Double => sizeof(double),
_ => -1
PrimitiveType.TimeSpan => sizeof(ulong),
PrimitiveType.DateTime => sizeof(ulong),
PrimitiveType.Decimal => -1, // represented as variable-length string
_ => throw new InvalidOperationException()
};

if (sizeOfSingleValue > 0)
Expand Down Expand Up @@ -215,7 +209,8 @@ private static Type MapPrimitive(PrimitiveType primitiveType)
PrimitiveType.DateTime => typeof(DateTime),
PrimitiveType.UInt16 => typeof(ushort),
PrimitiveType.UInt32 => typeof(uint),
_ => typeof(ulong)
PrimitiveType.UInt64 => typeof(ulong),
_ => throw new InvalidOperationException()
};

private static Type MapPrimitiveArray(PrimitiveType primitiveType)
Expand All @@ -235,7 +230,8 @@ private static Type MapPrimitiveArray(PrimitiveType primitiveType)
PrimitiveType.DateTime => typeof(DateTime[]),
PrimitiveType.UInt16 => typeof(ushort[]),
PrimitiveType.UInt32 => typeof(uint[]),
_ => typeof(ulong[]),
PrimitiveType.UInt64 => typeof(ulong[]),
_ => throw new InvalidOperationException()
};

private static object? GetActualValue(object value)
Expand Down
Loading