Skip to content
Prev Previous commit
Next Next commit
[NRBF] Comments and bug fixes from internal code review (#107735)
* copy comments and asserts from Levis internal code review

* apply Levis suggestion: don't store Array.MaxLength as a const, as it may change in the future

* add missing and fix some of the existing comments

* first bug fix: SerializationRecord.TypeNameMatches should throw ArgumentNullException for null Type argument

* second bug fix: SerializationRecord.TypeNameMatches should know the difference between SZArray and single-dimension, non-zero offset arrays (example: int[] and int[*])

* third bug fix: don't cast bytes to booleans

* fourth bug fix: don't cast bytes to DateTimes

* add one test case that I've forgot in previous PR
# Conflicts:
#	src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecord.cs
  • Loading branch information
adamsitnik committed Sep 16, 2024
commit 8dd423fd83293cf51547660df26c8404f78a3d76
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

namespace System.Formats.Nrbf;

// See [MS-NRBF] Sec. 2.7 for more information.
// https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/ca3ad2bc-777b-413a-a72a-9ba6ced76bc3

[Flags]
internal enum AllowedRecordTypes : uint
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ namespace System.Formats.Nrbf;
[DebuggerDisplay("{ArrayType}, rank={Rank}")]
internal readonly struct ArrayInfo
{
internal const int MaxArrayLength = 2147483591; // Array.MaxLength
#if NET8_0_OR_GREATER
internal static int MaxArrayLength => Array.MaxLength; // dynamic lookup in case the value changes in a future runtime
#else
internal const int MaxArrayLength = 2147483591; // hardcode legacy Array.MaxLength for downlevel runtimes
#endif

internal ArrayInfo(SerializationRecordId id, long totalElementsCount, BinaryArrayType arrayType = BinaryArrayType.Single, int rank = 1)
{
Expand Down Expand Up @@ -47,7 +51,7 @@ internal static int ParseValidArrayLength(BinaryReader reader)
{
int length = reader.ReadInt32();

if (length is < 0 or > MaxArrayLength)
if (length < 0 || length > MaxArrayLength)
{
ThrowHelper.ThrowInvalidValue(length);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;

namespace System.Formats.Nrbf;

Expand Down Expand Up @@ -54,6 +55,7 @@ public override TypeName TypeName
}

int nullCount = ((NullsRecord)actual).NullCount;
Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount.");
do
{
result[resultIndex++] = null;
Expand All @@ -63,6 +65,8 @@ public override TypeName TypeName
}
}

Debug.Assert(resultIndex == result.Length, "We should have traversed the entirety of the newly created array.");

return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.IO;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;

namespace System.Formats.Nrbf;

Expand Down Expand Up @@ -33,13 +34,15 @@ public override TypeName TypeName
{
object?[] values = new object?[Length];

for (int recordIndex = 0, valueIndex = 0; recordIndex < Records.Count; recordIndex++)
int valueIndex = 0;
for (int recordIndex = 0; recordIndex < Records.Count; recordIndex++)
{
SerializationRecord record = Records[recordIndex];

int nullCount = record is NullsRecord nullsRecord ? nullsRecord.NullCount : 0;
if (nullCount == 0)
{
// "new object[] { <SELF> }" is special cased because it allows for storing reference to itself.
values[valueIndex++] = record is MemberReferenceRecord referenceRecord && referenceRecord.Reference.Equals(Id)
? values // a reference to self, and a way to get StackOverflow exception ;)
: record.GetValue();
Expand All @@ -59,6 +62,8 @@ public override TypeName TypeName
while (nullCount > 0);
}

Debug.Assert(valueIndex == values.Length, "We should have traversed the entirety of the newly created array.");

return values;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,32 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
return (List<T>)(object)DecodeDecimals(reader, count);
}

// char[] has a unique representation in NRBF streams. Typical strings are transcoded
// to UTF-8 and prefixed with the number of bytes in the UTF-8 representation. char[]
// is also serialized as UTF-8, but it is instead prefixed with the number of chars
// in the UTF-16 representation, not the number of bytes in the UTF-8 representation.
// This number doesn't directly precede the UTF-8 contents in the NRBF stream; it's
// instead contained within the ArrayInfo structure (passed to this method as the
// 'count' argument).
//
// The practical consequence of this is that we don't actually know how many UTF-8
// bytes we need to consume in order to ensure we've read 'count' chars. We know that
// an n-length UTF-16 string turns into somewhere between [n .. 3n] UTF-8 bytes.
// The best we can do is that when reading an n-element char[], we'll ensure that
// there are at least n bytes remaining in the input stream. We'll still need to
// account for that even with this check, we might hit EOF before fully populating
// the char[]. But from a safety perspective, it does appropriately limit our
// allocations to be proportional to the amount of data present in the input stream,
// which is a sufficient defense against DoS.

long requiredBytes = count;
if (typeof(T) != typeof(char)) // the input is UTF8
if (typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan))
{
// We can't assume DateTime as represented by the runtime is 8 bytes.
// The only assumption we can make is that it's 8 bytes on the wire.
requiredBytes *= 8;
}
else if (typeof(T) != typeof(char))
{
requiredBytes *= Unsafe.SizeOf<T>();
}
Expand All @@ -79,6 +103,10 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
{
return (T[])(object)reader.ParseChars(count);
}
else if (typeof(T) == typeof(TimeSpan) || typeof(T) == typeof(DateTime))
{
return DecodeTime(reader, count);
}

// It's safe to pre-allocate, as we have ensured there is enough bytes in the stream.
T[] result = new T[count];
Expand Down Expand Up @@ -130,8 +158,7 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
}
#endif
}
else if (typeof(T) == typeof(long) || typeof(T) == typeof(ulong) || typeof(T) == typeof(double)
|| typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan))
else if (typeof(T) == typeof(long) || typeof(T) == typeof(ulong) || typeof(T) == typeof(double))
{
Span<long> span = MemoryMarshal.Cast<T, long>(result);
#if NET
Expand All @@ -145,6 +172,21 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
}
}

