Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
address code review feedback: count the arrays themselves
  • Loading branch information
adamsitnik committed Sep 13, 2024
commit 8a61178089d9625e12351adb9eaebca49a02e611
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ private protected ArrayRecord(ArrayInfo arrayInfo)

internal long ValuesToRead { get; private protected set; }

private protected ArrayInfo ArrayInfo { get; }
internal ArrayInfo ArrayInfo { get; }

internal bool IsJagged
=> ArrayInfo.ArrayType == BinaryArrayType.Jagged
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,51 +191,42 @@ private static long GetJaggedArrayFlattenedLength(BinaryArrayRecord jaggedArrayR

Debug.Assert(jaggedArrayRecord.IsJagged);

// In theory somebody could create a payload that would represent
// a very nested array with total elements count > long.MaxValue.
// That is why this method is using checked arithmetic.
result = checked(result + jaggedArrayRecord.Length); // count the arrays themselves

foreach (object value in jaggedArrayRecord.Values)
{
object item = value is MemberReferenceRecord referenceRecord
? referenceRecord.GetReferencedRecord()
: value;

if (item is not SerializationRecord record)
if (value is not SerializationRecord record)
{
result++;
continue;
}

if (record.RecordType == SerializationRecordType.MemberReference)
{
record = ((MemberReferenceRecord)record).GetReferencedRecord();
}

switch (record.RecordType)
{
case SerializationRecordType.BinaryArray:
case SerializationRecordType.ArraySinglePrimitive:
case SerializationRecordType.ArraySingleObject:
case SerializationRecordType.ArraySingleString:
case SerializationRecordType.BinaryArray:
ArrayRecord nestedArrayRecord = (ArrayRecord)record;
if (nestedArrayRecord.IsJagged)
{
(jaggedArrayRecords ??= new()).Enqueue((BinaryArrayRecord)nestedArrayRecord);
}
else
{
Debug.Assert(nestedArrayRecord is not BinaryArrayRecord, "Ensure lack of recursive call");
checked
{
// In theory somebody could create a payload that would represent
// a very nested array with total elements count > long.MaxValue.
result += nestedArrayRecord.FlattenedLength;
}
}
break;
case SerializationRecordType.ObjectNull:
case SerializationRecordType.ObjectNullMultiple256:
case SerializationRecordType.ObjectNullMultiple:
// All nulls need to be included, as it's another form of possible attack.
checked
{
result += ((NullsRecord)item).NullCount;
// Don't call nestedArrayRecord.FlattenedLength to avoid any potential recursion,
// just call nestedArrayRecord.ArrayInfo.FlattenedLength that returns pre-computed value.
result = checked(result + nestedArrayRecord.ArrayInfo.FlattenedLength);
}
break;
default:
result++;
break;
}
}
Expand Down
34 changes: 25 additions & 9 deletions src/libraries/System.Formats.Nrbf/tests/JaggedArraysTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,25 @@ namespace System.Formats.Nrbf.Tests;

public class JaggedArraysTests : ReadTests
{
[Fact]
public void CanReadJaggedArraysOfPrimitiveTypes_2D()
[Theory]
[InlineData(true)]
[InlineData(false)]
public void CanReadJaggedArraysOfPrimitiveTypes_2D(bool useReferences)
{
int[][] input = new int[7][];
int[] same = [1, 2, 3];
for (int i = 0; i < input.Length; i++)
{
input[i] = [i, i, i];
input[i] = useReferences
? same // reuse the same object (represented as a single record that is referenced multiple times)
: [i, i, i]; // create new array
}

var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));

Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(input.Length * 3, arrayRecord.FlattenedLength);
Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength);
}

[Theory]
Expand All @@ -42,13 +47,17 @@ public void FlattenedLengthIncludesNullArrays(int nullCount)
public void ItIsPossibleToHaveBinaryArrayRecordsHaveAnElementTypeOfArrayWithoutBeingMarkedAsJagged()
{
int[][][] input = new int[3][][];
long totalElementsCount = 0;
for (int i = 0; i < input.Length; i++)
{
input[i] = new int[4][];
totalElementsCount++; // count the arrays themselves

for (int j = 0; j < input[i].Length; j++)
{
input[i][j] = [i, j, 0, 1, 2];
totalElementsCount += input[i][j].Length;
totalElementsCount++; // count the arrays themselves
}
}

Expand All @@ -67,25 +76,31 @@ public void ItIsPossibleToHaveBinaryArrayRecordsHaveAnElementTypeOfArrayWithoutB

Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(3 * 4 * 5, arrayRecord.FlattenedLength);
Assert.Equal(3 + 3 * 4 + 3 * 4 * 5, totalElementsCount);
Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength);
}

[Fact]
public void CanReadJaggedArraysOfPrimitiveTypes_3D()
{
int[][][] input = new int[7][][];
long totalElementsCount = 0;
for (int i = 0; i < input.Length; i++)
{
totalElementsCount++; // count the arrays themselves
input[i] = new int[1][];
totalElementsCount++; // count the arrays themselves
input[i][0] = [i, i, i];
totalElementsCount += input[i][0].Length;
}

var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));

Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(1, arrayRecord.Rank);
Assert.Equal(input.Length * 1 * 3, arrayRecord.FlattenedLength);
Assert.Equal(7 + 7 * 1 + 7 * 1 * 3, totalElementsCount);
Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength);
}

[Fact]
Expand All @@ -110,7 +125,7 @@ public void CanReadJaggedArrayOfRectangularArrays()
Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(1, arrayRecord.Rank);
Assert.Equal(input.Length * 3 * 3, arrayRecord.FlattenedLength);
Assert.Equal(input.Length + input.Length * 3 * 3, arrayRecord.FlattenedLength);
}

[Fact]
Expand All @@ -126,7 +141,7 @@ public void CanReadJaggedArraysOfStrings()

Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(input.Length * 3, arrayRecord.FlattenedLength);
Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength);
}

[Fact]
Expand All @@ -142,7 +157,7 @@ public void CanReadJaggedArraysOfObjects()

Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(input.Length * 3, arrayRecord.FlattenedLength);
Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength);
}

[Serializable]
Expand All @@ -160,6 +175,7 @@ public void CanReadJaggedArraysOfComplexTypes()
{
input[i] = Enumerable.Range(0, i + 1).Select(j => new ComplexType { SomeField = j }).ToArray();
totalElementsCount += input[i].Length;
totalElementsCount++; // count the arrays themselves
}

var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));
Expand Down