if (typeof(T) == typeof(bool))
{
// See DontCastBytesToBooleans test to see what could go wrong.
bool[] booleans = (bool[])(object)result;
resultAsBytes = MemoryMarshal.AsBytes<T>(result);
for (int i = 0; i < booleans.Length; i++)
{
// We don't use the bool array to get the value, as an optimizing compiler or JIT could elide this.
if (resultAsBytes[i] != 0) // it can be any byte different than 0
{
booleans[i] = true; // set it to 1 in explicit way
}
}
}

return result;
}

Expand All @@ -158,8 +200,34 @@ private static List<decimal> DecodeDecimals(BinaryReader reader, int count)
return values;
}

private static T[] DecodeTime(BinaryReader reader, int count)
{
T[] values = new T[count];
for (int i = 0; i < values.Length; i++)
{
if (typeof(T) == typeof(DateTime))
{
values[i] = (T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64());
}
else if (typeof(T) == typeof(TimeSpan))
{
values[i] = (T)(object)new TimeSpan(reader.ReadInt64());
}
else
{
throw new InvalidOperationException();
}
}

return values;
}

private static List<T> DecodeFromNonSeekableStream(BinaryReader reader, int count)
{
// The count arg could originate from untrusted input, so we shouldn't
// pass it as-is to the ctor's capacity arg. We'll instead rely on
// List<T>.Add's O(1) amortization to keep the entire loop O(count).

List<T> values = new List<T>(Math.Min(count, 4));
for (int i = 0; i < count; i++)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.IO;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;

namespace System.Formats.Nrbf;

Expand Down Expand Up @@ -47,7 +48,8 @@ internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetA
{
string?[] values = new string?[Length];

for (int recordIndex = 0, valueIndex = 0; recordIndex < Records.Count; recordIndex++)
int valueIndex = 0;
for (int recordIndex = 0; recordIndex < Records.Count; recordIndex++)
{
SerializationRecord record = Records[recordIndex];

Expand All @@ -73,6 +75,7 @@ record = memberReference.GetReferencedRecord();
}

int nullCount = ((NullsRecord)record).NullCount;
Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount.");
do
{
values[valueIndex++] = null;
Expand All @@ -81,6 +84,8 @@ record = memberReference.GetReferencedRecord();
while (nullCount > 0);
}

Debug.Assert(valueIndex == values.Length, "We should have traversed the entirety of the newly created array.");

return values;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.IO;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;

namespace System.Formats.Nrbf;

Expand Down Expand Up @@ -84,6 +85,10 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls)
case SerializationRecordType.ArraySinglePrimitive:
case SerializationRecordType.ArraySingleObject:
case SerializationRecordType.ArraySingleString:

// Recursion depth is bounded by the depth of arrayType, which is
// a trustworthy Type instance. Don't need to worry about stack overflow.

ArrayRecord nestedArrayRecord = (ArrayRecord)record;
Array nestedArray = nestedArrayRecord.GetArray(actualElementType, allowNulls);
array.SetValue(nestedArray, resultIndex++);
Expand All @@ -97,6 +102,7 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls)
}

int nullCount = ((NullsRecord)item).NullCount;
Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount.");
do
{
array.SetValue(null, resultIndex++);
Expand All @@ -110,6 +116,8 @@ private protected override Array Deserialize(Type arrayType, bool allowNulls)
}
}

Debug.Assert(resultIndex == array.Length, "We should have traversed the entirety of the newly created array.");

return array;
}

Expand All @@ -122,6 +130,7 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay
bool isRectangular = arrayType is BinaryArrayType.Rectangular;

// It is an arbitrary limit in the current CoreCLR type loader.
// Don't change this value without reviewing the loop a few lines below.
const int MaxSupportedArrayRank = 32;

if (rank < 1 || rank > MaxSupportedArrayRank
Expand All @@ -132,18 +141,26 @@ internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, Pay
}

int[] lengths = new int[rank]; // adversary-controlled, but acceptable since upper limit of 32
long totalElementCount = 1;
long totalElementCount = 1; // to avoid integer overflow during the multiplication below
for (int i = 0; i < lengths.Length; i++)
{
lengths[i] = ArrayInfo.ParseValidArrayLength(reader);
totalElementCount *= lengths[i];

// n.b. This forbids "new T[Array.MaxLength, Array.MaxLength, Array.MaxLength, ..., 0]"
// but allows "new T[0, Array.MaxLength, Array.MaxLength, Array.MaxLength, ...]". But
// that's the same behavior that newarr and Array.CreateInstance exhibit, so at least
// we're consistent.

if (totalElementCount > ArrayInfo.MaxArrayLength)
{
ThrowHelper.ThrowInvalidValue(lengths[i]); // max array size exceeded
}
}

// Per BinaryReaderExtensions.ReadArrayType, we do not support nonzero offsets, so
// we don't need to read the NRBF stream 'LowerBounds' field here.

MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, 1, options, recordMap);
ArrayInfo arrayInfo = new(objectId, totalElementCount, arrayType, rank);

Expand Down Expand Up @@ -186,6 +203,9 @@ private static Type MapElementType(Type arrayType, out bool isClassRecord)
Type elementType = arrayType;
int arrayNestingDepth = 0;

// Loop iteration counts are bound by the nesting depth of arrayType,
// which is a trustworthy input. No DoS concerns.

while (elementType.IsArray)
{
elementType = elementType.GetElementType()!;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ internal static ClassInfo Decode(BinaryReader reader)

// Use Dictionary instead of List so that searching for member IDs by name
// is O(n) instead of O(m * n), where m = memberCount and n = memberNameLength,
// in degenerate cases.
// in degenerate cases. Since memberCount may be hostile, don't allow it to be
// used as the initial capacity in the collection instance.
Dictionary<string, int> memberNames = new(StringComparer.Ordinal);
for (int i = 0; i < memberCount; i++)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
namespace System.Formats.Nrbf;

/// <summary>
/// Identifies a class by it's name and library id.
/// Identifies a class by its name and library id.
/// </summary>
/// <remarks>
/// ClassTypeInfo structures are described in <see href="https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/844b24dd-9f82-426e-9b98-05334307a239">[MS-NRBF] 2.1.1.8</see>.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ internal bool ShouldBeRepresentedAsArrayOfClassRecords()
{
// This library tries to minimize the number of concepts the users need to learn to use it.
// Since SZArrays are most common, it provides an SZArrayRecord<T> abstraction.
// Every other array (jagged, multi-dimensional etc) is represented using SZArrayRecord.
// Every other array (jagged, multi-dimensional etc) is represented using ArrayRecord.
// The goal of this method is to determine whether given array can be represented as SZArrayRecord<ClassRecord>.

(BinaryType binaryType, object? additionalInfo) = Infos[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,5 @@ internal NextInfo(AllowedRecordTypes allowed, SerializationRecord parent,
internal PrimitiveType PrimitiveType { get; }

internal NextInfo With(AllowedRecordTypes allowed, PrimitiveType primitiveType)
=> allowed == Allowed && primitiveType == PrimitiveType
? this // previous record was of the same type
: new(allowed, Parent, Stack, primitiveType);
=> new(allowed, Parent, Stack, primitiveType);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public static class NrbfDecoder
// The header consists of:
// - a byte that describes the record type (SerializationRecordType.SerializedStreamHeader)
// - four 32 bit integers:
// - root Id (every value is valid)
// - root Id (every value except of 0 is valid)
// - header Id (value is ignored)
// - major version, it has to be equal 1.
// - minor version, it has to be equal 0.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,17 @@ public PayloadOptions() { }
/// </summary>
/// <value><see langword="true" /> if truncated type names should be reassembled; otherwise, <see langword="false" />.</value>
/// <remarks>
/// <para>
/// Example:
/// TypeName: "Namespace.TypeName`1[[Namespace.GenericArgName"
/// LibraryName: "AssemblyName]]"
/// Is combined into "Namespace.TypeName`1[[Namespace.GenericArgName, AssemblyName]]"
/// </para>
/// <para>
/// Setting this to <see langword="true" /> can render <see cref="NrbfDecoder"/> susceptible to Denial of Service
/// attacks when parsing or handling malicious input.
/// </para>
/// <para>The default value is <see langword="false" />.</para>
/// </remarks>
public bool UndoTruncatedTypeNames { get; set; }
}
Loading