From 6d7206052cbbe0e366ed153c3f35baa33c076380 Mon Sep 17 00:00:00 2001 From: Heath Baron-Morgan Date: Tue, 6 Dec 2022 23:28:33 -0800 Subject: [PATCH 01/17] Allow Base64Decoder to ignore space chars, add IsValid methods and tests Chars now ignored: 9: Line feed 10: Horizontal tab 13: Carriage return 32: Space -- Vertical tab omitted --- .../tests/Base64/Base64DecoderUnitTests.cs | 73 ++- .../tests/Base64/Base64TestBase.cs | 110 ++++ .../tests/Base64/Base64TestHelper.cs | 14 + .../tests/Base64/Base64ValidationUnitTests.cs | 218 ++++++++ .../tests/System.Memory.Tests.csproj | 2 + .../System.Private.CoreLib.Shared.projitems | 1 + .../src/System/Buffers/Text/Base64Decoder.cs | 529 ++++++++++++------ .../System/Buffers/Text/Base64Validator.cs | 214 +++++++ .../System.Runtime/ref/System.Runtime.cs | 4 + 9 files changed, 985 insertions(+), 180 deletions(-) create mode 100644 src/libraries/System.Memory/tests/Base64/Base64TestBase.cs create mode 100644 src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs create mode 100644 src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs diff --git a/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs b/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs index 4e11de38490591..883941e549f803 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs @@ -1,12 +1,13 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Linq; using System.Text; using Xunit; namespace System.Buffers.Text.Tests { - public class Base64DecoderUnitTests + public class Base64DecoderUnitTests : Base64TestBase { [Fact] public void BasicDecoding() @@ -144,7 +145,7 @@ public void DecodingOutputTooSmall() Span decodedBytes = new byte[3]; int consumed, written; - if (numBytes % 4 == 0) + if (numBytes >= 8) { Assert.True(OperationStatus.DestinationTooSmall == Base64.DecodeFromUtf8(source, decodedBytes, out consumed, out written), "Number of Input Bytes: " + numBytes); @@ -360,7 +361,9 @@ public void DecodingInvalidBytes(bool isFinalBlock) for (int i = 0; i < invalidBytes.Length; i++) { // Don't test padding (byte 61 i.e. '='), which is tested in DecodingInvalidBytesPadding - if (invalidBytes[i] == Base64TestHelper.EncodingPad) + // Don't test chars to be ignored (spaces: 9, 10, 13, 32 i.e. '\n', '\t', '\r', ' ') + if (invalidBytes[i] == Base64TestHelper.EncodingPad + || Base64TestHelper.IsByteToBeIgnored(invalidBytes[i])) continue; // replace one byte with an invalid input @@ -555,7 +558,9 @@ public void DecodeInPlaceInvalidBytes() Span buffer = "2222PPPP"u8.ToArray(); // valid input // Don't test padding (byte 61 i.e. '='), which is tested in DecodeInPlaceInvalidBytesPadding - if (invalidBytes[i] == Base64TestHelper.EncodingPad) + // Don't test chars to be ignored (spaces: 9, 10, 13, 32 i.e. '\n', '\t', '\r', ' ') + if (invalidBytes[i] == Base64TestHelper.EncodingPad + || Base64TestHelper.IsByteToBeIgnored(invalidBytes[i])) continue; // replace one byte with an invalid input @@ -581,7 +586,7 @@ public void DecodeInPlaceInvalidBytes() { Span buffer = "2222PPP"u8.ToArray(); // incomplete input Assert.Equal(OperationStatus.InvalidData, Base64.DecodeFromUtf8InPlace(buffer, out int bytesWritten)); - Assert.Equal(0, bytesWritten); + Assert.Equal(3, bytesWritten); } } @@ -654,5 +659,63 @@ public void DecodeInPlaceInvalidBytesPadding() } } + [Theory] + [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] + public void BasicDecodingIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored); + byte[] resultBytes = new byte[5]; + OperationStatus result = Base64.DecodeFromUtf8(utf8BytesWithByteToBeIgnored, resultBytes, out int bytesConsumed, out int bytesWritten); + + // Control value from Convert.FromBase64String + byte[] stringBytes = Convert.FromBase64String(utf8WithCharsToBeIgnored); + + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(8, bytesConsumed); + Assert.Equal(expectedBytes.Length, bytesWritten); + Assert.True(expectedBytes.SequenceEqual(resultBytes)); + Assert.True(stringBytes.SequenceEqual(resultBytes)); + } + + [Theory] + [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] + public void DecodeInPlaceIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes) + { + Span utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored); + OperationStatus result = Base64.DecodeFromUtf8InPlace(utf8BytesWithByteToBeIgnored, out int bytesWritten); + Span bytesOverwritten = utf8BytesWithByteToBeIgnored.Slice(0, bytesWritten); + byte[] resultBytesArray = bytesOverwritten.ToArray(); + + // Control value from Convert.FromBase64String + byte[] stringBytes = Convert.FromBase64String(utf8WithCharsToBeIgnored); + + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(expectedBytes.Length, bytesWritten); + Assert.True(expectedBytes.SequenceEqual(resultBytesArray)); + Assert.True(stringBytes.SequenceEqual(resultBytesArray)); + } + + [Theory] + [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))] + public void BasicDecodingWithOnlyCharsToBeIgnored(string utf8WithCharsToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored); + byte[] resultBytes = new byte[5]; + OperationStatus result = Base64.DecodeFromUtf8(utf8BytesWithByteToBeIgnored, resultBytes, out int bytesConsumed, out int bytesWritten); + + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(0, bytesWritten); + } + + [Theory] + [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))] + public void DecodingInPlaceWithOnlyCharsToBeIgnored(string utf8WithCharsToBeIgnored) + { + Span utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored); + OperationStatus result = Base64.DecodeFromUtf8InPlace(utf8BytesWithByteToBeIgnored, out int bytesWritten); + + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(0, bytesWritten); + } } } diff --git a/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs b/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs new file mode 100644 index 00000000000000..21f116da83464a --- /dev/null +++ b/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs @@ -0,0 +1,110 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Text; + +namespace System.Buffers.Text.Tests +{ + public class Base64TestBase + { + public static IEnumerable ValidBase64Strings_WithCharsThatMustBeIgnored() + { + // Create a Base64 string + string text = "a b c"; + byte[] utf8Bytes = Encoding.UTF8.GetBytes(text); + string base64Utf8String = Convert.ToBase64String(utf8Bytes); + + // Split the base64 string in half + int stringLength = base64Utf8String.Length / 2; + string firstSegment = base64Utf8String.Substring(0, stringLength); + string secondSegment = base64Utf8String.Substring(stringLength, stringLength); + + // Insert ignored chars between the base 64 string + // One will have 1 char, another will have 3 + + // Line feed + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(9), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(9), 3), utf8Bytes }; + + // Horizontal tab + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(10), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(10), 3), utf8Bytes }; + + // Carriage return + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(13), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(13), 3), utf8Bytes }; + + // Space + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(32), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(32), 3), utf8Bytes }; + + string GetBase64StringWithPassedCharInsertedInTheMiddle(char charToInsert, int numberOfTimesToInsert) => $"{firstSegment}{new string(charToInsert, numberOfTimesToInsert)}{secondSegment}"; + + // Insert ignored chars at the start of the base 64 string + // One will have 1 char, another will have 3 + + // Line feed + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(9), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(9), 3), utf8Bytes }; + + // Horizontal tab + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(10), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(10), 3), utf8Bytes }; + + // Carriage return + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(13), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(13), 3), utf8Bytes }; + + // Space + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(32), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(32), 3), utf8Bytes }; + + string GetBase64StringWithPassedCharInsertedAtTheStart(char charToInsert, int numberOfTimesToInsert) => $"{new string(charToInsert, numberOfTimesToInsert)}{firstSegment}{secondSegment}"; + + // Insert ignored chars at the end of the base 64 string + // One will have 1 char, another will have 3 + + // Line feed + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(9), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(9), 3), utf8Bytes }; + + // Horizontal tab + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(10), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(10), 3), utf8Bytes }; + + // Carriage return + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(13), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(13), 3), utf8Bytes }; + + // Space + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(32), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(32), 3), utf8Bytes }; + + string GetBase64StringWithPassedCharInsertedAtTheEnd(char charToInsert, int numberOfTimesToInsert) => $"{firstSegment}{secondSegment}{new string(charToInsert, numberOfTimesToInsert)}"; + } + + public static IEnumerable StringsOnlyWithCharsToBeIgnored() + { + // One will have 1 char, another will have 3 + + // Line feed + yield return new object[] { GetRepeatedChar(Convert.ToChar(9), 1) }; + yield return new object[] { GetRepeatedChar(Convert.ToChar(9), 3) }; + + // Horizontal tab + yield return new object[] { GetRepeatedChar(Convert.ToChar(10), 1) }; + yield return new object[] { GetRepeatedChar(Convert.ToChar(10), 3) }; + + // Carriage return + yield return new object[] { GetRepeatedChar(Convert.ToChar(13), 1) }; + yield return new object[] { GetRepeatedChar(Convert.ToChar(13), 3) }; + + // Space + yield return new object[] { GetRepeatedChar(Convert.ToChar(32), 1) }; + yield return new object[] { GetRepeatedChar(Convert.ToChar(32), 3) }; + + string GetRepeatedChar(char charToInsert, int numberOfTimesToInsert) => new string(charToInsert, numberOfTimesToInsert); + } + } +} diff --git a/src/libraries/System.Memory/tests/Base64/Base64TestHelper.cs b/src/libraries/System.Memory/tests/Base64/Base64TestHelper.cs index 7715f6b5d4bdf3..f6c9db4bf0af1a 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64TestHelper.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64TestHelper.cs @@ -44,6 +44,20 @@ public static class Base64TestHelper -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, }; + public static bool IsByteToBeIgnored(byte charByte) + { + switch (charByte) + { + case 9: // Line feed + case 10: // Horizontal tab + case 13: // Carriage return + case 32: // Space + return true; + default: + return false; + } + } + public const byte EncodingPad = (byte)'='; // '=', for padding public const sbyte InvalidByte = -1; // Designating -1 for invalid bytes in the decoding map diff --git a/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs new file mode 100644 index 00000000000000..801a36f3fcd8cd --- /dev/null +++ b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs @@ -0,0 +1,218 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Linq; +using System.Text; +using Xunit; + +namespace System.Buffers.Text.Tests +{ + public class Base64ValidationUnitTests : Base64TestBase + { + [Fact] + public void BasicValidationBytes() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 != 0); // ensure we have a valid length + + Span source = new byte[numBytes]; + Base64TestHelper.InitializeDecodableBytes(source, numBytes); + + Assert.True(Base64.IsValid(source)); + Assert.True(Base64.IsValid(source, out int decodedLength)); + Assert.True(decodedLength > 0); + } + } + + [Fact] + public void BasicValidationChars() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 != 0); // ensure we have a valid length + + Span source = new byte[numBytes]; + Base64TestHelper.InitializeDecodableBytes(source, numBytes); + Span chars = source + .ToArray() + .Select(Convert.ToChar) + .ToArray() + .AsSpan(); + + Assert.True(Base64.IsValid(chars)); + Assert.True(Base64.IsValid(chars, out int decodedLength)); + Assert.True(decodedLength > 0); + } + } + + [Fact] + public void BasicValidationInvalidInputLengthBytes() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 == 0); // ensure we have a invalid length + + Span source = new byte[numBytes]; + + Assert.False(Base64.IsValid(source)); + Assert.False(Base64.IsValid(source, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + } + + [Fact] + public void BasicValidationInvalidInputLengthChars() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 == 0); // ensure we have a invalid length + + Span source = new char[numBytes]; + + Assert.False(Base64.IsValid(source)); + Assert.False(Base64.IsValid(source, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + } + + [Fact] + public void ValidateEmptySpanBytes() + { + Span source = Span.Empty; + + Assert.True(Base64.IsValid(source)); + Assert.True(Base64.IsValid(source, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Fact] + public void ValidateEmptySpanChars() + { + Span source = Span.Empty; + + Assert.True(Base64.IsValid(source)); + Assert.True(Base64.IsValid(source, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Fact] + public void ValidateGuidBytes() + { + Span source = new byte[24]; + Span decodedBytes = Guid.NewGuid().ToByteArray(); + Base64.EncodeToUtf8(decodedBytes, source, out int _, out int _); + + Assert.True(Base64.IsValid(source)); + Assert.True(Base64.IsValid(source, out int decodedLength)); + Assert.True(decodedLength > 0); + } + + [Fact] + public void ValidateGuidChars() + { + Span source = new byte[24]; + Span decodedBytes = Guid.NewGuid().ToByteArray(); + Base64.EncodeToUtf8(decodedBytes, source, out int _, out int _); + Span chars = source + .ToArray() + .Select(Convert.ToChar) + .ToArray() + .AsSpan(); + + Assert.True(Base64.IsValid(chars)); + Assert.True(Base64.IsValid(chars, out int decodedLength)); + Assert.True(decodedLength > 0); + } + + [Theory] + [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] + public void ValidateBytesIgnoresCharsToBeIgnoredBytes(string utf8WithByteToBeIgnored, byte[] expectedBytes) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedBytes.Length, decodedLength); + } + + [Theory] + [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] + public void ValidateBytesIgnoresCharsToBeIgnoredChars(string utf8WithByteToBeIgnored, byte[] expectedBytes) + { + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray(); + + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedBytes.Length, decodedLength); + } + + [Theory] + [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))] + public void ValidateWithOnlyCharsToBeIgnoredBytes(string utf8WithByteToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Theory] + [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))] + public void ValidateWithOnlyCharsToBeIgnoredChars(string utf8WithByteToBeIgnored) + { + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray(); + + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Theory] + [InlineData("YQ==", 1)] + [InlineData("YWI=", 2)] + [InlineData("YWJj", 3)] + public void ValidateWithPaddingReturnsCorrectCount(string utf8WithByteToBeIgnored, int expectedLength) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedLength, decodedLength); + } + + [Theory] + [InlineData("YQ==", 1)] + [InlineData("YWI=", 2)] + [InlineData("YWJj", 3)] + public void DecodeEmptySpan(string utf8WithByteToBeIgnored, int expectedLength) + { + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray(); + + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedLength, decodedLength); + } + } +} diff --git a/src/libraries/System.Memory/tests/System.Memory.Tests.csproj b/src/libraries/System.Memory/tests/System.Memory.Tests.csproj index e7758562dcb5b5..e34ecfa3387b86 100644 --- a/src/libraries/System.Memory/tests/System.Memory.Tests.csproj +++ b/src/libraries/System.Memory/tests/System.Memory.Tests.csproj @@ -270,6 +270,8 @@ + + + diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs index 2ec9959536427e..2c3406d4c9048f 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs @@ -46,7 +46,7 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa fixed (byte* srcBytes = &MemoryMarshal.GetReference(utf8)) fixed (byte* destBytes = &MemoryMarshal.GetReference(bytes)) { - int srcLength = utf8.Length & ~0x3; // only decode input up to the closest multiple of 4. + int srcLength = utf8.Length; // only decode input up to the closest multiple of 4. int destLength = bytes.Length; int maxSrcLength = srcLength; int decodedLength = GetMaxDecodedFromUtf8Length(srcLength); @@ -60,9 +60,12 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa byte* src = srcBytes; byte* dest = destBytes; + byte* destEnd = dest + (uint)destLength; byte* srcEnd = srcBytes + (uint)srcLength; byte* srcMax = srcBytes + (uint)maxSrcLength; + int totalBytesIgnored = 0; + if (maxSrcLength >= 24) { byte* end = srcMax - 45; @@ -84,142 +87,199 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa } } - // Last bytes could have padding characters, so process them separately and treat them as valid only if isFinalBlock is true - // if isFinalBlock is false, padding characters are considered invalid - int skipLastChunk = isFinalBlock ? 4 : 0; - - if (destLength >= decodedLength) - { - maxSrcLength = srcLength - skipLastChunk; - } - else - { - // This should never overflow since destLength here is less than int.MaxValue / 4 * 3 (i.e. 1610612733) - // Therefore, (destLength / 3) * 4 will always be less than 2147483641 - Debug.Assert(destLength < (int.MaxValue / 4 * 3)); - maxSrcLength = (destLength / 3) * 4; - } - ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap); - srcMax = srcBytes + (uint)maxSrcLength; - while (src < srcMax) - { - int result = Decode(src, ref decodingMap); + // The next src increment is stored as it will be used if the dest has enough space + // or ignored in consumed operations if not. + int pendingSrcIncrement = 0; + while (src + 4 <= srcEnd) + { + // The default increment will be 4 if no bytes that require ignoring are encountered. + pendingSrcIncrement = 4; + byte b0 = src[0]; + byte b1 = src[1]; + byte b2 = src[2]; + byte b3 = src[3]; + + int result = Decode(b0, b1, b2, b3, ref decodingMap); if (result < 0) - goto InvalidDataExit; + { + int firstInvalidIndex = GetIndexOfFirstByteToBeIgnored(src); + if (firstInvalidIndex != -1) + { + int bytesIgnored = 0; + int validBytesSearchIndex = firstInvalidIndex; + bool insufficientValidBytesFound = false; + + for (int currentBlockIndex = firstInvalidIndex; currentBlockIndex < 4; currentBlockIndex++) + { + while (src + validBytesSearchIndex < srcEnd + && IsByteToBeIgnored(src[validBytesSearchIndex])) + { + validBytesSearchIndex++; + bytesIgnored++; + totalBytesIgnored++; + } + + if (src + validBytesSearchIndex >= srcEnd) + { + insufficientValidBytesFound = true; + break; + } + + if (currentBlockIndex == 0) + { + b0 = src[validBytesSearchIndex]; + } + else if (currentBlockIndex == 1) + { + b1 = src[validBytesSearchIndex]; + } + else if (currentBlockIndex == 2) + { + b2 = src[validBytesSearchIndex]; + } + else + { + b3 = src[validBytesSearchIndex]; + } + + validBytesSearchIndex++; + } + + if (insufficientValidBytesFound) + { + break; + } + + result = Decode(b0, b1, b2, b3, ref decodingMap); + if (result < 0 + && !IsBlockEndBytesPadding(b2, b3)) + { + goto InvalidDataExit; + } + + pendingSrcIncrement = validBytesSearchIndex; + } + else + { + if (!IsBlockEndBytesPadding(b2, b3)) + { + goto InvalidDataExit; + } + } + + // Check to see if parsing failed due to padding. There could be 1 or 2 padding chars. + if (result < 0 + && IsBlockEndBytesPadding(b2, b3)) + { + int indexOfBytesAfterPadding = pendingSrcIncrement; + while (src + indexOfBytesAfterPadding + 1 <= srcEnd) + { + if (!IsByteToBeIgnored(src[indexOfBytesAfterPadding++])) + { + // Only bytes to be ignored can be after padding bytes. + goto InvalidDataExit; + } + } + + // If isFinalBlock is false, padding is treaded as invalid. + if (!isFinalBlock) + { + goto InvalidDataExit; + } + + int i0 = Unsafe.Add(ref decodingMap, (IntPtr)b0); + int i1 = Unsafe.Add(ref decodingMap, (IntPtr)b1); + + i0 <<= 18; + i1 <<= 12; + + i0 |= i1; + + if (b2 != EncodingPad) + { + int i2 = Unsafe.Add(ref decodingMap, (IntPtr)b2); + + i2 <<= 6; + + i0 |= i2; + + if (i0 < 0) + goto InvalidDataExit; + if (dest + 2 > destEnd) + goto DestinationTooSmallExit; + + dest[0] = (byte)(i0 >> 16); + dest[1] = (byte)(i0 >> 8); + dest += 2; + } + else + { + if (i0 < 0) + goto InvalidDataExit; + if (dest + 1 > destEnd) + goto DestinationTooSmallExit; + + dest[0] = (byte)(i0 >> 16); + dest += 1; + } + + src += pendingSrcIncrement; + pendingSrcIncrement = 0; + + break; + } + } + + if (dest + 3 > destEnd) + { + goto DestinationTooSmallExit; + } WriteThreeLowOrderBytes(dest, result); - src += 4; + + src += pendingSrcIncrement; + pendingSrcIncrement = 0; dest += 3; } - if (maxSrcLength != srcLength - skipLastChunk) - goto DestinationTooSmallExit; - - // If input is less than 4 bytes, srcLength == sourceIndex == 0 - // If input is not a multiple of 4, sourceIndex == srcLength != 0 - if (src == srcEnd) + if (!isFinalBlock) { - if (isFinalBlock) - goto InvalidDataExit; - - if (src == srcBytes + utf8.Length) - goto DoneExit; - - goto NeedMoreDataExit; - } - - // if isFinalBlock is false, we will never reach this point - - // Handle last four bytes. There are 0, 1, 2 padding chars. - uint t0 = srcEnd[-4]; - uint t1 = srcEnd[-3]; - uint t2 = srcEnd[-2]; - uint t3 = srcEnd[-1]; - - int i0 = Unsafe.Add(ref decodingMap, (IntPtr)t0); - int i1 = Unsafe.Add(ref decodingMap, (IntPtr)t1); - - i0 <<= 18; - i1 <<= 12; - - i0 |= i1; - - byte* destMax = destBytes + (uint)destLength; - - if (t3 != EncodingPad) - { - int i2 = Unsafe.Add(ref decodingMap, (IntPtr)t2); - int i3 = Unsafe.Add(ref decodingMap, (IntPtr)t3); - - i2 <<= 6; - - i0 |= i3; - i0 |= i2; - - if (i0 < 0) - goto InvalidDataExit; - if (dest + 3 > destMax) - goto DestinationTooSmallExit; - - WriteThreeLowOrderBytes(dest, i0); - dest += 3; + int remainingBytes = (int)(srcEnd - src); + if (remainingBytes > 0 && remainingBytes < 4) + { + goto NeedMoreDataExit; + } } - else if (t2 != EncodingPad) - { - int i2 = Unsafe.Add(ref decodingMap, (IntPtr)t2); - - i2 <<= 6; - i0 |= i2; - - if (i0 < 0) - goto InvalidDataExit; - if (dest + 2 > destMax) - goto DestinationTooSmallExit; - - dest[0] = (byte)(i0 >> 16); - dest[1] = (byte)(i0 >> 8); - dest += 2; - } - else + int indexOfBytesNotConsumed = pendingSrcIncrement; + while (src + indexOfBytesNotConsumed + 1 <= srcEnd) { - if (i0 < 0) + if (!IsByteToBeIgnored(src[indexOfBytesNotConsumed++])) + { goto InvalidDataExit; - if (dest + 1 > destMax) - goto DestinationTooSmallExit; - - dest[0] = (byte)(i0 >> 16); - dest += 1; + } } - src += 4; - - if (srcLength != utf8.Length) - goto InvalidDataExit; - DoneExit: - bytesConsumed = (int)(src - srcBytes); + bytesConsumed = ((int)(src - srcBytes)) - totalBytesIgnored; bytesWritten = (int)(dest - destBytes); return OperationStatus.Done; DestinationTooSmallExit: - if (srcLength != utf8.Length && isFinalBlock) - goto InvalidDataExit; // if input is not a multiple of 4, and there is no more data, return invalid data instead - - bytesConsumed = (int)(src - srcBytes); + bytesConsumed = ((int)(src - srcBytes)) - totalBytesIgnored; bytesWritten = (int)(dest - destBytes); return OperationStatus.DestinationTooSmall; NeedMoreDataExit: - bytesConsumed = (int)(src - srcBytes); + bytesConsumed = ((int)(src - srcBytes)) - totalBytesIgnored; bytesWritten = (int)(dest - destBytes); return OperationStatus.NeedMoreData; InvalidDataExit: - bytesConsumed = (int)(src - srcBytes); + bytesConsumed = ((int)(src - srcBytes)) - totalBytesIgnored; bytesWritten = (int)(dest - destBytes); return OperationStatus.InvalidData; } @@ -269,75 +329,154 @@ public static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, ou uint sourceIndex = 0; uint destIndex = 0; - // only decode input if it is a multiple of 4 - if (bufferLength != ((bufferLength >> 2) * 4)) - goto InvalidExit; + int totalBytesIgnored = 0; + if (bufferLength == 0) goto DoneExit; ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap); - while (sourceIndex < bufferLength - 4) + while (sourceIndex <= bufferLength - 4) { - int result = Decode(bufferBytes + sourceIndex, ref decodingMap); + // The default increment will be 4 if no bytes that require ignoring are encountered. + uint nextSoureIndex = sourceIndex + 4; + byte b0 = bufferBytes[sourceIndex]; + byte b1 = bufferBytes[sourceIndex + 1]; + byte b2 = bufferBytes[sourceIndex + 2]; + byte b3 = bufferBytes[sourceIndex + 3]; + + int result = Decode(b0, b1, b2, b3, ref decodingMap); if (result < 0) - goto InvalidExit; - WriteThreeLowOrderBytes(bufferBytes + destIndex, result); - destIndex += 3; - sourceIndex += 4; - } - - uint t0 = bufferBytes[bufferLength - 4]; - uint t1 = bufferBytes[bufferLength - 3]; - uint t2 = bufferBytes[bufferLength - 2]; - uint t3 = bufferBytes[bufferLength - 1]; - - int i0 = Unsafe.Add(ref decodingMap, (IntPtr)t0); - int i1 = Unsafe.Add(ref decodingMap, (IntPtr)t1); - - i0 <<= 18; - i1 <<= 12; - - i0 |= i1; - - if (t3 != EncodingPad) - { - int i2 = Unsafe.Add(ref decodingMap, (IntPtr)t2); - int i3 = Unsafe.Add(ref decodingMap, (IntPtr)t3); - - i2 <<= 6; - - i0 |= i3; - i0 |= i2; + { + int firstInvalidIndex = GetIndexOfFirstByteToBeIgnored(bufferBytes + sourceIndex); + if (firstInvalidIndex != -1) + { + int bytesIgnored = 0; + uint validBytesSearchIndex = (uint)firstInvalidIndex + sourceIndex; + bool insufficientValidBytesFound = false; + + for (int currentBlockIndex = firstInvalidIndex; currentBlockIndex < 4; currentBlockIndex++) + { + while (validBytesSearchIndex <= bufferLength - 1 + && IsByteToBeIgnored(bufferBytes[validBytesSearchIndex])) + { + validBytesSearchIndex++; + bytesIgnored++; + totalBytesIgnored++; + } + + if (validBytesSearchIndex > bufferLength - 1) + { + insufficientValidBytesFound = true; + break; + } + + if (currentBlockIndex == 0) + { + b0 = bufferBytes[validBytesSearchIndex]; + } + else if (currentBlockIndex == 1) + { + b1 = bufferBytes[validBytesSearchIndex]; + } + else if (currentBlockIndex == 2) + { + b2 = bufferBytes[validBytesSearchIndex]; + } + else + { + b3 = bufferBytes[validBytesSearchIndex]; + } + + validBytesSearchIndex++; + } + + if (insufficientValidBytesFound) + { + break; + } + + result = Decode(b0, b1, b2, b3, ref decodingMap); + if (result < 0 + && !IsBlockEndBytesPadding(b2, b3)) + { + goto InvalidExit; + } + + nextSoureIndex = validBytesSearchIndex; + } + else + { + if (!IsBlockEndBytesPadding(b2, b3)) + { + goto InvalidExit; + } + } + + // Handle last four bytes. There are 1, 2 padding chars. + if (result < 0 + && IsBlockEndBytesPadding(b2, b3)) + { + uint indexOfBytesAfterPadding = sourceIndex + nextSoureIndex; + while (indexOfBytesAfterPadding + 1 <= bufferLength - 1) + { + if (!IsByteToBeIgnored(bufferBytes[indexOfBytesAfterPadding++])) + { + // Only bytes to be ignored can be after padding bytes. + goto InvalidExit; + } + } + + int i0 = Unsafe.Add(ref decodingMap, (IntPtr)b0); + int i1 = Unsafe.Add(ref decodingMap, (IntPtr)b1); + + i0 <<= 18; + i1 <<= 12; + + i0 |= i1; + + if (b2 != EncodingPad) + { + int i2 = Unsafe.Add(ref decodingMap, (IntPtr)b2); + + i2 <<= 6; + + i0 |= i2; + + if (i0 < 0) + goto InvalidExit; + + bufferBytes[destIndex] = (byte)(i0 >> 16); + bufferBytes[destIndex + 1] = (byte)(i0 >> 8); + destIndex += 2; + } + else + { + if (i0 < 0) + goto InvalidExit; + + bufferBytes[destIndex] = (byte)(i0 >> 16); + destIndex += 1; + } + + sourceIndex = nextSoureIndex; - if (i0 < 0) - goto InvalidExit; + goto DoneExit; + } + } - WriteThreeLowOrderBytes(bufferBytes + destIndex, i0); + WriteThreeLowOrderBytes(bufferBytes + destIndex, result); destIndex += 3; + sourceIndex = nextSoureIndex; } - else if (t2 != EncodingPad) - { - int i2 = Unsafe.Add(ref decodingMap, (IntPtr)t2); - i2 <<= 6; - - i0 |= i2; - - if (i0 < 0) - goto InvalidExit; - - bufferBytes[destIndex] = (byte)(i0 >> 16); - bufferBytes[destIndex + 1] = (byte)(i0 >> 8); - destIndex += 2; - } - else + // Check if there are any bytes that should not be ignored after the last valid block size. + while (sourceIndex <= bufferLength - 1) { - if (i0 < 0) + if (!IsByteToBeIgnored(bufferBytes[sourceIndex++])) + { goto InvalidExit; - - bufferBytes[destIndex] = (byte)(i0 >> 16); - destIndex += 1; + } } DoneExit: @@ -571,15 +710,15 @@ private static unsafe void Vector128Decode(ref byte* srcBytes, ref byte* destByt // 1111 0x10 andlut 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 // The JIT won't hoist these "constants", so help it - Vector128 lutHi = Vector128.Create(0x02011010, 0x08040804, 0x10101010, 0x10101010).AsByte(); - Vector128 lutLo = Vector128.Create(0x11111115, 0x11111111, 0x1A131111, 0x1A1B1B1B).AsByte(); + Vector128 lutHi = Vector128.Create(0x02011010, 0x08040804, 0x10101010, 0x10101010).AsByte(); + Vector128 lutLo = Vector128.Create(0x11111115, 0x11111111, 0x1A131111, 0x1A1B1B1B).AsByte(); Vector128 lutShift = Vector128.Create(0x04131000, 0xb9b9bfbf, 0x00000000, 0x00000000).AsSByte(); Vector128 packBytesMask = Vector128.Create(0x06000102, 0x090A0405, 0x0C0D0E08, 0xffffffff).AsSByte(); - Vector128 mergeConstant0 = Vector128.Create(0x01400140).AsByte(); + Vector128 mergeConstant0 = Vector128.Create(0x01400140).AsByte(); Vector128 mergeConstant1 = Vector128.Create(0x00011000).AsInt16(); - Vector128 one = Vector128.Create((byte)1); - Vector128 mask2F = Vector128.Create((byte)'/'); - Vector128 mask8F = Vector128.Create((byte)0x8F); + Vector128 one = Vector128.Create((byte)1); + Vector128 mask2F = Vector128.Create((byte)'/'); + Vector128 mask8F = Vector128.Create((byte)0x8F); byte* src = srcBytes; byte* dest = destBytes; @@ -665,13 +804,8 @@ private static unsafe void Vector128Decode(ref byte* srcBytes, ref byte* destByt } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe int Decode(byte* encodedBytes, ref sbyte decodingMap) + private static unsafe int Decode(uint t0, uint t1, uint t2, uint t3, ref sbyte decodingMap) { - uint t0 = encodedBytes[0]; - uint t1 = encodedBytes[1]; - uint t2 = encodedBytes[2]; - uint t3 = encodedBytes[3]; - int i0 = Unsafe.Add(ref decodingMap, (IntPtr)t0); int i1 = Unsafe.Add(ref decodingMap, (IntPtr)t1); int i2 = Unsafe.Add(ref decodingMap, (IntPtr)t2); @@ -696,6 +830,51 @@ private static unsafe void WriteThreeLowOrderBytes(byte* destination, int value) destination[2] = (byte)(value); } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe int GetIndexOfFirstByteToBeIgnored(byte* src) + { + int firstInvalidIndex = -1; + + if (IsByteToBeIgnored(src[0])) + { + firstInvalidIndex = 0; + } + else if (IsByteToBeIgnored(src[1])) + { + firstInvalidIndex = 1; + } + else if (IsByteToBeIgnored(src[2])) + { + firstInvalidIndex = 2; + } + else if (IsByteToBeIgnored(src[3])) + { + firstInvalidIndex = 3; + } + + return firstInvalidIndex; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool IsBlockEndBytesPadding(byte secondToLastByte, byte lastByte) => + lastByte == EncodingPad + || secondToLastByte == EncodingPad && lastByte == EncodingPad; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool IsByteToBeIgnored(byte charByte) + { + switch (charByte) + { + case 9: // Line feed + case 10: // Horizontal tab + case 13: // Carriage return + case 32: // Space + return true; + default: + return false; + } + } + // Pre-computing this table using a custom string(s_characters) and GenerateDecodingMapAndVerify (found in tests) private static ReadOnlySpan DecodingMap => new sbyte[] { -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs new file mode 100644 index 00000000000000..b1bbc57bc571ff --- /dev/null +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs @@ -0,0 +1,214 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace System.Buffers.Text +{ + public static partial class Base64 + { + /// + /// Validates the span of UTF-8 encoded text represented as base64 into binary data. + /// + /// The input span which contains UTF-8 encoded text in base64 that needs to be validated. + /// true if is decodable; otherwise, false. + public static bool IsValid(ReadOnlySpan base64Text) => IsValid(base64Text, out int _); + + /// + /// Validates the span of UTF-8 encoded text represented as base64 into binary data. + /// + /// The input span which contains UTF-8 encoded text in base64 that needs to be validated. + /// The maximum length (in bytes) if you were to decode the base 64 encoded text within a byte span. + /// true if is decodable; otherwise, false. + public static unsafe bool IsValid(ReadOnlySpan base64Text, out int decodedLength) + { + if (base64Text.IsEmpty) + { + decodedLength = 0; + return true; + } + + // Check for invalid chars + int indexOfFirstNonBase64 = base64Text.IndexOfAnyExcept(ValidBase64CharsSortedAsc); + if (indexOfFirstNonBase64 >= 0) + { + decodedLength = 0; + return false; + } + + int length = base64Text.Length; + int paddingCount = 0; + + // Check if there are chars that need to be ignored while determining the length + if (base64Text.IndexOfAny(CharsToIgnore) > -1) + { + fixed (char* srcChars = &MemoryMarshal.GetReference(base64Text)) + { + int numberOfCharsToIgnore = 0; + char* src = srcChars; + + for (int i = 0; i < length; i++) + { + char charToValidate = *src++; + if (IsCharToBeIgnored(charToValidate)) + { + numberOfCharsToIgnore++; + } + else if (charToValidate == EncodingPad) + { + paddingCount++; + } + } + + length -= numberOfCharsToIgnore; + } + } + else if (length >= 4) + { + if (base64Text[length - 1] == EncodingPad) + { + paddingCount++; + } + if (base64Text[length - 2] == EncodingPad) + { + paddingCount++; + } + } + + if (length % 4 != 0) + { + decodedLength = 0; + return false; + } + + // Remove padding to get exact length + decodedLength = (length / 4 * 3) - paddingCount; + return true; + } + + /// + /// Validates the span of UTF-8 encoded text represented as base64 into binary data. + /// + /// The input span which contains UTF-8 encoded text in base64 that needs to be validated. + /// true if is decodable; otherwise, false. + public static bool IsValid(ReadOnlySpan base64TextUtf8) => IsValid(base64TextUtf8, out int _); + + /// + /// Validates the span of UTF-8 encoded text represented as base64 into binary data. + /// + /// The input span which contains UTF-8 encoded text in base64 that needs to be validated. + /// The maximum length (in bytes) if you were to decode the base 64 encoded text within a byte span. + /// true if is decodable; otherwise, false. + public static unsafe bool IsValid(ReadOnlySpan base64TextUtf8, out int decodedLength) + { + if (base64TextUtf8.IsEmpty) + { + decodedLength = 0; + return true; + } + + // Check for invalid chars + int indexOfFirstNonBase64 = base64TextUtf8.IndexOfAnyExcept(ValidBase64BytesSortedAsc); + if (indexOfFirstNonBase64 >= 0) + { + decodedLength = 0; + return false; + } + + int length = base64TextUtf8.Length; + int paddingCount = 0; + + // Check if there are chars that need to be ignored while determining the length + if (base64TextUtf8.IndexOfAny(BytesToIgnore) > -1) + { + fixed (byte* srcBytes = &MemoryMarshal.GetReference(base64TextUtf8)) + { + int numberOfBytesToIgnore = 0; + byte* src = srcBytes; + + for (int i = 0; i < length; i++) + { + byte byteToValidate = *src++; + if (IsByteToBeIgnored(byteToValidate)) + { + numberOfBytesToIgnore++; + } + else if (byteToValidate == EncodingPad) + { + paddingCount++; + } + } + + length -= numberOfBytesToIgnore; + } + } + else if (length >= 4) + { + if (base64TextUtf8[length - 1] == EncodingPad) + { + paddingCount++; + } + if (base64TextUtf8[length - 2] == EncodingPad) + { + paddingCount++; + } + } + + if (length % 4 != 0) + { + decodedLength = 0; + return false; + } + + // Remove padding to get exact length + decodedLength = (length / 4 * 3) - paddingCount; + return true; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool IsCharToBeIgnored(char aChar) + { + switch (aChar) + { + case '\n': // Line feed + case '\t': // Horizontal tab + case '\r': // Carriage return + case ' ': // Space + return true; + default: + return false; + } + } + + private static IndexOfAnyValues ValidBase64BytesSortedAsc => IndexOfAnyValues.Create(new byte[] { + 9, 10, 13, 32, 43, 47, //Line feed, Horizontal tab, Carriage return, Space, +, / + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, //0..9, + 61, //= + 65, 66, 67, 68, 69, 70, 71, 72, //A..H + 73, 74, 75, 76, 77, 78, 79, 80, //I..P + 81, 82, 83, 84, 85, 86, 87, 88, //Q..X + 89, 90, 97, 98, 99, 100, 101, 102, //Y..Z, a..f + 103, 104, 105, 106, 107, 108, 109, 110, //g..n + 111, 112, 113, 114, 115, 116, 117, 118, //o..v + 119, 120, 121, 122, //w..z + }); + + private static IndexOfAnyValues ValidBase64CharsSortedAsc => IndexOfAnyValues.Create(new char[] { + '\n', '\t', '\r', ' ', '+', '/', + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', + '=', + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', + 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', + 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', + 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', + 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', + 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', + 'w', 'x', 'y', 'z', + }); + + private static IndexOfAnyValues BytesToIgnore => IndexOfAnyValues.Create(new byte[] { 9, 10, 13, 32 }); + + private static IndexOfAnyValues CharsToIgnore => IndexOfAnyValues.Create(new char[] { '\n', '\t', '\r', ' ' }); + } +} diff --git a/src/libraries/System.Runtime/ref/System.Runtime.cs b/src/libraries/System.Runtime/ref/System.Runtime.cs index 713acb7e8120b1..51057d8bf556d8 100644 --- a/src/libraries/System.Runtime/ref/System.Runtime.cs +++ b/src/libraries/System.Runtime/ref/System.Runtime.cs @@ -7109,6 +7109,10 @@ public static partial class Base64 public static System.Buffers.OperationStatus EncodeToUtf8InPlace(System.Span buffer, int dataLength, out int bytesWritten) { throw null; } public static int GetMaxDecodedFromUtf8Length(int length) { throw null; } public static int GetMaxEncodedToUtf8Length(int length) { throw null; } + public static bool IsValid(System.ReadOnlySpan base64Text) { throw null; } + public static bool IsValid(System.ReadOnlySpan base64Text, out int decodedLength) { throw null; } + public static bool IsValid(System.ReadOnlySpan base64TextUtf8) { throw null; } + public static bool IsValid(System.ReadOnlySpan base64TextUtf8, out int decodedLength) { throw null; } } } namespace System.CodeDom.Compiler From 35302b4843a167e8705357f6593b8e076e12be4a Mon Sep 17 00:00:00 2001 From: Heath Baron-Morgan Date: Thu, 8 Dec 2022 01:01:26 -0800 Subject: [PATCH 02/17] Address PR feedback regarding Base64.IsValid --- .../tests/Base64/Base64ValidationUnitTests.cs | 110 +++++++++- .../System/Buffers/Text/Base64Validator.cs | 190 +++++++++--------- 2 files changed, 205 insertions(+), 95 deletions(-) diff --git a/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs index 801a36f3fcd8cd..abccac736bb217 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs @@ -193,7 +193,17 @@ public void ValidateWithOnlyCharsToBeIgnoredChars(string utf8WithByteToBeIgnored [InlineData("YQ==", 1)] [InlineData("YWI=", 2)] [InlineData("YWJj", 3)] - public void ValidateWithPaddingReturnsCorrectCount(string utf8WithByteToBeIgnored, int expectedLength) + [InlineData(" YWI=", 2)] + [InlineData("Y WI=", 2)] + [InlineData("YW I=", 2)] + [InlineData("YWI =", 2)] + [InlineData("YWI= ", 2)] + [InlineData(" YQ==", 1)] + [InlineData("Y Q==", 1)] + [InlineData("YQ ==", 1)] + [InlineData("YQ= =", 1)] + [InlineData("YQ== ", 1)] + public void ValidateWithPaddingReturnsCorrectCountBytes(string utf8WithByteToBeIgnored, int expectedLength) { byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); @@ -202,6 +212,29 @@ public void ValidateWithPaddingReturnsCorrectCount(string utf8WithByteToBeIgnore Assert.Equal(expectedLength, decodedLength); } + [Theory] + [InlineData("YQ==", 1)] + [InlineData("YWI=", 2)] + [InlineData("YWJj", 3)] + [InlineData(" YWI=", 2)] + [InlineData("Y WI=", 2)] + [InlineData("YW I=", 2)] + [InlineData("YWI =", 2)] + [InlineData("YWI= ", 2)] + [InlineData(" YQ==", 1)] + [InlineData("Y Q==", 1)] + [InlineData("YQ ==", 1)] + [InlineData("YQ= =", 1)] + [InlineData("YQ== ", 1)] + public void ValidateWithPaddingReturnsCorrectCountChars(string utf8WithByteToBeIgnored, int expectedLength) + { + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray(); + + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedLength, decodedLength); + } + [Theory] [InlineData("YQ==", 1)] [InlineData("YWI=", 2)] @@ -214,5 +247,80 @@ public void DecodeEmptySpan(string utf8WithByteToBeIgnored, int expectedLength) Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); Assert.Equal(expectedLength, decodedLength); } + + [Theory] + [InlineData("YWJ")] + [InlineData("YW")] + [InlineData("Y")] + public void InvalidSizeBytes(string utf8WithByteToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Theory] + [InlineData("YWJ")] + [InlineData("YW")] + [InlineData("Y")] + public void InvalidSizeChars(string utf8WithByteToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Theory] + [InlineData("YQ===")] + [InlineData("YQ=a=")] + [InlineData("YWI=a")] + [InlineData(" aYWI=a")] + [InlineData("a YWI=a")] + [InlineData("aY WI=a")] + [InlineData("aYW I=a")] + [InlineData("aYWI =a")] + [InlineData("aYWI= a")] + [InlineData("a YQ==a")] + [InlineData("aY Q==a")] + [InlineData("aYQ ==a")] + [InlineData("aYQ= =a")] + [InlineData("aYQ== a")] + [InlineData("aYQ==a ")] + public void InvalidBase64Bytes(string utf8WithByteToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Theory] + [InlineData("YQ===")] + [InlineData("YQ=a=")] + [InlineData("YWI=a")] + [InlineData("a YWI=a")] + [InlineData("aY WI=a")] + [InlineData("aYW I=a")] + [InlineData("aYWI =a")] + [InlineData("aYWI= a")] + [InlineData("a YQ==a")] + [InlineData("aY Q==a")] + [InlineData("aYQ ==a")] + [InlineData("aYQ= =a")] + [InlineData("aYQ== a")] + [InlineData("aYQ==a ")] + public void InvalidBase64Chars(string utf8WithByteToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs index b1bbc57bc571ff..ca223798851576 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs @@ -29,50 +29,64 @@ public static unsafe bool IsValid(ReadOnlySpan base64Text, out int decoded return true; } - // Check for invalid chars - int indexOfFirstNonBase64 = base64Text.IndexOfAnyExcept(ValidBase64CharsSortedAsc); - if (indexOfFirstNonBase64 >= 0) - { - decodedLength = 0; - return false; - } - int length = base64Text.Length; int paddingCount = 0; - // Check if there are chars that need to be ignored while determining the length - if (base64Text.IndexOfAny(CharsToIgnore) > -1) + int indexOfPaddingInvalidOrWhitespace = base64Text.IndexOfAnyExcept(validBase64Chars); + if (indexOfPaddingInvalidOrWhitespace >= 0) { - fixed (char* srcChars = &MemoryMarshal.GetReference(base64Text)) + while (indexOfPaddingInvalidOrWhitespace >= 0) { - int numberOfCharsToIgnore = 0; - char* src = srcChars; - - for (int i = 0; i < length; i++) + char charToValidate = base64Text[indexOfPaddingInvalidOrWhitespace]; + if (IsCharToBeIgnored(charToValidate)) { - char charToValidate = *src++; - if (IsCharToBeIgnored(charToValidate)) - { - numberOfCharsToIgnore++; - } - else if (charToValidate == EncodingPad) + // Chars to be ignored (e,g, whitespace...) should not count towards the length. + length--; + } + else if (charToValidate == EncodingPad) + { + // There can be at most 2 padding chars. + if (paddingCount == 2) { - paddingCount++; + decodedLength = 0; + return false; } + + paddingCount++; + } + else + { + // An invalid char was encountered. + decodedLength = 0; + return false; } - length -= numberOfCharsToIgnore; - } - } - else if (length >= 4) - { - if (base64Text[length - 1] == EncodingPad) - { - paddingCount++; + if (indexOfPaddingInvalidOrWhitespace == base64Text.Length - 1) + { + // The end of the input has been reached. + break; + } + + // If no padding is found, slice and use IndexOfAnyExcept to look for the next invalid char. + if (paddingCount == 0) + { + indexOfPaddingInvalidOrWhitespace = base64Text + .Slice(indexOfPaddingInvalidOrWhitespace + 1, base64Text.Length - indexOfPaddingInvalidOrWhitespace - 1) + .IndexOfAnyExcept(validBase64Chars) + + indexOfPaddingInvalidOrWhitespace + 1; // Add current index offset. + } + // If padding is already found, simply increment, as the common case might have 2 sequential padding chars. + else + { + indexOfPaddingInvalidOrWhitespace++; + } } - if (base64Text[length - 2] == EncodingPad) + + // If the invalid chars all consisted of whitespace, the input will be empty. + if (length == 0) { - paddingCount++; + decodedLength = 0; + return true; } } @@ -100,7 +114,7 @@ public static unsafe bool IsValid(ReadOnlySpan base64Text, out int decoded /// The input span which contains UTF-8 encoded text in base64 that needs to be validated. /// The maximum length (in bytes) if you were to decode the base 64 encoded text within a byte span. /// true if is decodable; otherwise, false. - public static unsafe bool IsValid(ReadOnlySpan base64TextUtf8, out int decodedLength) + public static bool IsValid(ReadOnlySpan base64TextUtf8, out int decodedLength) { if (base64TextUtf8.IsEmpty) { @@ -108,50 +122,64 @@ public static unsafe bool IsValid(ReadOnlySpan base64TextUtf8, out int dec return true; } - // Check for invalid chars - int indexOfFirstNonBase64 = base64TextUtf8.IndexOfAnyExcept(ValidBase64BytesSortedAsc); - if (indexOfFirstNonBase64 >= 0) - { - decodedLength = 0; - return false; - } - int length = base64TextUtf8.Length; int paddingCount = 0; - // Check if there are chars that need to be ignored while determining the length - if (base64TextUtf8.IndexOfAny(BytesToIgnore) > -1) + int indexOfPaddingInvalidOrWhitespace = base64TextUtf8.IndexOfAnyExcept(validBase64Bytes); + if (indexOfPaddingInvalidOrWhitespace >= 0) { - fixed (byte* srcBytes = &MemoryMarshal.GetReference(base64TextUtf8)) + while (indexOfPaddingInvalidOrWhitespace >= 0) { - int numberOfBytesToIgnore = 0; - byte* src = srcBytes; - - for (int i = 0; i < length; i++) + byte byteToValidate = base64TextUtf8[indexOfPaddingInvalidOrWhitespace]; + if (IsByteToBeIgnored(byteToValidate)) { - byte byteToValidate = *src++; - if (IsByteToBeIgnored(byteToValidate)) - { - numberOfBytesToIgnore++; - } - else if (byteToValidate == EncodingPad) + // Bytes to be ignored (e,g, whitespace...) should not count towards the length. + length--; + } + else if (byteToValidate == EncodingPad) + { + // There can be at most 2 padding chars. + if (paddingCount == 2) { - paddingCount++; + decodedLength = 0; + return false; } + + paddingCount++; + } + else + { + // An invalid char was encountered. + decodedLength = 0; + return false; } - length -= numberOfBytesToIgnore; - } - } - else if (length >= 4) - { - if (base64TextUtf8[length - 1] == EncodingPad) - { - paddingCount++; + if (indexOfPaddingInvalidOrWhitespace == base64TextUtf8.Length - 1) + { + // The end of the input has been reached. + break; + } + + // If no padding is found, slice and use IndexOfAnyExcept to look for the next invalid char. + if (paddingCount == 0) + { + indexOfPaddingInvalidOrWhitespace = base64TextUtf8 + .Slice(indexOfPaddingInvalidOrWhitespace + 1, base64TextUtf8.Length - indexOfPaddingInvalidOrWhitespace - 1) + .IndexOfAnyExcept(validBase64Bytes) + + indexOfPaddingInvalidOrWhitespace + 1; // Add current index offset. + } + // If padding is already found, simply increment, as the common case might have 2 sequential padding chars. + else + { + indexOfPaddingInvalidOrWhitespace++; + } } - if (base64TextUtf8[length - 2] == EncodingPad) + + // If the invalid chars all consisted of whitespace, the input will be empty. + if (length == 0) { - paddingCount++; + decodedLength = 0; + return true; } } @@ -181,34 +209,8 @@ private static bool IsCharToBeIgnored(char aChar) } } - private static IndexOfAnyValues ValidBase64BytesSortedAsc => IndexOfAnyValues.Create(new byte[] { - 9, 10, 13, 32, 43, 47, //Line feed, Horizontal tab, Carriage return, Space, +, / - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, //0..9, - 61, //= - 65, 66, 67, 68, 69, 70, 71, 72, //A..H - 73, 74, 75, 76, 77, 78, 79, 80, //I..P - 81, 82, 83, 84, 85, 86, 87, 88, //Q..X - 89, 90, 97, 98, 99, 100, 101, 102, //Y..Z, a..f - 103, 104, 105, 106, 107, 108, 109, 110, //g..n - 111, 112, 113, 114, 115, 116, 117, 118, //o..v - 119, 120, 121, 122, //w..z - }); - - private static IndexOfAnyValues ValidBase64CharsSortedAsc => IndexOfAnyValues.Create(new char[] { - '\n', '\t', '\r', ' ', '+', '/', - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', - '=', - 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', - 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', - 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', - 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', - 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', - 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', - 'w', 'x', 'y', 'z', - }); - - private static IndexOfAnyValues BytesToIgnore => IndexOfAnyValues.Create(new byte[] { 9, 10, 13, 32 }); - - private static IndexOfAnyValues CharsToIgnore => IndexOfAnyValues.Create(new char[] { '\n', '\t', '\r', ' ' }); + private static readonly IndexOfAnyValues validBase64Bytes = IndexOfAnyValues.Create(EncodingMap); + + private static readonly IndexOfAnyValues validBase64Chars = IndexOfAnyValues.Create("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"); } } From 837b4bd94fa79b40cc50d6b83e5029601c497963 Mon Sep 17 00:00:00 2001 From: Heath Baron-Morgan Date: Mon, 12 Dec 2022 21:17:56 -0800 Subject: [PATCH 03/17] Address PR feedback: General optimizations --- .../src/System/Buffers/Text/Base64Decoder.cs | 31 +++++++++---------- .../System/Buffers/Text/Base64Validator.cs | 21 ++----------- 2 files changed, 17 insertions(+), 35 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs index 2c3406d4c9048f..d22c3647de886c 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs @@ -93,7 +93,7 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa // or ignored in consumed operations if not. int pendingSrcIncrement = 0; - while (src + 4 <= srcEnd) + while (src <= srcEnd - 4) { // The default increment will be 4 if no bytes that require ignoring are encountered. pendingSrcIncrement = 4; @@ -106,7 +106,7 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa if (result < 0) { int firstInvalidIndex = GetIndexOfFirstByteToBeIgnored(src); - if (firstInvalidIndex != -1) + if (firstInvalidIndex >= 0) { int bytesIgnored = 0; int validBytesSearchIndex = firstInvalidIndex; @@ -114,7 +114,7 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa for (int currentBlockIndex = firstInvalidIndex; currentBlockIndex < 4; currentBlockIndex++) { - while (src + validBytesSearchIndex < srcEnd + while (src < srcEnd - validBytesSearchIndex && IsByteToBeIgnored(src[validBytesSearchIndex])) { validBytesSearchIndex++; @@ -122,7 +122,7 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa totalBytesIgnored++; } - if (src + validBytesSearchIndex >= srcEnd) + if (src >= srcEnd - validBytesSearchIndex) { insufficientValidBytesFound = true; break; @@ -175,7 +175,7 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa && IsBlockEndBytesPadding(b2, b3)) { int indexOfBytesAfterPadding = pendingSrcIncrement; - while (src + indexOfBytesAfterPadding + 1 <= srcEnd) + while (src <= srcEnd - indexOfBytesAfterPadding - 1) { if (!IsByteToBeIgnored(src[indexOfBytesAfterPadding++])) { @@ -233,7 +233,7 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa } } - if (dest + 3 > destEnd) + if (dest > destEnd - 3) { goto DestinationTooSmallExit; } @@ -255,7 +255,7 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa } int indexOfBytesNotConsumed = pendingSrcIncrement; - while (src + indexOfBytesNotConsumed + 1 <= srcEnd) + while (src <= srcEnd - indexOfBytesNotConsumed - 1) { if (!IsByteToBeIgnored(src[indexOfBytesNotConsumed++])) { @@ -349,7 +349,7 @@ public static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, ou if (result < 0) { int firstInvalidIndex = GetIndexOfFirstByteToBeIgnored(bufferBytes + sourceIndex); - if (firstInvalidIndex != -1) + if (firstInvalidIndex >= 0) { int bytesIgnored = 0; uint validBytesSearchIndex = (uint)firstInvalidIndex + sourceIndex; @@ -861,18 +861,15 @@ private static bool IsBlockEndBytesPadding(byte secondToLastByte, byte lastByte) || secondToLastByte == EncodingPad && lastByte == EncodingPad; [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool IsByteToBeIgnored(byte charByte) + private static bool IsByteToBeIgnored(int charByte) { - switch (charByte) + if (charByte < 32) { - case 9: // Line feed - case 10: // Horizontal tab - case 13: // Carriage return - case 32: // Space - return true; - default: - return false; + const int BitMask = (1 << 9) | (1 << 10) | (1 << 13); + return ((1 << charByte) & BitMask) != 0; } + + return charByte == 32; } // Pre-computing this table using a custom string(s_characters) and GenerateDecodingMapAndVerify (found in tests) diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs index ca223798851576..fac6df6a870dbe 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs @@ -38,7 +38,7 @@ public static unsafe bool IsValid(ReadOnlySpan base64Text, out int decoded while (indexOfPaddingInvalidOrWhitespace >= 0) { char charToValidate = base64Text[indexOfPaddingInvalidOrWhitespace]; - if (IsCharToBeIgnored(charToValidate)) + if (IsByteToBeIgnored(charToValidate)) { // Chars to be ignored (e,g, whitespace...) should not count towards the length. length--; @@ -97,7 +97,7 @@ public static unsafe bool IsValid(ReadOnlySpan base64Text, out int decoded } // Remove padding to get exact length - decodedLength = (length / 4 * 3) - paddingCount; + decodedLength = (int)((uint)length / 4 * 3) - paddingCount; return true; } @@ -190,25 +190,10 @@ public static bool IsValid(ReadOnlySpan base64TextUtf8, out int decodedLen } // Remove padding to get exact length - decodedLength = (length / 4 * 3) - paddingCount; + decodedLength = (int)((uint)length / 4 * 3) - paddingCount; return true; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool IsCharToBeIgnored(char aChar) - { - switch (aChar) - { - case '\n': // Line feed - case '\t': // Horizontal tab - case '\r': // Carriage return - case ' ': // Space - return true; - default: - return false; - } - } - private static readonly IndexOfAnyValues validBase64Bytes = IndexOfAnyValues.Create(EncodingMap); private static readonly IndexOfAnyValues validBase64Chars = IndexOfAnyValues.Create("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"); From 9b3b5818a4d024c1bfd88b7f8fa0a5d15372fb93 Mon Sep 17 00:00:00 2001 From: Heath Baron-Morgan Date: Tue, 13 Dec 2022 23:51:28 -0800 Subject: [PATCH 04/17] Address PR feedback: Use vectorized decoding while enough src --- .../src/System/Buffers/Text/Base64Decoder.cs | 479 ++++++++++-------- 1 file changed, 277 insertions(+), 202 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs index d22c3647de886c..5cf61f219e40f0 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs @@ -63,225 +63,109 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa byte* destEnd = dest + (uint)destLength; byte* srcEnd = srcBytes + (uint)srcLength; byte* srcMax = srcBytes + (uint)maxSrcLength; - - int totalBytesIgnored = 0; - - if (maxSrcLength >= 24) - { - byte* end = srcMax - 45; - if (Avx2.IsSupported && (end >= src)) - { - Avx2Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); - - if (src == srcEnd) - goto DoneExit; - } - - end = srcMax - 24; - if ((Ssse3.IsSupported || AdvSimd.Arm64.IsSupported) && BitConverter.IsLittleEndian && (end >= src)) - { - Vector128Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); - - if (src == srcEnd) - goto DoneExit; - } - } + byte* end = srcMax - 45; ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap); + OperationStatus lastBlockStatus = OperationStatus.Done; + int totalBytesIgnored = 0; // The next src increment is stored as it will be used if the dest has enough space // or ignored in consumed operations if not. - int pendingSrcIncrement = 0; + int pendingSrcIncrement; - while (src <= srcEnd - 4) + if (Avx2.IsSupported) { - // The default increment will be 4 if no bytes that require ignoring are encountered. - pendingSrcIncrement = 4; - byte b0 = src[0]; - byte b1 = src[1]; - byte b2 = src[2]; - byte b3 = src[3]; - - int result = Decode(b0, b1, b2, b3, ref decodingMap); - if (result < 0) + while (end >= src) { - int firstInvalidIndex = GetIndexOfFirstByteToBeIgnored(src); - if (firstInvalidIndex >= 0) - { - int bytesIgnored = 0; - int validBytesSearchIndex = firstInvalidIndex; - bool insufficientValidBytesFound = false; - - for (int currentBlockIndex = firstInvalidIndex; currentBlockIndex < 4; currentBlockIndex++) - { - while (src < srcEnd - validBytesSearchIndex - && IsByteToBeIgnored(src[validBytesSearchIndex])) - { - validBytesSearchIndex++; - bytesIgnored++; - totalBytesIgnored++; - } - - if (src >= srcEnd - validBytesSearchIndex) - { - insufficientValidBytesFound = true; - break; - } - - if (currentBlockIndex == 0) - { - b0 = src[validBytesSearchIndex]; - } - else if (currentBlockIndex == 1) - { - b1 = src[validBytesSearchIndex]; - } - else if (currentBlockIndex == 2) - { - b2 = src[validBytesSearchIndex]; - } - else - { - b3 = src[validBytesSearchIndex]; - } - - validBytesSearchIndex++; - } - - if (insufficientValidBytesFound) - { - break; - } - - result = Decode(b0, b1, b2, b3, ref decodingMap); - if (result < 0 - && !IsBlockEndBytesPadding(b2, b3)) - { - goto InvalidDataExit; - } - - pendingSrcIncrement = validBytesSearchIndex; - } - else + bool isComplete = Avx2Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); + if (isComplete) { - if (!IsBlockEndBytesPadding(b2, b3)) - { - goto InvalidDataExit; - } + break; } - // Check to see if parsing failed due to padding. There could be 1 or 2 padding chars. - if (result < 0 - && IsBlockEndBytesPadding(b2, b3)) + // Process 4 bytes, until first set of invalid bytes are skipped. + lastBlockStatus = IgnoreWhitespaceAndTryConsumeValidBytes(ref src, srcEnd, ref dest, destEnd, ref decodingMap, ref totalBytesIgnored, isFinalBlock, lastBlockStatus, true, out pendingSrcIncrement); + if (lastBlockStatus != OperationStatus.Done) { - int indexOfBytesAfterPadding = pendingSrcIncrement; - while (src <= srcEnd - indexOfBytesAfterPadding - 1) - { - if (!IsByteToBeIgnored(src[indexOfBytesAfterPadding++])) - { - // Only bytes to be ignored can be after padding bytes. - goto InvalidDataExit; - } - } - - // If isFinalBlock is false, padding is treaded as invalid. - if (!isFinalBlock) - { - goto InvalidDataExit; - } - - int i0 = Unsafe.Add(ref decodingMap, (IntPtr)b0); - int i1 = Unsafe.Add(ref decodingMap, (IntPtr)b1); - - i0 <<= 18; - i1 <<= 12; - - i0 |= i1; - - if (b2 != EncodingPad) - { - int i2 = Unsafe.Add(ref decodingMap, (IntPtr)b2); - - i2 <<= 6; - - i0 |= i2; - - if (i0 < 0) - goto InvalidDataExit; - if (dest + 2 > destEnd) - goto DestinationTooSmallExit; - - dest[0] = (byte)(i0 >> 16); - dest[1] = (byte)(i0 >> 8); - dest += 2; - } - else - { - if (i0 < 0) - goto InvalidDataExit; - if (dest + 1 > destEnd) - goto DestinationTooSmallExit; - - dest[0] = (byte)(i0 >> 16); - dest += 1; - } - - src += pendingSrcIncrement; - pendingSrcIncrement = 0; - break; } } + } + else if ((Ssse3.IsSupported || AdvSimd.Arm64.IsSupported) + && BitConverter.IsLittleEndian) + { + end = srcMax - 24; - if (dest > destEnd - 3) + while (end >= src) { - goto DestinationTooSmallExit; - } - - WriteThreeLowOrderBytes(dest, result); + bool isComplete = Vector128Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); + if (isComplete) + { + break; + } - src += pendingSrcIncrement; - pendingSrcIncrement = 0; - dest += 3; + // Process 4 bytes, until first set of invalid bytes are skipped. + lastBlockStatus = IgnoreWhitespaceAndTryConsumeValidBytes(ref src, srcEnd, ref dest, destEnd, ref decodingMap, ref totalBytesIgnored, isFinalBlock, lastBlockStatus, true, out pendingSrcIncrement); + if (lastBlockStatus != OperationStatus.Done) + { + break; + } + } } - if (!isFinalBlock) + lastBlockStatus = IgnoreWhitespaceAndTryConsumeValidBytes(ref src, srcEnd, ref dest, destEnd, ref decodingMap, ref totalBytesIgnored, isFinalBlock, lastBlockStatus, false, out pendingSrcIncrement); + + // Check if + if (lastBlockStatus == OperationStatus.Done + || lastBlockStatus == OperationStatus.NeedMoreData) { - int remainingBytes = (int)(srcEnd - src); - if (remainingBytes > 0 && remainingBytes < 4) + if (!isFinalBlock) + { + int remainingBytes = (int)(srcEnd - src); + if (remainingBytes > 0 && remainingBytes < 4) + { + lastBlockStatus = OperationStatus.NeedMoreData; + } + } + else { - goto NeedMoreDataExit; + int indexOfBytesNotConsumed = pendingSrcIncrement; + while (src <= srcEnd - indexOfBytesNotConsumed - 1) + { + if (!IsByteToBeIgnored(src[indexOfBytesNotConsumed++])) + { + lastBlockStatus = OperationStatus.InvalidData; + break; + } + } } } - int indexOfBytesNotConsumed = pendingSrcIncrement; - while (src <= srcEnd - indexOfBytesNotConsumed - 1) + switch (lastBlockStatus) { - if (!IsByteToBeIgnored(src[indexOfBytesNotConsumed++])) - { - goto InvalidDataExit; - } + case OperationStatus.Done: + bytesConsumed = ((int)(src - srcBytes)) - totalBytesIgnored; + bytesWritten = (int)(dest - destBytes); + break; + case OperationStatus.DestinationTooSmall: + bytesConsumed = ((int)(src - srcBytes)) - totalBytesIgnored; + bytesWritten = (int)(dest - destBytes); + break; + case OperationStatus.NeedMoreData: + bytesConsumed = ((int)(src - srcBytes)) - totalBytesIgnored; + bytesWritten = (int)(dest - destBytes); + break; + case OperationStatus.InvalidData: + bytesConsumed = ((int)(src - srcBytes)) - totalBytesIgnored; + bytesWritten = (int)(dest - destBytes); + break; + default: + bytesConsumed = 0; + bytesWritten = 0; + break; } - DoneExit: - bytesConsumed = ((int)(src - srcBytes)) - totalBytesIgnored; - bytesWritten = (int)(dest - destBytes); - return OperationStatus.Done; - - DestinationTooSmallExit: - bytesConsumed = ((int)(src - srcBytes)) - totalBytesIgnored; - bytesWritten = (int)(dest - destBytes); - return OperationStatus.DestinationTooSmall; - - NeedMoreDataExit: - bytesConsumed = ((int)(src - srcBytes)) - totalBytesIgnored; - bytesWritten = (int)(dest - destBytes); - return OperationStatus.NeedMoreData; - - InvalidDataExit: - bytesConsumed = ((int)(src - srcBytes)) - totalBytesIgnored; - bytesWritten = (int)(dest - destBytes); - return OperationStatus.InvalidData; + return lastBlockStatus; } } @@ -329,8 +213,6 @@ public static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, ou uint sourceIndex = 0; uint destIndex = 0; - int totalBytesIgnored = 0; - if (bufferLength == 0) goto DoneExit; @@ -351,7 +233,6 @@ public static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, ou int firstInvalidIndex = GetIndexOfFirstByteToBeIgnored(bufferBytes + sourceIndex); if (firstInvalidIndex >= 0) { - int bytesIgnored = 0; uint validBytesSearchIndex = (uint)firstInvalidIndex + sourceIndex; bool insufficientValidBytesFound = false; @@ -361,8 +242,6 @@ public static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, ou && IsByteToBeIgnored(bufferBytes[validBytesSearchIndex])) { validBytesSearchIndex++; - bytesIgnored++; - totalBytesIgnored++; } if (validBytesSearchIndex > bufferLength - 1) @@ -459,8 +338,6 @@ public static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, ou destIndex += 1; } - sourceIndex = nextSoureIndex; - goto DoneExit; } } @@ -490,7 +367,189 @@ public static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, ou } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe void Avx2Decode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + private static unsafe OperationStatus IgnoreWhitespaceAndTryConsumeValidBytes(ref byte* src, byte* srcEnd, ref byte* dest, byte* destEnd, ref sbyte decodingMap, ref int totalBytesIgnored, bool isFinalBlock, OperationStatus lastBlockStatus, bool exitAfterFirstSkippedBytes, out int pendingSrcIncrement) + { + pendingSrcIncrement = 0; + + while (src <= srcEnd - 4) + { + byte* srcBeforeProcessing = src; + lastBlockStatus = IgnoreWhitespaceAndTryConsumeNextValidBytesBlock(ref src, srcEnd, ref dest, destEnd, ref decodingMap, ref totalBytesIgnored, isFinalBlock, out pendingSrcIncrement); + + if (lastBlockStatus != OperationStatus.Done + // The source was not increased because there were not enough valid bytes. + || srcBeforeProcessing == src + // Exit after consuming more than 4 bytes due to skipping whitespace. + || (exitAfterFirstSkippedBytes && src - srcBeforeProcessing > 4)) + { + break; + } + } + + return lastBlockStatus; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe OperationStatus IgnoreWhitespaceAndTryConsumeNextValidBytesBlock(ref byte* src, byte* srcEnd, ref byte* dest, byte* destEnd, ref sbyte decodingMap, ref int totalBytesIgnored, bool isFinalBlock, out int pendingSrcIncrement) + { + // The default increment will be 4 if no bytes that require ignoring are encountered. + pendingSrcIncrement = 4; + byte b0 = src[0]; + byte b1 = src[1]; + byte b2 = src[2]; + byte b3 = src[3]; + + int result = Decode(b0, b1, b2, b3, ref decodingMap); + if (result < 0) + { + int firstInvalidIndex = GetIndexOfFirstByteToBeIgnored(src); + if (firstInvalidIndex >= 0) + { + int validBytesSearchIndex = firstInvalidIndex; + bool insufficientValidBytesFound = false; + + for (int currentBlockIndex = firstInvalidIndex; currentBlockIndex < 4; currentBlockIndex++) + { + while (src < srcEnd - validBytesSearchIndex + && IsByteToBeIgnored(src[validBytesSearchIndex])) + { + validBytesSearchIndex++; + totalBytesIgnored++; + } + + if (src >= srcEnd - validBytesSearchIndex) + { + insufficientValidBytesFound = true; + break; + } + + if (currentBlockIndex == 0) + { + b0 = src[validBytesSearchIndex]; + } + else if (currentBlockIndex == 1) + { + b1 = src[validBytesSearchIndex]; + } + else if (currentBlockIndex == 2) + { + b2 = src[validBytesSearchIndex]; + } + else + { + b3 = src[validBytesSearchIndex]; + } + + validBytesSearchIndex++; + } + + if (insufficientValidBytesFound) + { + return OperationStatus.Done; + } + + result = Decode(b0, b1, b2, b3, ref decodingMap); + if (result < 0 + && !IsBlockEndBytesPadding(b2, b3)) + { + return OperationStatus.InvalidData; + } + + pendingSrcIncrement = validBytesSearchIndex; + } + else + { + if (!IsBlockEndBytesPadding(b2, b3)) + { + return OperationStatus.InvalidData; + } + } + + // Check to see if parsing failed due to padding. There could be 1 or 2 padding chars. + if (result < 0 + && IsBlockEndBytesPadding(b2, b3)) + { + int indexOfBytesAfterPadding = pendingSrcIncrement; + while (src <= srcEnd - indexOfBytesAfterPadding - 1) + { + if (!IsByteToBeIgnored(src[indexOfBytesAfterPadding++])) + { + // Only bytes to be ignored can be after padding bytes. + return OperationStatus.InvalidData; + } + } + + // If isFinalBlock is false, padding is treaded as invalid. + if (!isFinalBlock) + { + return OperationStatus.InvalidData; + } + + return TryProcessFinalBlockWithPadding(ref src, ref dest, destEnd, ref decodingMap, ref pendingSrcIncrement, b0, b1, b2); + } + } + + if (dest > destEnd - 3) + { + return OperationStatus.DestinationTooSmall; + } + + WriteThreeLowOrderBytes(dest, result); + + src += pendingSrcIncrement; + pendingSrcIncrement = 0; + dest += 3; + + return OperationStatus.Done; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe OperationStatus TryProcessFinalBlockWithPadding(ref byte* src, ref byte* dest, byte* destEnd, ref sbyte decodingMap, ref int pendingSrcIncrement, byte b0, byte b1, byte b2) + { + int i0 = Unsafe.Add(ref decodingMap, (IntPtr)b0); + int i1 = Unsafe.Add(ref decodingMap, (IntPtr)b1); + + i0 <<= 18; + i1 <<= 12; + + i0 |= i1; + + if (b2 != EncodingPad) + { + int i2 = Unsafe.Add(ref decodingMap, (IntPtr)b2); + + i2 <<= 6; + + i0 |= i2; + + if (i0 < 0) + return OperationStatus.InvalidData; + if (dest + 2 > destEnd) + return OperationStatus.DestinationTooSmall; + + dest[0] = (byte)(i0 >> 16); + dest[1] = (byte)(i0 >> 8); + dest += 2; + } + else + { + if (i0 < 0) + return OperationStatus.InvalidData; + if (dest + 1 > destEnd) + return OperationStatus.DestinationTooSmall; + + dest[0] = (byte)(i0 >> 16); + dest += 1; + } + + src += pendingSrcIncrement; + pendingSrcIncrement = 0; + + return OperationStatus.Done; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe bool Avx2Decode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) { // If we have AVX2 support, pick off 32 bytes at a time for as long as we can, // but make sure that we quit before seeing any == markers at the end of the @@ -570,7 +629,13 @@ private static unsafe void Avx2Decode(ref byte* srcBytes, ref byte* destBytes, b Vector256 lo = Avx2.Shuffle(lutLo, loNibbles); if (!Avx.TestZ(lo, hi)) - break; + { + // Record current progress + srcBytes = src; + destBytes = dest; + + return false; + } Vector256 eq2F = Avx2.CompareEqual(str, mask2F); Vector256 shift = Avx2.Shuffle(lutShift, Avx2.Add(eq2F, hiNibbles)); @@ -614,6 +679,8 @@ private static unsafe void Avx2Decode(ref byte* srcBytes, ref byte* destBytes, b srcBytes = src; destBytes = dest; + + return true; } // This can be replaced once https://github.com/dotnet/runtime/issues/63331 is implemented. @@ -633,7 +700,7 @@ private static Vector128 SimdShuffle(Vector128 left, Vector128 } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe void Vector128Decode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + private static unsafe bool Vector128Decode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) { Debug.Assert((Ssse3.IsSupported || AdvSimd.Arm64.IsSupported) && BitConverter.IsLittleEndian); @@ -738,7 +805,13 @@ private static unsafe void Vector128Decode(ref byte* srcBytes, ref byte* destByt // Check for invalid input: if any "and" values from lo and hi are not zero, // fall back on bytewise code to do error checking and reporting: if ((lo & hi) != Vector128.Zero) - break; + { + // Record current progress + srcBytes = src; + destBytes = dest; + + return false; + } Vector128 eq2F = Vector128.Equals(str, mask2F); Vector128 shift = SimdShuffle(lutShift.AsByte(), (eq2F + hiNibbles), mask8F); @@ -801,6 +874,8 @@ private static unsafe void Vector128Decode(ref byte* srcBytes, ref byte* destByt srcBytes = src; destBytes = dest; + + return true; } [MethodImpl(MethodImplOptions.AggressiveInlining)] From 9983889556913d600c85356d63af54bd4f3e2853 Mon Sep 17 00:00:00 2001 From: Heath Baron-Morgan Date: Wed, 14 Dec 2022 20:26:06 -0800 Subject: [PATCH 05/17] Address PR feedback: General optimization --- .../src/System/Buffers/Text/Base64Decoder.cs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs index 5cf61f219e40f0..7102d5dc6b5640 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs @@ -115,20 +115,21 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa lastBlockStatus = IgnoreWhitespaceAndTryConsumeValidBytes(ref src, srcEnd, ref dest, destEnd, ref decodingMap, ref totalBytesIgnored, isFinalBlock, lastBlockStatus, false, out pendingSrcIncrement); - // Check if - if (lastBlockStatus == OperationStatus.Done - || lastBlockStatus == OperationStatus.NeedMoreData) + // Assess the end block and bytes beyond it. + if (lastBlockStatus == OperationStatus.Done) { if (!isFinalBlock) { int remainingBytes = (int)(srcEnd - src); - if (remainingBytes > 0 && remainingBytes < 4) + if (remainingBytes is > 0 and < 4) { + // An incomplete block of bytes was found. lastBlockStatus = OperationStatus.NeedMoreData; } } else { + // Check if there are bytes that should not be ignored beyond the expected end of the valid input range. int indexOfBytesNotConsumed = pendingSrcIncrement; while (src <= srcEnd - indexOfBytesNotConsumed - 1) { From f389289d0849a528e480efba3a96faa14f16c362 Mon Sep 17 00:00:00 2001 From: Heath Baron-Morgan Date: Mon, 19 Dec 2022 23:13:23 -0800 Subject: [PATCH 06/17] Address PR feedback: Optimize for whitespace (\r\n) every 76 bytes --- .../tests/Base64/Base64DecoderUnitTests.cs | 7 +- .../tests/Base64/Base64TestBase.cs | 51 ++++----- .../tests/Base64/Base64ValidationUnitTests.cs | 6 +- .../src/System/Buffers/Text/Base64Decoder.cs | 106 +++++++++++++++--- 4 files changed, 122 insertions(+), 48 deletions(-) diff --git a/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs b/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs index 883941e549f803..16c48a52dda12a 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs @@ -661,7 +661,7 @@ public void DecodeInPlaceInvalidBytesPadding() [Theory] [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] - public void BasicDecodingIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes) + public void BasicDecodingIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes, int expectedBytesConsumed) { byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored); byte[] resultBytes = new byte[5]; @@ -671,7 +671,7 @@ public void BasicDecodingIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf byte[] stringBytes = Convert.FromBase64String(utf8WithCharsToBeIgnored); Assert.Equal(OperationStatus.Done, result); - Assert.Equal(8, bytesConsumed); + Assert.Equal(expectedBytesConsumed, bytesConsumed); Assert.Equal(expectedBytes.Length, bytesWritten); Assert.True(expectedBytes.SequenceEqual(resultBytes)); Assert.True(stringBytes.SequenceEqual(resultBytes)); @@ -679,8 +679,9 @@ public void BasicDecodingIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf [Theory] [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] - public void DecodeInPlaceIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes) + public void DecodeInPlaceIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes, int expectedBytesConsumed) { + _ = expectedBytesConsumed; Span utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored); OperationStatus result = Base64.DecodeFromUtf8InPlace(utf8BytesWithByteToBeIgnored, out int bytesWritten); Span bytesOverwritten = utf8BytesWithByteToBeIgnored.Slice(0, bytesWritten); diff --git a/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs b/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs index 21f116da83464a..3b1e06e13c2b23 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs @@ -1,5 +1,5 @@ // Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. +// The .NET Foundation licenses this file to you under the MIT license.utf8Bytes, utf8Bytes.Length using System.Collections.Generic; using System.Text; @@ -24,20 +24,20 @@ public static IEnumerable ValidBase64Strings_WithCharsThatMustBeIgnore // One will have 1 char, another will have 3 // Line feed - yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(9), 1), utf8Bytes }; - yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(9), 3), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(9), 1), utf8Bytes, utf8Bytes.Length + 4 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(9), 3), utf8Bytes, utf8Bytes.Length + 6 }; // Horizontal tab - yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(10), 1), utf8Bytes }; - yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(10), 3), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(10), 1), utf8Bytes, utf8Bytes.Length + 4 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(10), 3), utf8Bytes, utf8Bytes.Length + 6 }; // Carriage return - yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(13), 1), utf8Bytes }; - yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(13), 3), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(13), 1), utf8Bytes, utf8Bytes.Length + 4 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(13), 3), utf8Bytes, utf8Bytes.Length + 6 }; // Space - yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(32), 1), utf8Bytes }; - yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(32), 3), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(32), 1), utf8Bytes, utf8Bytes.Length + 4 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(32), 3), utf8Bytes, utf8Bytes.Length + 6 }; string GetBase64StringWithPassedCharInsertedInTheMiddle(char charToInsert, int numberOfTimesToInsert) => $"{firstSegment}{new string(charToInsert, numberOfTimesToInsert)}{secondSegment}"; @@ -45,41 +45,42 @@ public static IEnumerable ValidBase64Strings_WithCharsThatMustBeIgnore // One will have 1 char, another will have 3 // Line feed - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(9), 1), utf8Bytes }; - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(9), 3), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(9), 1), utf8Bytes, utf8Bytes.Length + 4 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(9), 3), utf8Bytes, utf8Bytes.Length + 6 }; // Horizontal tab - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(10), 1), utf8Bytes }; - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(10), 3), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(10), 1), utf8Bytes, utf8Bytes.Length + 4 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(10), 3), utf8Bytes, utf8Bytes.Length + 6 }; // Carriage return - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(13), 1), utf8Bytes }; - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(13), 3), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(13), 1), utf8Bytes, utf8Bytes.Length + 4 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(13), 3), utf8Bytes, utf8Bytes.Length + 6 }; // Space - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(32), 1), utf8Bytes }; - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(32), 3), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(32), 1), utf8Bytes, utf8Bytes.Length + 4 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(32), 3), utf8Bytes, utf8Bytes.Length + 6 }; string GetBase64StringWithPassedCharInsertedAtTheStart(char charToInsert, int numberOfTimesToInsert) => $"{new string(charToInsert, numberOfTimesToInsert)}{firstSegment}{secondSegment}"; // Insert ignored chars at the end of the base 64 string // One will have 1 char, another will have 3 + // Whitespace after end/padding is not included in consumed bytes // Line feed - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(9), 1), utf8Bytes }; - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(9), 3), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(9), 1), utf8Bytes, utf8Bytes.Length + 3 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(9), 3), utf8Bytes, utf8Bytes.Length + 3 }; // Horizontal tab - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(10), 1), utf8Bytes }; - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(10), 3), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(10), 1), utf8Bytes, utf8Bytes.Length + 3 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(10), 3), utf8Bytes, utf8Bytes.Length + 3 }; // Carriage return - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(13), 1), utf8Bytes }; - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(13), 3), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(13), 1), utf8Bytes, utf8Bytes.Length + 3 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(13), 3), utf8Bytes, utf8Bytes.Length + 3 }; // Space - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(32), 1), utf8Bytes }; - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(32), 3), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(32), 1), utf8Bytes, utf8Bytes.Length + 3 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(32), 3), utf8Bytes, utf8Bytes.Length + 3 }; string GetBase64StringWithPassedCharInsertedAtTheEnd(char charToInsert, int numberOfTimesToInsert) => $"{firstSegment}{secondSegment}{new string(charToInsert, numberOfTimesToInsert)}"; } diff --git a/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs index abccac736bb217..f8ffba91639ac1 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs @@ -147,8 +147,9 @@ public void ValidateGuidChars() [Theory] [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] - public void ValidateBytesIgnoresCharsToBeIgnoredBytes(string utf8WithByteToBeIgnored, byte[] expectedBytes) + public void ValidateBytesIgnoresCharsToBeIgnoredBytes(string utf8WithByteToBeIgnored, byte[] expectedBytes, int expectedBytesConsumed) { + _ = expectedBytesConsumed; byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); @@ -158,8 +159,9 @@ public void ValidateBytesIgnoresCharsToBeIgnoredBytes(string utf8WithByteToBeIgn [Theory] [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] - public void ValidateBytesIgnoresCharsToBeIgnoredChars(string utf8WithByteToBeIgnored, byte[] expectedBytes) + public void ValidateBytesIgnoresCharsToBeIgnoredChars(string utf8WithByteToBeIgnored, byte[] expectedBytes, int expectedBytesConsumed) { + _ = expectedBytesConsumed; ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray(); Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs index 7102d5dc6b5640..9d4894beeacba3 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs @@ -46,7 +46,7 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa fixed (byte* srcBytes = &MemoryMarshal.GetReference(utf8)) fixed (byte* destBytes = &MemoryMarshal.GetReference(bytes)) { - int srcLength = utf8.Length; // only decode input up to the closest multiple of 4. + int srcLength = utf8.Length; int destLength = bytes.Length; int maxSrcLength = srcLength; int decodedLength = GetMaxDecodedFromUtf8Length(srcLength); @@ -67,10 +67,6 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap); OperationStatus lastBlockStatus = OperationStatus.Done; - int totalBytesIgnored = 0; - - // The next src increment is stored as it will be used if the dest has enough space - // or ignored in consumed operations if not. int pendingSrcIncrement; if (Avx2.IsSupported) @@ -83,8 +79,14 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa break; } + isComplete = TryDecodeCurrentGroupIfWhitespaceIsSeparatingValidBytesInCommonLocation(utf8, ref src, ref dest, ref end, destLength, srcBytes, destBytes, 32); + if (isComplete) + { + continue; + } + // Process 4 bytes, until first set of invalid bytes are skipped. - lastBlockStatus = IgnoreWhitespaceAndTryConsumeValidBytes(ref src, srcEnd, ref dest, destEnd, ref decodingMap, ref totalBytesIgnored, isFinalBlock, lastBlockStatus, true, out pendingSrcIncrement); + lastBlockStatus = IgnoreWhitespaceAndConsumeValidBytesUntilInvalidBlock(ref src, srcEnd, ref dest, destEnd, ref decodingMap, isFinalBlock, lastBlockStatus, true, out pendingSrcIncrement); if (lastBlockStatus != OperationStatus.Done) { break; @@ -104,8 +106,14 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa break; } + isComplete = TryDecodeCurrentGroupIfWhitespaceIsSeparatingValidBytesInCommonLocation(utf8, ref src, ref dest, ref end, destLength, srcBytes, destBytes, 16); + if (isComplete) + { + continue; + } + // Process 4 bytes, until first set of invalid bytes are skipped. - lastBlockStatus = IgnoreWhitespaceAndTryConsumeValidBytes(ref src, srcEnd, ref dest, destEnd, ref decodingMap, ref totalBytesIgnored, isFinalBlock, lastBlockStatus, true, out pendingSrcIncrement); + lastBlockStatus = IgnoreWhitespaceAndConsumeValidBytesUntilInvalidBlock(ref src, srcEnd, ref dest, destEnd, ref decodingMap, isFinalBlock, lastBlockStatus, true, out pendingSrcIncrement); if (lastBlockStatus != OperationStatus.Done) { break; @@ -113,7 +121,7 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa } } - lastBlockStatus = IgnoreWhitespaceAndTryConsumeValidBytes(ref src, srcEnd, ref dest, destEnd, ref decodingMap, ref totalBytesIgnored, isFinalBlock, lastBlockStatus, false, out pendingSrcIncrement); + lastBlockStatus = IgnoreWhitespaceAndConsumeValidBytesUntilInvalidBlock(ref src, srcEnd, ref dest, destEnd, ref decodingMap, isFinalBlock, lastBlockStatus, false, out pendingSrcIncrement); // Assess the end block and bytes beyond it. if (lastBlockStatus == OperationStatus.Done) @@ -145,19 +153,19 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa switch (lastBlockStatus) { case OperationStatus.Done: - bytesConsumed = ((int)(src - srcBytes)) - totalBytesIgnored; + bytesConsumed = (int)(src - srcBytes); bytesWritten = (int)(dest - destBytes); break; case OperationStatus.DestinationTooSmall: - bytesConsumed = ((int)(src - srcBytes)) - totalBytesIgnored; + bytesConsumed = (int)(src - srcBytes); bytesWritten = (int)(dest - destBytes); break; case OperationStatus.NeedMoreData: - bytesConsumed = ((int)(src - srcBytes)) - totalBytesIgnored; + bytesConsumed = (int)(src - srcBytes); bytesWritten = (int)(dest - destBytes); break; case OperationStatus.InvalidData: - bytesConsumed = ((int)(src - srcBytes)) - totalBytesIgnored; + bytesConsumed = (int)(src - srcBytes); bytesWritten = (int)(dest - destBytes); break; default: @@ -368,14 +376,14 @@ public static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, ou } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe OperationStatus IgnoreWhitespaceAndTryConsumeValidBytes(ref byte* src, byte* srcEnd, ref byte* dest, byte* destEnd, ref sbyte decodingMap, ref int totalBytesIgnored, bool isFinalBlock, OperationStatus lastBlockStatus, bool exitAfterFirstSkippedBytes, out int pendingSrcIncrement) + private static unsafe OperationStatus IgnoreWhitespaceAndConsumeValidBytesUntilInvalidBlock(ref byte* src, byte* srcEnd, ref byte* dest, byte* destEnd, ref sbyte decodingMap, bool isFinalBlock, OperationStatus lastBlockStatus, bool exitAfterFirstSkippedBytes, out int pendingSrcIncrement) { pendingSrcIncrement = 0; while (src <= srcEnd - 4) { byte* srcBeforeProcessing = src; - lastBlockStatus = IgnoreWhitespaceAndTryConsumeNextValidBytesBlock(ref src, srcEnd, ref dest, destEnd, ref decodingMap, ref totalBytesIgnored, isFinalBlock, out pendingSrcIncrement); + lastBlockStatus = IgnoreWhitespaceAndTryConsumeNextValidBytesBlock(ref src, srcEnd, ref dest, destEnd, ref decodingMap, isFinalBlock, out pendingSrcIncrement); if (lastBlockStatus != OperationStatus.Done // The source was not increased because there were not enough valid bytes. @@ -391,7 +399,7 @@ private static unsafe OperationStatus IgnoreWhitespaceAndTryConsumeValidBytes(re } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe OperationStatus IgnoreWhitespaceAndTryConsumeNextValidBytesBlock(ref byte* src, byte* srcEnd, ref byte* dest, byte* destEnd, ref sbyte decodingMap, ref int totalBytesIgnored, bool isFinalBlock, out int pendingSrcIncrement) + private static unsafe OperationStatus IgnoreWhitespaceAndTryConsumeNextValidBytesBlock(ref byte* src, byte* srcEnd, ref byte* dest, byte* destEnd, ref sbyte decodingMap, bool isFinalBlock, out int pendingSrcIncrement) { // The default increment will be 4 if no bytes that require ignoring are encountered. pendingSrcIncrement = 4; @@ -415,7 +423,6 @@ private static unsafe OperationStatus IgnoreWhitespaceAndTryConsumeNextValidByte && IsByteToBeIgnored(src[validBytesSearchIndex])) { validBytesSearchIndex++; - totalBytesIgnored++; } if (src >= srcEnd - validBytesSearchIndex) @@ -486,7 +493,7 @@ private static unsafe OperationStatus IgnoreWhitespaceAndTryConsumeNextValidByte return OperationStatus.InvalidData; } - return TryProcessFinalBlockWithPadding(ref src, ref dest, destEnd, ref decodingMap, ref pendingSrcIncrement, b0, b1, b2); + return TryWriteFinalUnevenBlock(ref src, ref dest, destEnd, ref decodingMap, ref pendingSrcIncrement, b0, b1, b2); } } @@ -505,7 +512,7 @@ private static unsafe OperationStatus IgnoreWhitespaceAndTryConsumeNextValidByte } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe OperationStatus TryProcessFinalBlockWithPadding(ref byte* src, ref byte* dest, byte* destEnd, ref sbyte decodingMap, ref int pendingSrcIncrement, byte b0, byte b1, byte b2) + private static unsafe OperationStatus TryWriteFinalUnevenBlock(ref byte* src, ref byte* dest, byte* destEnd, ref sbyte decodingMap, ref int pendingSrcIncrement, byte b0, byte b1, byte b2) { int i0 = Unsafe.Add(ref decodingMap, (IntPtr)b0); int i1 = Unsafe.Add(ref decodingMap, (IntPtr)b1); @@ -549,6 +556,69 @@ private static unsafe OperationStatus TryProcessFinalBlockWithPadding(ref byte* return OperationStatus.Done; } + private static unsafe bool TryDecodeCurrentGroupIfWhitespaceIsSeparatingValidBytesInCommonLocation(ReadOnlySpan utf8, ref byte* src, ref byte* dest, ref byte* end, int destLength, byte* srcBytes, byte* destBytes, byte groupSize) + { + int consumed = (int)(src - srcBytes); + int numberOfWhiteSpaceBytesSkipped = (consumed / 76) * 2; + + if (end < src + numberOfWhiteSpaceBytesSkipped) + { + // Max potential index is out of range. + return false; + } + + int srcIndexInCurrent76CharsSequence = consumed % 76; + int potentialIndeOfFirstWhitespace = (76 - srcIndexInCurrent76CharsSequence + numberOfWhiteSpaceBytesSkipped) % 78; + if (potentialIndeOfFirstWhitespace > groupSize) + { + // The potential whitespace index should have existed in the group that just failed vectorized decoding. + return false; + } + + byte firstWhiteSpace = src[potentialIndeOfFirstWhitespace]; + byte secondWhiteSpace = src[potentialIndeOfFirstWhitespace + 1]; + if (firstWhiteSpace != 13 + || secondWhiteSpace != 10) + { + return false; + } + + // The two slice and copy below avoids a loop that skips the whitespace indexes to fill the new span. + Span groupOfBlocksWithoutWhitespace = stackalloc byte[groupSize]; + // Slice and copy next 34 bytes + utf8.Slice(consumed + 2, groupSize).CopyTo(groupOfBlocksWithoutWhitespace); + // Overwrite whitespace and valid bytes before + utf8.Slice(consumed, potentialIndeOfFirstWhitespace).CopyTo(groupOfBlocksWithoutWhitespace); + + fixed (byte* groupOfBlocksWithoutWhitespaceSrcBytes = &MemoryMarshal.GetReference(utf8)) + { + byte* groupOfBlocksWithoutWhitespaceSrcBytesOriginal = groupOfBlocksWithoutWhitespaceSrcBytes; + byte* groupOfBlocksWithoutWhitespaceSrc = groupOfBlocksWithoutWhitespaceSrcBytes; + + // Note: Set srcEnd = start, so the while loop will fail during decoding, so that only one pass is possible. + bool isComplete; + if (groupSize == 32) + { + isComplete = Avx2Decode(ref groupOfBlocksWithoutWhitespaceSrc, ref dest, groupOfBlocksWithoutWhitespaceSrcBytes, groupOfBlocksWithoutWhitespace.Length, destLength, groupOfBlocksWithoutWhitespaceSrcBytes, destBytes); + } + else + { + isComplete = Vector128Decode(ref groupOfBlocksWithoutWhitespaceSrc, ref dest, groupOfBlocksWithoutWhitespaceSrcBytes, groupOfBlocksWithoutWhitespace.Length, destLength, groupOfBlocksWithoutWhitespaceSrcBytes, destBytes); + } + + if (groupOfBlocksWithoutWhitespaceSrc == groupOfBlocksWithoutWhitespaceSrcBytesOriginal) + { + // Decoding failed because there was more whitespace or invalid bytes. + return false; + } + + // + 2 for the 2 whitespace chars + src += groupSize + 2; + + return true; + } + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe bool Avx2Decode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) { From 8482bbd051beb7f7ac64646c910959013c72acf1 Mon Sep 17 00:00:00 2001 From: Heath Baron-Morgan Date: Thu, 22 Dec 2022 21:28:14 -0800 Subject: [PATCH 07/17] Address PR feedback: Implement driver/worker pattern with original code --- .../src/System/Buffers/Text/Base64Decoder.cs | 971 ++++++++++-------- .../System/Buffers/Text/Base64Validator.cs | 6 +- 2 files changed, 557 insertions(+), 420 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs index 9d4894beeacba3..88ace1273a62ca 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs @@ -34,152 +34,56 @@ public static partial class Base64 /// - InvalidData - if the input contains bytes outside of the expected base64 range, or if it contains invalid/more than two padding characters, /// or if the input is incomplete (i.e. not a multiple of 4) and is . /// - public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Span bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) + public static OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Span bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) { - if (utf8.IsEmpty) - { - bytesConsumed = 0; - bytesWritten = 0; - return OperationStatus.Done; - } + OperationStatus status = OperationStatus.Done; + bytesConsumed = 0; + bytesWritten = 0; - fixed (byte* srcBytes = &MemoryMarshal.GetReference(utf8)) - fixed (byte* destBytes = &MemoryMarshal.GetReference(bytes)) + while (utf8.Length > 0) { - int srcLength = utf8.Length; - int destLength = bytes.Length; - int maxSrcLength = srcLength; - int decodedLength = GetMaxDecodedFromUtf8Length(srcLength); + status = DecodeFromUtf8Core(utf8, bytes, out int localConsumed, out int localWritten, isFinalBlock); + bytesConsumed += localConsumed; + bytesWritten += localWritten; - // max. 2 padding chars - if (destLength < decodedLength - 2) + if (status is OperationStatus.Done or OperationStatus.NeedMoreData) { - // For overflow see comment below - maxSrcLength = destLength / 3 * 4; + break; } - byte* src = srcBytes; - byte* dest = destBytes; - byte* destEnd = dest + (uint)destLength; - byte* srcEnd = srcBytes + (uint)srcLength; - byte* srcMax = srcBytes + (uint)maxSrcLength; - byte* end = srcMax - 45; - - ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap); - OperationStatus lastBlockStatus = OperationStatus.Done; - int pendingSrcIncrement; - - if (Avx2.IsSupported) - { - while (end >= src) - { - bool isComplete = Avx2Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); - if (isComplete) - { - break; - } - - isComplete = TryDecodeCurrentGroupIfWhitespaceIsSeparatingValidBytesInCommonLocation(utf8, ref src, ref dest, ref end, destLength, srcBytes, destBytes, 32); - if (isComplete) - { - continue; - } + utf8 = utf8.Slice(localConsumed); + bytes = bytes.Slice(localWritten); - // Process 4 bytes, until first set of invalid bytes are skipped. - lastBlockStatus = IgnoreWhitespaceAndConsumeValidBytesUntilInvalidBlock(ref src, srcEnd, ref dest, destEnd, ref decodingMap, isFinalBlock, lastBlockStatus, true, out pendingSrcIncrement); - if (lastBlockStatus != OperationStatus.Done) - { - break; - } - } - } - else if ((Ssse3.IsSupported || AdvSimd.Arm64.IsSupported) - && BitConverter.IsLittleEndian) + if (AllBytesRemainingAreWhitespace(utf8, out localConsumed)) { - end = srcMax - 24; - - while (end >= src) + if (localConsumed > 0) { - bool isComplete = Vector128Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); - if (isComplete) - { - break; - } - - isComplete = TryDecodeCurrentGroupIfWhitespaceIsSeparatingValidBytesInCommonLocation(utf8, ref src, ref dest, ref end, destLength, srcBytes, destBytes, 16); - if (isComplete) - { - continue; - } - - // Process 4 bytes, until first set of invalid bytes are skipped. - lastBlockStatus = IgnoreWhitespaceAndConsumeValidBytesUntilInvalidBlock(ref src, srcEnd, ref dest, destEnd, ref decodingMap, isFinalBlock, lastBlockStatus, true, out pendingSrcIncrement); - if (lastBlockStatus != OperationStatus.Done) - { - break; - } + status = OperationStatus.Done; } - } - lastBlockStatus = IgnoreWhitespaceAndConsumeValidBytesUntilInvalidBlock(ref src, srcEnd, ref dest, destEnd, ref decodingMap, isFinalBlock, lastBlockStatus, false, out pendingSrcIncrement); - - // Assess the end block and bytes beyond it. - if (lastBlockStatus == OperationStatus.Done) - { - if (!isFinalBlock) - { - int remainingBytes = (int)(srcEnd - src); - if (remainingBytes is > 0 and < 4) - { - // An incomplete block of bytes was found. - lastBlockStatus = OperationStatus.NeedMoreData; - } - } - else - { - // Check if there are bytes that should not be ignored beyond the expected end of the valid input range. - int indexOfBytesNotConsumed = pendingSrcIncrement; - while (src <= srcEnd - indexOfBytesNotConsumed - 1) - { - if (!IsByteToBeIgnored(src[indexOfBytesNotConsumed++])) - { - lastBlockStatus = OperationStatus.InvalidData; - break; - } - } - } + // The end of the input has been reached. + break; } - switch (lastBlockStatus) + if (localConsumed == 0) { - case OperationStatus.Done: - bytesConsumed = (int)(src - srcBytes); - bytesWritten = (int)(dest - destBytes); - break; - case OperationStatus.DestinationTooSmall: - bytesConsumed = (int)(src - srcBytes); - bytesWritten = (int)(dest - destBytes); - break; - case OperationStatus.NeedMoreData: - bytesConsumed = (int)(src - srcBytes); - bytesWritten = (int)(dest - destBytes); - break; - case OperationStatus.InvalidData: - bytesConsumed = (int)(src - srcBytes); - bytesWritten = (int)(dest - destBytes); - break; - default: - bytesConsumed = 0; - bytesWritten = 0; - break; + // First char isn't whitespace, but we didn't consume anything, + // thus the input may have whitespace anywhere in between. So fall back to block-wise decoding. + status = DecodeWithWhitespaceBlockwise(utf8, bytes, out localConsumed, out localWritten, isFinalBlock); + bytesConsumed += localConsumed; + bytesWritten += localWritten; + break; } - return lastBlockStatus; + bytesConsumed += localConsumed; + utf8 = utf8.Slice(localConsumed); } + + return status; } /// - /// Returns the maximum length (in bytes) of the result if you were to deocde base 64 encoded text within a byte span of size "length". + /// Returns the maximum length (in bytes) of the result if you were to decode base 64 encoded text within a byte span of size "length". /// /// /// Thrown when the specified is less than 0. @@ -208,312 +112,366 @@ public static int GetMaxDecodedFromUtf8Length(int length) /// It does not return NeedMoreData since this method tramples the data in the buffer and /// hence can only be called once with all the data in the buffer. /// - public static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, out int bytesWritten) + public static OperationStatus DecodeFromUtf8InPlace(Span buffer, out int bytesWritten) { - if (buffer.IsEmpty) + OperationStatus status = DecodeFromUtf8InPlaceCore(buffer, out bytesWritten, out uint sourceIndex); + + if (status is OperationStatus.InvalidData or OperationStatus.DestinationTooSmall) + { + // The input may have whitespace, attempt to decode while ignoring whitespace. + status = DecodeWithWhitespaceFromUtf8InPlace(buffer, ref bytesWritten, sourceIndex); + } + + return status; + } + + /// + /// Untouched original DecodeFromUtf8 method + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe OperationStatus DecodeFromUtf8Core(ReadOnlySpan utf8, Span bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) + { + if (utf8.IsEmpty) { + bytesConsumed = 0; bytesWritten = 0; return OperationStatus.Done; } - fixed (byte* bufferBytes = &MemoryMarshal.GetReference(buffer)) + fixed (byte* srcBytes = &MemoryMarshal.GetReference(utf8)) + fixed (byte* destBytes = &MemoryMarshal.GetReference(bytes)) { - int bufferLength = buffer.Length; - uint sourceIndex = 0; - uint destIndex = 0; + int srcLength = utf8.Length & ~0x3; // only decode input up to the closest multiple of 4. + int destLength = bytes.Length; + int maxSrcLength = srcLength; + int decodedLength = GetMaxDecodedFromUtf8Length(srcLength); - if (bufferLength == 0) - goto DoneExit; + // max. 2 padding chars + if (destLength < decodedLength - 2) + { + // For overflow see comment below + maxSrcLength = destLength / 3 * 4; + } - ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap); + byte* src = srcBytes; + byte* dest = destBytes; + byte* srcEnd = srcBytes + (uint)srcLength; + byte* srcMax = srcBytes + (uint)maxSrcLength; - while (sourceIndex <= bufferLength - 4) + if (maxSrcLength >= 24) { - // The default increment will be 4 if no bytes that require ignoring are encountered. - uint nextSoureIndex = sourceIndex + 4; - byte b0 = bufferBytes[sourceIndex]; - byte b1 = bufferBytes[sourceIndex + 1]; - byte b2 = bufferBytes[sourceIndex + 2]; - byte b3 = bufferBytes[sourceIndex + 3]; + byte* end = srcMax - 45; + if (Avx2.IsSupported && (end >= src)) + { + Avx2Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); - int result = Decode(b0, b1, b2, b3, ref decodingMap); - if (result < 0) + if (src == srcEnd) + goto DoneExit; + } + + end = srcMax - 24; + if ((Ssse3.IsSupported || AdvSimd.Arm64.IsSupported) && BitConverter.IsLittleEndian && (end >= src)) { - int firstInvalidIndex = GetIndexOfFirstByteToBeIgnored(bufferBytes + sourceIndex); - if (firstInvalidIndex >= 0) - { - uint validBytesSearchIndex = (uint)firstInvalidIndex + sourceIndex; - bool insufficientValidBytesFound = false; + Vector128Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); - for (int currentBlockIndex = firstInvalidIndex; currentBlockIndex < 4; currentBlockIndex++) - { - while (validBytesSearchIndex <= bufferLength - 1 - && IsByteToBeIgnored(bufferBytes[validBytesSearchIndex])) - { - validBytesSearchIndex++; - } + if (src == srcEnd) + goto DoneExit; + } + } - if (validBytesSearchIndex > bufferLength - 1) - { - insufficientValidBytesFound = true; - break; - } + // Last bytes could have padding characters, so process them separately and treat them as valid only if isFinalBlock is true + // if isFinalBlock is false, padding characters are considered invalid + int skipLastChunk = isFinalBlock ? 4 : 0; - if (currentBlockIndex == 0) - { - b0 = bufferBytes[validBytesSearchIndex]; - } - else if (currentBlockIndex == 1) - { - b1 = bufferBytes[validBytesSearchIndex]; - } - else if (currentBlockIndex == 2) - { - b2 = bufferBytes[validBytesSearchIndex]; - } - else - { - b3 = bufferBytes[validBytesSearchIndex]; - } + if (destLength >= decodedLength) + { + maxSrcLength = srcLength - skipLastChunk; + } + else + { + // This should never overflow since destLength here is less than int.MaxValue / 4 * 3 (i.e. 1610612733) + // Therefore, (destLength / 3) * 4 will always be less than 2147483641 + Debug.Assert(destLength < (int.MaxValue / 4 * 3)); + maxSrcLength = (destLength / 3) * 4; + } - validBytesSearchIndex++; - } + ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap); + srcMax = srcBytes + (uint)maxSrcLength; - if (insufficientValidBytesFound) - { - break; - } + while (src < srcMax) + { + int result = Decode(src, ref decodingMap); - result = Decode(b0, b1, b2, b3, ref decodingMap); - if (result < 0 - && !IsBlockEndBytesPadding(b2, b3)) - { - goto InvalidExit; - } + if (result < 0) + goto InvalidDataExit; - nextSoureIndex = validBytesSearchIndex; - } - else - { - if (!IsBlockEndBytesPadding(b2, b3)) - { - goto InvalidExit; - } - } + WriteThreeLowOrderBytes(dest, result); + src += 4; + dest += 3; + } - // Handle last four bytes. There are 1, 2 padding chars. - if (result < 0 - && IsBlockEndBytesPadding(b2, b3)) - { - uint indexOfBytesAfterPadding = sourceIndex + nextSoureIndex; - while (indexOfBytesAfterPadding + 1 <= bufferLength - 1) - { - if (!IsByteToBeIgnored(bufferBytes[indexOfBytesAfterPadding++])) - { - // Only bytes to be ignored can be after padding bytes. - goto InvalidExit; - } - } + if (maxSrcLength != srcLength - skipLastChunk) + goto DestinationTooSmallExit; - int i0 = Unsafe.Add(ref decodingMap, (IntPtr)b0); - int i1 = Unsafe.Add(ref decodingMap, (IntPtr)b1); + // If input is less than 4 bytes, srcLength == sourceIndex == 0 + // If input is not a multiple of 4, sourceIndex == srcLength != 0 + if (src == srcEnd) + { + if (isFinalBlock) + goto InvalidDataExit; - i0 <<= 18; - i1 <<= 12; + if (src == srcBytes + utf8.Length) + goto DoneExit; - i0 |= i1; + goto NeedMoreDataExit; + } - if (b2 != EncodingPad) - { - int i2 = Unsafe.Add(ref decodingMap, (IntPtr)b2); + // if isFinalBlock is false, we will never reach this point - i2 <<= 6; + // Handle last four bytes. There are 0, 1, 2 padding chars. + uint t0 = srcEnd[-4]; + uint t1 = srcEnd[-3]; + uint t2 = srcEnd[-2]; + uint t3 = srcEnd[-1]; - i0 |= i2; + int i0 = Unsafe.Add(ref decodingMap, (IntPtr)t0); + int i1 = Unsafe.Add(ref decodingMap, (IntPtr)t1); - if (i0 < 0) - goto InvalidExit; + i0 <<= 18; + i1 <<= 12; - bufferBytes[destIndex] = (byte)(i0 >> 16); - bufferBytes[destIndex + 1] = (byte)(i0 >> 8); - destIndex += 2; - } - else - { - if (i0 < 0) - goto InvalidExit; + i0 |= i1; - bufferBytes[destIndex] = (byte)(i0 >> 16); - destIndex += 1; - } + byte* destMax = destBytes + (uint)destLength; - goto DoneExit; - } - } + if (t3 != EncodingPad) + { + int i2 = Unsafe.Add(ref decodingMap, (IntPtr)t2); + int i3 = Unsafe.Add(ref decodingMap, (IntPtr)t3); - WriteThreeLowOrderBytes(bufferBytes + destIndex, result); - destIndex += 3; - sourceIndex = nextSoureIndex; + i2 <<= 6; + + i0 |= i3; + i0 |= i2; + + if (i0 < 0) + goto InvalidDataExit; + if (dest + 3 > destMax) + goto DestinationTooSmallExit; + + WriteThreeLowOrderBytes(dest, i0); + dest += 3; } + else if (t2 != EncodingPad) + { + int i2 = Unsafe.Add(ref decodingMap, (IntPtr)t2); - // Check if there are any bytes that should not be ignored after the last valid block size. - while (sourceIndex <= bufferLength - 1) + i2 <<= 6; + + i0 |= i2; + + if (i0 < 0) + goto InvalidDataExit; + if (dest + 2 > destMax) + goto DestinationTooSmallExit; + + dest[0] = (byte)(i0 >> 16); + dest[1] = (byte)(i0 >> 8); + dest += 2; + } + else { - if (!IsByteToBeIgnored(bufferBytes[sourceIndex++])) - { - goto InvalidExit; - } + if (i0 < 0) + goto InvalidDataExit; + if (dest + 1 > destMax) + goto DestinationTooSmallExit; + + dest[0] = (byte)(i0 >> 16); + dest += 1; } - DoneExit: - bytesWritten = (int)destIndex; + src += 4; + + if (srcLength != utf8.Length) + goto InvalidDataExit; + + DoneExit: + bytesConsumed = (int)(src - srcBytes); + bytesWritten = (int)(dest - destBytes); return OperationStatus.Done; - InvalidExit: - bytesWritten = (int)destIndex; + DestinationTooSmallExit: + if (srcLength != utf8.Length && isFinalBlock) + goto InvalidDataExit; // if input is not a multiple of 4, and there is no more data, return invalid data instead + + bytesConsumed = (int)(src - srcBytes); + bytesWritten = (int)(dest - destBytes); + return OperationStatus.DestinationTooSmall; + + NeedMoreDataExit: + bytesConsumed = (int)(src - srcBytes); + bytesWritten = (int)(dest - destBytes); + return OperationStatus.NeedMoreData; + + InvalidDataExit: + bytesConsumed = (int)(src - srcBytes); + bytesWritten = (int)(dest - destBytes); return OperationStatus.InvalidData; } } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe OperationStatus IgnoreWhitespaceAndConsumeValidBytesUntilInvalidBlock(ref byte* src, byte* srcEnd, ref byte* dest, byte* destEnd, ref sbyte decodingMap, bool isFinalBlock, OperationStatus lastBlockStatus, bool exitAfterFirstSkippedBytes, out int pendingSrcIncrement) + private static bool AllBytesRemainingAreWhitespace(ReadOnlySpan encoded, out int consumed) { - pendingSrcIncrement = 0; + consumed = 0; - while (src <= srcEnd - 4) + while (consumed < encoded.Length) { - byte* srcBeforeProcessing = src; - lastBlockStatus = IgnoreWhitespaceAndTryConsumeNextValidBytesBlock(ref src, srcEnd, ref dest, destEnd, ref decodingMap, isFinalBlock, out pendingSrcIncrement); - - if (lastBlockStatus != OperationStatus.Done - // The source was not increased because there were not enough valid bytes. - || srcBeforeProcessing == src - // Exit after consuming more than 4 bytes due to skipping whitespace. - || (exitAfterFirstSkippedBytes && src - srcBeforeProcessing > 4)) + if (!IsByteWhitespace(encoded[consumed])) { - break; + return false; } + + consumed++; } - return lastBlockStatus; + return true; } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe OperationStatus IgnoreWhitespaceAndTryConsumeNextValidBytesBlock(ref byte* src, byte* srcEnd, ref byte* dest, byte* destEnd, ref sbyte decodingMap, bool isFinalBlock, out int pendingSrcIncrement) + private static unsafe OperationStatus DecodeWithWhitespaceBlockwise(ReadOnlySpan utf8, Span bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) { - // The default increment will be 4 if no bytes that require ignoring are encountered. - pendingSrcIncrement = 4; - byte b0 = src[0]; - byte b1 = src[1]; - byte b2 = src[2]; - byte b3 = src[3]; - - int result = Decode(b0, b1, b2, b3, ref decodingMap); - if (result < 0) + fixed (byte* srcBytes = &MemoryMarshal.GetReference(utf8)) + fixed (byte* destBytes = &MemoryMarshal.GetReference(bytes)) { - int firstInvalidIndex = GetIndexOfFirstByteToBeIgnored(src); - if (firstInvalidIndex >= 0) - { - int validBytesSearchIndex = firstInvalidIndex; - bool insufficientValidBytesFound = false; + int srcLength = utf8.Length; + int destLength = bytes.Length; + byte* src = srcBytes; + byte* dest = destBytes; + byte* destEnd = dest + (uint)destLength; + byte* srcEnd = srcBytes + (uint)srcLength; - for (int currentBlockIndex = firstInvalidIndex; currentBlockIndex < 4; currentBlockIndex++) - { - while (src < srcEnd - validBytesSearchIndex - && IsByteToBeIgnored(src[validBytesSearchIndex])) - { - validBytesSearchIndex++; - } + ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap); + + OperationStatus lastBlockStatus = IgnoreWhitespaceAndConsumeValidBytesUntilInvalidBlock(ref src, srcEnd, ref dest, destEnd, ref decodingMap, isFinalBlock, out int pendingSrcIncrement); - if (src >= srcEnd - validBytesSearchIndex) + // Assess the end block and bytes beyond it. + if (lastBlockStatus == OperationStatus.Done) + { + // Check if there are bytes that should not be ignored beyond the expected end of the valid input range. + while (src + pendingSrcIncrement <= srcEnd - 1) + { + if (!IsByteWhitespace(src[pendingSrcIncrement++])) { - insufficientValidBytesFound = true; + lastBlockStatus = OperationStatus.InvalidData; break; } + } + } - if (currentBlockIndex == 0) - { - b0 = src[validBytesSearchIndex]; - } - else if (currentBlockIndex == 1) - { - b1 = src[validBytesSearchIndex]; - } - else if (currentBlockIndex == 2) - { - b2 = src[validBytesSearchIndex]; - } - else - { - b3 = src[validBytesSearchIndex]; - } + bytesConsumed = (int)(src - srcBytes); + bytesWritten = (int)(dest - destBytes); + return lastBlockStatus; + } + } - validBytesSearchIndex++; - } - if (insufficientValidBytesFound) - { - return OperationStatus.Done; - } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe OperationStatus IgnoreWhitespaceAndConsumeValidBytesUntilInvalidBlock(ref byte* src, byte* srcEnd, ref byte* dest, byte* destEnd, ref sbyte decodingMap, bool isFinalBlock, out int pendingSrcIncrement) + { + pendingSrcIncrement = 0; + Span buffer = stackalloc byte[4]; - result = Decode(b0, b1, b2, b3, ref decodingMap); - if (result < 0 - && !IsBlockEndBytesPadding(b2, b3)) - { - return OperationStatus.InvalidData; - } + while (src <= srcEnd - 4) + { + // The default increment will be 4 if no bytes that require ignoring are encountered. + pendingSrcIncrement = 4; - pendingSrcIncrement = validBytesSearchIndex; - } - else + buffer[0] = src[0]; + buffer[1] = src[1]; + buffer[2] = src[2]; + buffer[3] = src[3]; + + int result = Decode(buffer, ref decodingMap); + if (result < 0) { - if (!IsBlockEndBytesPadding(b2, b3)) + // Start searching for valid bytes from the first invalid byte. + int indexOfFirstWhitespace = GetIndexOfFirstWhitespace(src); + if (indexOfFirstWhitespace >= 0) + { + // Try to fill remainder of the buffer with valid bytes. + pendingSrcIncrement = indexOfFirstWhitespace; + + for (int currentBlockIndex = indexOfFirstWhitespace; currentBlockIndex < 4; currentBlockIndex++) + { + while (src < srcEnd - pendingSrcIncrement + && IsByteWhitespace(src[pendingSrcIncrement])) + { + pendingSrcIncrement++; + } + + if (src >= srcEnd - pendingSrcIncrement) + { + // Insufficient valid bytes found. + return OperationStatus.Done; + } + + buffer[currentBlockIndex] = src[pendingSrcIncrement]; + pendingSrcIncrement++; + } + + result = Decode(buffer, ref decodingMap); + } + else if (!IsBlockEndBytesPadding(buffer)) { + // Found invalid non-whitespace bytes. return OperationStatus.InvalidData; } - } - // Check to see if parsing failed due to padding. There could be 1 or 2 padding chars. - if (result < 0 - && IsBlockEndBytesPadding(b2, b3)) - { - int indexOfBytesAfterPadding = pendingSrcIncrement; - while (src <= srcEnd - indexOfBytesAfterPadding - 1) + // Check to see if parsing failed due to padding. There could be 1 or 2 padding chars. + if (IsBlockEndBytesPadding(buffer)) { - if (!IsByteToBeIgnored(src[indexOfBytesAfterPadding++])) + // If isFinalBlock is false, padding is treated as invalid. + if (!isFinalBlock) { - // Only bytes to be ignored can be after padding bytes. return OperationStatus.InvalidData; } - } - // If isFinalBlock is false, padding is treaded as invalid. - if (!isFinalBlock) - { - return OperationStatus.InvalidData; + int indexOfBytesAfterPadding = pendingSrcIncrement; + while (src <= srcEnd - indexOfBytesAfterPadding - 1) + { + if (!IsByteWhitespace(src[indexOfBytesAfterPadding++])) + { + // Only whitespace can be after padding bytes. + // Exit before consuming more bytes. + return OperationStatus.InvalidData; + } + } + + return TryWriteFinalBlockWithPadding(ref dest, destEnd, ref decodingMap, buffer); } + } - return TryWriteFinalUnevenBlock(ref src, ref dest, destEnd, ref decodingMap, ref pendingSrcIncrement, b0, b1, b2); + if (dest > destEnd - 3) + { + return OperationStatus.DestinationTooSmall; } - } - if (dest > destEnd - 3) - { - return OperationStatus.DestinationTooSmall; + WriteThreeLowOrderBytes(dest, result); + src += pendingSrcIncrement; + dest += 3; } - WriteThreeLowOrderBytes(dest, result); - - src += pendingSrcIncrement; - pendingSrcIncrement = 0; - dest += 3; - return OperationStatus.Done; } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe OperationStatus TryWriteFinalUnevenBlock(ref byte* src, ref byte* dest, byte* destEnd, ref sbyte decodingMap, ref int pendingSrcIncrement, byte b0, byte b1, byte b2) + private static unsafe OperationStatus TryWriteFinalBlockWithPadding(ref byte* dest, byte* destEnd, ref sbyte decodingMap, Span buffer) { + byte b0 = buffer[0]; + byte b1 = buffer[1]; + byte b2 = buffer[2]; + int i0 = Unsafe.Add(ref decodingMap, (IntPtr)b0); int i1 = Unsafe.Add(ref decodingMap, (IntPtr)b1); @@ -550,77 +508,245 @@ private static unsafe OperationStatus TryWriteFinalUnevenBlock(ref byte* src, re dest += 1; } - src += pendingSrcIncrement; - pendingSrcIncrement = 0; - return OperationStatus.Done; } - private static unsafe bool TryDecodeCurrentGroupIfWhitespaceIsSeparatingValidBytesInCommonLocation(ReadOnlySpan utf8, ref byte* src, ref byte* dest, ref byte* end, int destLength, byte* srcBytes, byte* destBytes, byte groupSize) + /// + /// Untouched original DecodeFromUtf8InPlace method + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe OperationStatus DecodeFromUtf8InPlaceCore(Span buffer, out int bytesWritten, out uint sourceIndex) { - int consumed = (int)(src - srcBytes); - int numberOfWhiteSpaceBytesSkipped = (consumed / 76) * 2; - - if (end < src + numberOfWhiteSpaceBytesSkipped) + if (buffer.IsEmpty) { - // Max potential index is out of range. - return false; + bytesWritten = 0; + sourceIndex = 0; + return OperationStatus.Done; } - int srcIndexInCurrent76CharsSequence = consumed % 76; - int potentialIndeOfFirstWhitespace = (76 - srcIndexInCurrent76CharsSequence + numberOfWhiteSpaceBytesSkipped) % 78; - if (potentialIndeOfFirstWhitespace > groupSize) + fixed (byte* bufferBytes = &MemoryMarshal.GetReference(buffer)) { - // The potential whitespace index should have existed in the group that just failed vectorized decoding. - return false; - } + int bufferLength = buffer.Length; + sourceIndex = 0; + uint destIndex = 0; - byte firstWhiteSpace = src[potentialIndeOfFirstWhitespace]; - byte secondWhiteSpace = src[potentialIndeOfFirstWhitespace + 1]; - if (firstWhiteSpace != 13 - || secondWhiteSpace != 10) - { - return false; - } + // only decode input if it is a multiple of 4 + if (bufferLength != ((bufferLength >> 2) * 4)) + goto InvalidExit; + if (bufferLength == 0) + goto DoneExit; - // The two slice and copy below avoids a loop that skips the whitespace indexes to fill the new span. - Span groupOfBlocksWithoutWhitespace = stackalloc byte[groupSize]; - // Slice and copy next 34 bytes - utf8.Slice(consumed + 2, groupSize).CopyTo(groupOfBlocksWithoutWhitespace); - // Overwrite whitespace and valid bytes before - utf8.Slice(consumed, potentialIndeOfFirstWhitespace).CopyTo(groupOfBlocksWithoutWhitespace); + ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap); - fixed (byte* groupOfBlocksWithoutWhitespaceSrcBytes = &MemoryMarshal.GetReference(utf8)) - { - byte* groupOfBlocksWithoutWhitespaceSrcBytesOriginal = groupOfBlocksWithoutWhitespaceSrcBytes; - byte* groupOfBlocksWithoutWhitespaceSrc = groupOfBlocksWithoutWhitespaceSrcBytes; + while (sourceIndex < bufferLength - 4) + { + int result = Decode(bufferBytes + sourceIndex, ref decodingMap); + if (result < 0) + goto InvalidExit; + WriteThreeLowOrderBytes(bufferBytes + destIndex, result); + destIndex += 3; + sourceIndex += 4; + } + + uint t0 = bufferBytes[bufferLength - 4]; + uint t1 = bufferBytes[bufferLength - 3]; + uint t2 = bufferBytes[bufferLength - 2]; + uint t3 = bufferBytes[bufferLength - 1]; + + int i0 = Unsafe.Add(ref decodingMap, (IntPtr)t0); + int i1 = Unsafe.Add(ref decodingMap, (IntPtr)t1); + + i0 <<= 18; + i1 <<= 12; - // Note: Set srcEnd = start, so the while loop will fail during decoding, so that only one pass is possible. - bool isComplete; - if (groupSize == 32) + i0 |= i1; + + if (t3 != EncodingPad) { - isComplete = Avx2Decode(ref groupOfBlocksWithoutWhitespaceSrc, ref dest, groupOfBlocksWithoutWhitespaceSrcBytes, groupOfBlocksWithoutWhitespace.Length, destLength, groupOfBlocksWithoutWhitespaceSrcBytes, destBytes); + int i2 = Unsafe.Add(ref decodingMap, (IntPtr)t2); + int i3 = Unsafe.Add(ref decodingMap, (IntPtr)t3); + + i2 <<= 6; + + i0 |= i3; + i0 |= i2; + + if (i0 < 0) + goto InvalidExit; + + WriteThreeLowOrderBytes(bufferBytes + destIndex, i0); + destIndex += 3; + } + else if (t2 != EncodingPad) + { + int i2 = Unsafe.Add(ref decodingMap, (IntPtr)t2); + + i2 <<= 6; + + i0 |= i2; + + if (i0 < 0) + goto InvalidExit; + + bufferBytes[destIndex] = (byte)(i0 >> 16); + bufferBytes[destIndex + 1] = (byte)(i0 >> 8); + destIndex += 2; } else { - isComplete = Vector128Decode(ref groupOfBlocksWithoutWhitespaceSrc, ref dest, groupOfBlocksWithoutWhitespaceSrcBytes, groupOfBlocksWithoutWhitespace.Length, destLength, groupOfBlocksWithoutWhitespaceSrcBytes, destBytes); + if (i0 < 0) + goto InvalidExit; + + bufferBytes[destIndex] = (byte)(i0 >> 16); + destIndex += 1; + } + + DoneExit: + bytesWritten = (int)destIndex; + return OperationStatus.Done; + + InvalidExit: + bytesWritten = (int)destIndex; + return OperationStatus.InvalidData; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe OperationStatus DecodeWithWhitespaceFromUtf8InPlace(Span inputBuffer, ref int destIndex, uint sourceIndex) + { + fixed (byte* bufferBytes = &MemoryMarshal.GetReference(inputBuffer)) + { + int bufferLength = inputBuffer.Length; + Span bufferWithoutWhitespace = stackalloc byte[4]; + + ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap); + + while (sourceIndex <= bufferLength - 4) + { + // The default increment will be 4 if no bytes that require ignoring are encountered. + uint nextSrcIndex = sourceIndex + 4; + + bufferWithoutWhitespace[0] = bufferBytes[sourceIndex]; + bufferWithoutWhitespace[1] = bufferBytes[sourceIndex + 1]; + bufferWithoutWhitespace[2] = bufferBytes[sourceIndex + 2]; + bufferWithoutWhitespace[3] = bufferBytes[sourceIndex + 3]; + + int result = Decode(bufferWithoutWhitespace, ref decodingMap); + if (result < 0) + { + // Start searching for valid bytes from the first invalid byte. + int indexOfFirstWhitespace = GetIndexOfFirstWhitespace(bufferBytes + sourceIndex); + if (indexOfFirstWhitespace >= 0) + { + // Try to fill remainder of the buffer with valid bytes. + nextSrcIndex = (uint)indexOfFirstWhitespace + sourceIndex; + + for (int currentBlockIndex = indexOfFirstWhitespace; currentBlockIndex < 4; currentBlockIndex++) + { + while (nextSrcIndex <= bufferLength - 1 + && IsByteWhitespace(bufferBytes[nextSrcIndex])) + { + nextSrcIndex++; + } + + if (nextSrcIndex > bufferLength - 1) + { + return OperationStatus.Done; + } + + bufferWithoutWhitespace[currentBlockIndex] = bufferBytes[nextSrcIndex]; + nextSrcIndex++; + } + + result = Decode(bufferWithoutWhitespace, ref decodingMap); + } + else if (!IsBlockEndBytesPadding(bufferWithoutWhitespace)) + { + // Found invalid non-whitespace bytes. + return OperationStatus.InvalidData; + } + + // Check to see if parsing failed due to padding. There could be 1 or 2 padding chars. + if (IsBlockEndBytesPadding(bufferWithoutWhitespace)) + { + uint indexOfBytesAfterPadding = sourceIndex + nextSrcIndex; + while (indexOfBytesAfterPadding + 1 <= bufferLength - 1) + { + if (!IsByteWhitespace(bufferBytes[indexOfBytesAfterPadding++])) + { + // Only whitespace can be after padding bytes. + // Exit before writing more bytes. + return OperationStatus.InvalidData; + } + } + + return TryWriteFinalBlockWithPadding(bufferBytes, ref decodingMap, ref destIndex, bufferWithoutWhitespace); + } + } + + WriteThreeLowOrderBytes(bufferBytes + destIndex, result); + destIndex += 3; + sourceIndex = nextSrcIndex; } - if (groupOfBlocksWithoutWhitespaceSrc == groupOfBlocksWithoutWhitespaceSrcBytesOriginal) + // Check if there are any bytes that should not be ignored after the last valid block size. + while (sourceIndex <= bufferLength - 1) { - // Decoding failed because there was more whitespace or invalid bytes. - return false; + if (!IsByteWhitespace(bufferBytes[sourceIndex++])) + { + return OperationStatus.InvalidData; + } } - // + 2 for the 2 whitespace chars - src += groupSize + 2; + return OperationStatus.Done; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe OperationStatus TryWriteFinalBlockWithPadding(byte* dest, ref sbyte decodingMap, ref int destIndex, Span buffer) + { + byte b0 = buffer[0]; + byte b1 = buffer[1]; + byte b2 = buffer[2]; + + // Try read final block with padding + int i0 = Unsafe.Add(ref decodingMap, (IntPtr)b0); + int i1 = Unsafe.Add(ref decodingMap, (IntPtr)b1); + + i0 <<= 18; + i1 <<= 12; + + i0 |= i1; + + if (b2 != EncodingPad) + { + int i2 = Unsafe.Add(ref decodingMap, (IntPtr)b2); + + i2 <<= 6; + + i0 |= i2; + + if (i0 < 0) + return OperationStatus.InvalidData; + + dest[destIndex] = (byte)(i0 >> 16); + dest[destIndex + 1] = (byte)(i0 >> 8); + destIndex += 2; + } + else + { + if (i0 < 0) + return OperationStatus.InvalidData; - return true; + dest[destIndex] = (byte)(i0 >> 16); + destIndex += 1; } + + return OperationStatus.Done; } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe bool Avx2Decode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + private static unsafe void Avx2Decode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) { // If we have AVX2 support, pick off 32 bytes at a time for as long as we can, // but make sure that we quit before seeing any == markers at the end of the @@ -705,7 +831,7 @@ private static unsafe bool Avx2Decode(ref byte* srcBytes, ref byte* destBytes, b srcBytes = src; destBytes = dest; - return false; + return; } Vector256 eq2F = Avx2.CompareEqual(str, mask2F); @@ -751,7 +877,7 @@ private static unsafe bool Avx2Decode(ref byte* srcBytes, ref byte* destBytes, b srcBytes = src; destBytes = dest; - return true; + return; } // This can be replaced once https://github.com/dotnet/runtime/issues/63331 is implemented. @@ -771,7 +897,7 @@ private static Vector128 SimdShuffle(Vector128 left, Vector128 } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe bool Vector128Decode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + private static unsafe void Vector128Decode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) { Debug.Assert((Ssse3.IsSupported || AdvSimd.Arm64.IsSupported) && BitConverter.IsLittleEndian); @@ -881,7 +1007,7 @@ private static unsafe bool Vector128Decode(ref byte* srcBytes, ref byte* destByt srcBytes = src; destBytes = dest; - return false; + return; } Vector128 eq2F = Vector128.Equals(str, mask2F); @@ -946,11 +1072,33 @@ private static unsafe bool Vector128Decode(ref byte* srcBytes, ref byte* destByt srcBytes = src; destBytes = dest; - return true; + return; } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe int Decode(uint t0, uint t1, uint t2, uint t3, ref sbyte decodingMap) + private static unsafe int Decode(byte* encodedBytes, ref sbyte decodingMap) + { + uint t0 = encodedBytes[0]; + uint t1 = encodedBytes[1]; + uint t2 = encodedBytes[2]; + uint t3 = encodedBytes[3]; + + return Decode(t0, t1, t2, t3, ref decodingMap); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int Decode(ReadOnlySpan encodedBytes, ref sbyte decodingMap) + { + uint t0 = encodedBytes[0]; + uint t1 = encodedBytes[1]; + uint t2 = encodedBytes[2]; + uint t3 = encodedBytes[3]; + + return Decode(t0, t1, t2, t3, ref decodingMap); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int Decode(uint t0, uint t1, uint t2, uint t3, ref sbyte decodingMap) { int i0 = Unsafe.Add(ref decodingMap, (IntPtr)t0); int i1 = Unsafe.Add(ref decodingMap, (IntPtr)t1); @@ -977,37 +1125,26 @@ private static unsafe void WriteThreeLowOrderBytes(byte* destination, int value) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe int GetIndexOfFirstByteToBeIgnored(byte* src) + private static unsafe int GetIndexOfFirstWhitespace(byte* src) { - int firstInvalidIndex = -1; - - if (IsByteToBeIgnored(src[0])) - { - firstInvalidIndex = 0; - } - else if (IsByteToBeIgnored(src[1])) + for (int i = 0; i < 4; i++) { - firstInvalidIndex = 1; - } - else if (IsByteToBeIgnored(src[2])) - { - firstInvalidIndex = 2; - } - else if (IsByteToBeIgnored(src[3])) - { - firstInvalidIndex = 3; + if (IsByteWhitespace(src[i])) + { + return i; + } } - return firstInvalidIndex; + return -1; } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool IsBlockEndBytesPadding(byte secondToLastByte, byte lastByte) => - lastByte == EncodingPad - || secondToLastByte == EncodingPad && lastByte == EncodingPad; + private static bool IsBlockEndBytesPadding(Span buffer) => + buffer[3] == EncodingPad + || buffer[2] == EncodingPad && buffer[3] == EncodingPad; [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool IsByteToBeIgnored(int charByte) + private static bool IsByteWhitespace(int charByte) { if (charByte < 32) { diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs index fac6df6a870dbe..be0a317453b3f6 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs @@ -21,7 +21,7 @@ public static partial class Base64 /// The input span which contains UTF-8 encoded text in base64 that needs to be validated. /// The maximum length (in bytes) if you were to decode the base 64 encoded text within a byte span. /// true if is decodable; otherwise, false. - public static unsafe bool IsValid(ReadOnlySpan base64Text, out int decodedLength) + public static bool IsValid(ReadOnlySpan base64Text, out int decodedLength) { if (base64Text.IsEmpty) { @@ -38,7 +38,7 @@ public static unsafe bool IsValid(ReadOnlySpan base64Text, out int decoded while (indexOfPaddingInvalidOrWhitespace >= 0) { char charToValidate = base64Text[indexOfPaddingInvalidOrWhitespace]; - if (IsByteToBeIgnored(charToValidate)) + if (IsByteWhitespace(charToValidate)) { // Chars to be ignored (e,g, whitespace...) should not count towards the length. length--; @@ -131,7 +131,7 @@ public static bool IsValid(ReadOnlySpan base64TextUtf8, out int decodedLen while (indexOfPaddingInvalidOrWhitespace >= 0) { byte byteToValidate = base64TextUtf8[indexOfPaddingInvalidOrWhitespace]; - if (IsByteToBeIgnored(byteToValidate)) + if (IsByteWhitespace(byteToValidate)) { // Bytes to be ignored (e,g, whitespace...) should not count towards the length. length--; From dca7ee6d6bdc28d8d2af0a111731f8584baaba0f Mon Sep 17 00:00:00 2001 From: Heath Baron-Morgan Date: Fri, 23 Dec 2022 17:09:31 -0800 Subject: [PATCH 08/17] Address PR feedback: Reuse existing decoding method with whitespace --- .../tests/Base64/Base64DecoderUnitTests.cs | 7 +- .../tests/Base64/Base64TestBase.cs | 48 +- .../tests/Base64/Base64ValidationUnitTests.cs | 6 +- .../src/System/Buffers/Text/Base64Decoder.cs | 459 ++++++------------ .../System/Buffers/Text/Base64Validator.cs | 4 +- 5 files changed, 170 insertions(+), 354 deletions(-) diff --git a/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs b/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs index 16c48a52dda12a..7df0ad93aa6f30 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs @@ -661,7 +661,7 @@ public void DecodeInPlaceInvalidBytesPadding() [Theory] [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] - public void BasicDecodingIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes, int expectedBytesConsumed) + public void BasicDecodingIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes) { byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored); byte[] resultBytes = new byte[5]; @@ -671,7 +671,7 @@ public void BasicDecodingIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf byte[] stringBytes = Convert.FromBase64String(utf8WithCharsToBeIgnored); Assert.Equal(OperationStatus.Done, result); - Assert.Equal(expectedBytesConsumed, bytesConsumed); + Assert.Equal(utf8WithCharsToBeIgnored.Length, bytesConsumed); Assert.Equal(expectedBytes.Length, bytesWritten); Assert.True(expectedBytes.SequenceEqual(resultBytes)); Assert.True(stringBytes.SequenceEqual(resultBytes)); @@ -679,9 +679,8 @@ public void BasicDecodingIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf [Theory] [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] - public void DecodeInPlaceIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes, int expectedBytesConsumed) + public void DecodeInPlaceIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes) { - _ = expectedBytesConsumed; Span utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored); OperationStatus result = Base64.DecodeFromUtf8InPlace(utf8BytesWithByteToBeIgnored, out int bytesWritten); Span bytesOverwritten = utf8BytesWithByteToBeIgnored.Slice(0, bytesWritten); diff --git a/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs b/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs index 3b1e06e13c2b23..882db3026722ea 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs @@ -24,20 +24,20 @@ public static IEnumerable ValidBase64Strings_WithCharsThatMustBeIgnore // One will have 1 char, another will have 3 // Line feed - yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(9), 1), utf8Bytes, utf8Bytes.Length + 4 }; - yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(9), 3), utf8Bytes, utf8Bytes.Length + 6 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(9), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(9), 3), utf8Bytes }; // Horizontal tab - yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(10), 1), utf8Bytes, utf8Bytes.Length + 4 }; - yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(10), 3), utf8Bytes, utf8Bytes.Length + 6 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(10), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(10), 3), utf8Bytes }; // Carriage return - yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(13), 1), utf8Bytes, utf8Bytes.Length + 4 }; - yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(13), 3), utf8Bytes, utf8Bytes.Length + 6 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(13), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(13), 3), utf8Bytes }; // Space - yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(32), 1), utf8Bytes, utf8Bytes.Length + 4 }; - yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(32), 3), utf8Bytes, utf8Bytes.Length + 6 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(32), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(32), 3), utf8Bytes }; string GetBase64StringWithPassedCharInsertedInTheMiddle(char charToInsert, int numberOfTimesToInsert) => $"{firstSegment}{new string(charToInsert, numberOfTimesToInsert)}{secondSegment}"; @@ -45,20 +45,20 @@ public static IEnumerable ValidBase64Strings_WithCharsThatMustBeIgnore // One will have 1 char, another will have 3 // Line feed - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(9), 1), utf8Bytes, utf8Bytes.Length + 4 }; - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(9), 3), utf8Bytes, utf8Bytes.Length + 6 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(9), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(9), 3), utf8Bytes }; // Horizontal tab - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(10), 1), utf8Bytes, utf8Bytes.Length + 4 }; - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(10), 3), utf8Bytes, utf8Bytes.Length + 6 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(10), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(10), 3), utf8Bytes }; // Carriage return - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(13), 1), utf8Bytes, utf8Bytes.Length + 4 }; - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(13), 3), utf8Bytes, utf8Bytes.Length + 6 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(13), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(13), 3), utf8Bytes }; // Space - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(32), 1), utf8Bytes, utf8Bytes.Length + 4 }; - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(32), 3), utf8Bytes, utf8Bytes.Length + 6 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(32), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(32), 3), utf8Bytes }; string GetBase64StringWithPassedCharInsertedAtTheStart(char charToInsert, int numberOfTimesToInsert) => $"{new string(charToInsert, numberOfTimesToInsert)}{firstSegment}{secondSegment}"; @@ -67,20 +67,20 @@ public static IEnumerable ValidBase64Strings_WithCharsThatMustBeIgnore // Whitespace after end/padding is not included in consumed bytes // Line feed - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(9), 1), utf8Bytes, utf8Bytes.Length + 3 }; - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(9), 3), utf8Bytes, utf8Bytes.Length + 3 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(9), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(9), 3), utf8Bytes }; // Horizontal tab - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(10), 1), utf8Bytes, utf8Bytes.Length + 3 }; - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(10), 3), utf8Bytes, utf8Bytes.Length + 3 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(10), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(10), 3), utf8Bytes }; // Carriage return - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(13), 1), utf8Bytes, utf8Bytes.Length + 3 }; - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(13), 3), utf8Bytes, utf8Bytes.Length + 3 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(13), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(13), 3), utf8Bytes }; // Space - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(32), 1), utf8Bytes, utf8Bytes.Length + 3 }; - yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(32), 3), utf8Bytes, utf8Bytes.Length + 3 }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(32), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(32), 3), utf8Bytes }; string GetBase64StringWithPassedCharInsertedAtTheEnd(char charToInsert, int numberOfTimesToInsert) => $"{firstSegment}{secondSegment}{new string(charToInsert, numberOfTimesToInsert)}"; } diff --git a/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs index f8ffba91639ac1..abccac736bb217 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs @@ -147,9 +147,8 @@ public void ValidateGuidChars() [Theory] [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] - public void ValidateBytesIgnoresCharsToBeIgnoredBytes(string utf8WithByteToBeIgnored, byte[] expectedBytes, int expectedBytesConsumed) + public void ValidateBytesIgnoresCharsToBeIgnoredBytes(string utf8WithByteToBeIgnored, byte[] expectedBytes) { - _ = expectedBytesConsumed; byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); @@ -159,9 +158,8 @@ public void ValidateBytesIgnoresCharsToBeIgnoredBytes(string utf8WithByteToBeIgn [Theory] [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] - public void ValidateBytesIgnoresCharsToBeIgnoredChars(string utf8WithByteToBeIgnored, byte[] expectedBytes, int expectedBytesConsumed) + public void ValidateBytesIgnoresCharsToBeIgnoredChars(string utf8WithByteToBeIgnored, byte[] expectedBytes) { - _ = expectedBytesConsumed; ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray(); Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs index 88ace1273a62ca..7b72624a3aa645 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs @@ -46,7 +46,7 @@ public static OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Span bytesConsumed += localConsumed; bytesWritten += localWritten; - if (status is OperationStatus.Done or OperationStatus.NeedMoreData) + if (status is not OperationStatus.InvalidData or OperationStatus.DestinationTooSmall) { break; } @@ -54,10 +54,11 @@ public static OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Span utf8 = utf8.Slice(localConsumed); bytes = bytes.Slice(localWritten); - if (AllBytesRemainingAreWhitespace(utf8, out localConsumed)) + if (!TrySkipWhitespace(utf8, out localConsumed)) { if (localConsumed > 0) { + bytesConsumed += localConsumed; status = OperationStatus.Done; } @@ -69,9 +70,7 @@ public static OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Span { // First char isn't whitespace, but we didn't consume anything, // thus the input may have whitespace anywhere in between. So fall back to block-wise decoding. - status = DecodeWithWhitespaceBlockwise(utf8, bytes, out localConsumed, out localWritten, isFinalBlock); - bytesConsumed += localConsumed; - bytesWritten += localWritten; + status = DecodeWithWhitespaceBlockwise(utf8, bytes, out bytesConsumed, out bytesWritten, isFinalBlock); break; } @@ -119,7 +118,7 @@ public static OperationStatus DecodeFromUtf8InPlace(Span buffer, out int b if (status is OperationStatus.InvalidData or OperationStatus.DestinationTooSmall) { // The input may have whitespace, attempt to decode while ignoring whitespace. - status = DecodeWithWhitespaceFromUtf8InPlace(buffer, ref bytesWritten, sourceIndex); + status = DecodeWithWhitespaceFromUtf8InPlace(buffer, ref bytesWritten, (int)sourceIndex); } return status; @@ -295,7 +294,7 @@ private static unsafe OperationStatus DecodeFromUtf8Core(ReadOnlySpan utf8 if (srcLength != utf8.Length) goto InvalidDataExit; - DoneExit: + DoneExit: bytesConsumed = (int)(src - srcBytes); bytesWritten = (int)(dest - destBytes); return OperationStatus.Done; @@ -321,194 +320,122 @@ private static unsafe OperationStatus DecodeFromUtf8Core(ReadOnlySpan utf8 } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool AllBytesRemainingAreWhitespace(ReadOnlySpan encoded, out int consumed) + private static bool TrySkipWhitespace(ReadOnlySpan encoded, out int consumed) { - consumed = 0; + int i = 0; - while (consumed < encoded.Length) + for (; i < encoded.Length; ++i) { - if (!IsByteWhitespace(encoded[consumed])) + if (!IsWhitespace(encoded[i])) { - return false; + consumed = i; + return true; } - - consumed++; } - return true; + consumed = i; + return false; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe OperationStatus DecodeWithWhitespaceBlockwise(ReadOnlySpan utf8, Span bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) + private static OperationStatus DecodeWithWhitespaceBlockwise(ReadOnlySpan utf8, Span bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) { - fixed (byte* srcBytes = &MemoryMarshal.GetReference(utf8)) - fixed (byte* destBytes = &MemoryMarshal.GetReference(bytes)) - { - int srcLength = utf8.Length; - int destLength = bytes.Length; - byte* src = srcBytes; - byte* dest = destBytes; - byte* destEnd = dest + (uint)destLength; - byte* srcEnd = srcBytes + (uint)srcLength; + Unsafe.SkipInit(out bytesConsumed); + Unsafe.SkipInit(out bytesWritten); - ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap); + const int BlockSize = 4; + Span buffer = stackalloc byte[BlockSize]; + OperationStatus status = OperationStatus.Done; - OperationStatus lastBlockStatus = IgnoreWhitespaceAndConsumeValidBytesUntilInvalidBlock(ref src, srcEnd, ref dest, destEnd, ref decodingMap, isFinalBlock, out int pendingSrcIncrement); + while (!utf8.IsEmpty) + { + int encodedIdx = 0; + int bufferIdx = 0; + int skipped = 0; - // Assess the end block and bytes beyond it. - if (lastBlockStatus == OperationStatus.Done) + for (; encodedIdx < utf8.Length && (uint)bufferIdx < (uint)buffer.Length; ++encodedIdx) { - // Check if there are bytes that should not be ignored beyond the expected end of the valid input range. - while (src + pendingSrcIncrement <= srcEnd - 1) + if (IsWhitespace(utf8[encodedIdx])) { - if (!IsByteWhitespace(src[pendingSrcIncrement++])) - { - lastBlockStatus = OperationStatus.InvalidData; - break; - } + skipped++; + } + else + { + buffer[bufferIdx] = utf8[encodedIdx]; + bufferIdx++; } } - bytesConsumed = (int)(src - srcBytes); - bytesWritten = (int)(dest - destBytes); - return lastBlockStatus; - } - } - - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe OperationStatus IgnoreWhitespaceAndConsumeValidBytesUntilInvalidBlock(ref byte* src, byte* srcEnd, ref byte* dest, byte* destEnd, ref sbyte decodingMap, bool isFinalBlock, out int pendingSrcIncrement) - { - pendingSrcIncrement = 0; - Span buffer = stackalloc byte[4]; + utf8 = utf8.Slice(encodedIdx); + bytesConsumed += skipped; - while (src <= srcEnd - 4) - { - // The default increment will be 4 if no bytes that require ignoring are encountered. - pendingSrcIncrement = 4; + if (bufferIdx == 0) + { + continue; + } - buffer[0] = src[0]; - buffer[1] = src[1]; - buffer[2] = src[2]; - buffer[3] = src[3]; + bool hasAnotherBlock = utf8.Length >= BlockSize && bufferIdx == BlockSize; + bool localIsFinalBlock = !hasAnotherBlock; - int result = Decode(buffer, ref decodingMap); - if (result < 0) + // If this block contains padding and there's another block, then only whitespace may follow for being valid. + if (hasAnotherBlock) { - // Start searching for valid bytes from the first invalid byte. - int indexOfFirstWhitespace = GetIndexOfFirstWhitespace(src); - if (indexOfFirstWhitespace >= 0) + int paddingCount = GetPaddingCount(ref buffer[^1]); + if (paddingCount > 0) { - // Try to fill remainder of the buffer with valid bytes. - pendingSrcIncrement = indexOfFirstWhitespace; + hasAnotherBlock = false; + localIsFinalBlock = true; + } + } - for (int currentBlockIndex = indexOfFirstWhitespace; currentBlockIndex < 4; currentBlockIndex++) - { - while (src < srcEnd - pendingSrcIncrement - && IsByteWhitespace(src[pendingSrcIncrement])) - { - pendingSrcIncrement++; - } - - if (src >= srcEnd - pendingSrcIncrement) - { - // Insufficient valid bytes found. - return OperationStatus.Done; - } - - buffer[currentBlockIndex] = src[pendingSrcIncrement]; - pendingSrcIncrement++; - } + if (localIsFinalBlock && !isFinalBlock) + { + localIsFinalBlock = false; + } - result = Decode(buffer, ref decodingMap); - } - else if (!IsBlockEndBytesPadding(buffer)) - { - // Found invalid non-whitespace bytes. - return OperationStatus.InvalidData; - } + status = DecodeFromUtf8Core(buffer.Slice(0, bufferIdx), bytes, out int localConsumed, out int localWritten, localIsFinalBlock); + bytesConsumed += localConsumed; + bytesWritten += localWritten; - // Check to see if parsing failed due to padding. There could be 1 or 2 padding chars. - if (IsBlockEndBytesPadding(buffer)) + if (status != OperationStatus.Done) + { + return status; + } + + // The remaining data must all be whitespace in order to be valid. + if (!hasAnotherBlock) + { + for (int i = 0; i < utf8.Length; ++i) { - // If isFinalBlock is false, padding is treated as invalid. - if (!isFinalBlock) + if (!IsWhitespace(utf8[i])) { - return OperationStatus.InvalidData; - } + // Revert previous dest increment, since an invalid state followed. + bytesConsumed -= localConsumed; + bytesWritten -= localWritten; - int indexOfBytesAfterPadding = pendingSrcIncrement; - while (src <= srcEnd - indexOfBytesAfterPadding - 1) - { - if (!IsByteWhitespace(src[indexOfBytesAfterPadding++])) - { - // Only whitespace can be after padding bytes. - // Exit before consuming more bytes. - return OperationStatus.InvalidData; - } + return OperationStatus.InvalidData; } - return TryWriteFinalBlockWithPadding(ref dest, destEnd, ref decodingMap, buffer); + bytesConsumed++; } - } - if (dest > destEnd - 3) - { - return OperationStatus.DestinationTooSmall; + break; } - WriteThreeLowOrderBytes(dest, result); - src += pendingSrcIncrement; - dest += 3; + bytes = bytes.Slice(localWritten); } - return OperationStatus.Done; + return status; } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe OperationStatus TryWriteFinalBlockWithPadding(ref byte* dest, byte* destEnd, ref sbyte decodingMap, Span buffer) + private static int GetPaddingCount(ref byte ptrToLastElement) { - byte b0 = buffer[0]; - byte b1 = buffer[1]; - byte b2 = buffer[2]; - - int i0 = Unsafe.Add(ref decodingMap, (IntPtr)b0); - int i1 = Unsafe.Add(ref decodingMap, (IntPtr)b1); + int padding = 0; - i0 <<= 18; - i1 <<= 12; - - i0 |= i1; + if (ptrToLastElement == EncodingPad) padding++; + if (Unsafe.Subtract(ref ptrToLastElement, 1) == EncodingPad) padding++; - if (b2 != EncodingPad) - { - int i2 = Unsafe.Add(ref decodingMap, (IntPtr)b2); - - i2 <<= 6; - - i0 |= i2; - - if (i0 < 0) - return OperationStatus.InvalidData; - if (dest + 2 > destEnd) - return OperationStatus.DestinationTooSmall; - - dest[0] = (byte)(i0 >> 16); - dest[1] = (byte)(i0 >> 8); - dest += 2; - } - else - { - if (i0 < 0) - return OperationStatus.InvalidData; - if (dest + 1 > destEnd) - return OperationStatus.DestinationTooSmall; - - dest[0] = (byte)(i0 >> 16); - dest += 1; - } - - return OperationStatus.Done; + return padding; } /// @@ -611,138 +538,71 @@ private static unsafe OperationStatus DecodeFromUtf8InPlaceCore(Span buffe } } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe OperationStatus DecodeWithWhitespaceFromUtf8InPlace(Span inputBuffer, ref int destIndex, uint sourceIndex) + private static OperationStatus DecodeWithWhitespaceFromUtf8InPlace(Span utf8, ref int destIndex, int sourceIndex) { - fixed (byte* bufferBytes = &MemoryMarshal.GetReference(inputBuffer)) - { - int bufferLength = inputBuffer.Length; - Span bufferWithoutWhitespace = stackalloc byte[4]; + const int BlockSize = 4; + int length = utf8.Length; + Span buffer = stackalloc byte[BlockSize]; - ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap); - - while (sourceIndex <= bufferLength - 4) - { - // The default increment will be 4 if no bytes that require ignoring are encountered. - uint nextSrcIndex = sourceIndex + 4; + OperationStatus status = OperationStatus.Done; + int localDestIndex = destIndex; + bool hasPaddingBeenProcessed = false; + int localBytesWritten = 0; - bufferWithoutWhitespace[0] = bufferBytes[sourceIndex]; - bufferWithoutWhitespace[1] = bufferBytes[sourceIndex + 1]; - bufferWithoutWhitespace[2] = bufferBytes[sourceIndex + 2]; - bufferWithoutWhitespace[3] = bufferBytes[sourceIndex + 3]; + while (sourceIndex < length) + { + int bufferIdx = 0; - int result = Decode(bufferWithoutWhitespace, ref decodingMap); - if (result < 0) + while (sourceIndex < length + && bufferIdx < BlockSize) + { + if (!IsWhitespace(utf8[sourceIndex])) { - // Start searching for valid bytes from the first invalid byte. - int indexOfFirstWhitespace = GetIndexOfFirstWhitespace(bufferBytes + sourceIndex); - if (indexOfFirstWhitespace >= 0) - { - // Try to fill remainder of the buffer with valid bytes. - nextSrcIndex = (uint)indexOfFirstWhitespace + sourceIndex; - - for (int currentBlockIndex = indexOfFirstWhitespace; currentBlockIndex < 4; currentBlockIndex++) - { - while (nextSrcIndex <= bufferLength - 1 - && IsByteWhitespace(bufferBytes[nextSrcIndex])) - { - nextSrcIndex++; - } - - if (nextSrcIndex > bufferLength - 1) - { - return OperationStatus.Done; - } - - bufferWithoutWhitespace[currentBlockIndex] = bufferBytes[nextSrcIndex]; - nextSrcIndex++; - } - - result = Decode(bufferWithoutWhitespace, ref decodingMap); - } - else if (!IsBlockEndBytesPadding(bufferWithoutWhitespace)) - { - // Found invalid non-whitespace bytes. - return OperationStatus.InvalidData; - } - - // Check to see if parsing failed due to padding. There could be 1 or 2 padding chars. - if (IsBlockEndBytesPadding(bufferWithoutWhitespace)) - { - uint indexOfBytesAfterPadding = sourceIndex + nextSrcIndex; - while (indexOfBytesAfterPadding + 1 <= bufferLength - 1) - { - if (!IsByteWhitespace(bufferBytes[indexOfBytesAfterPadding++])) - { - // Only whitespace can be after padding bytes. - // Exit before writing more bytes. - return OperationStatus.InvalidData; - } - } - - return TryWriteFinalBlockWithPadding(bufferBytes, ref decodingMap, ref destIndex, bufferWithoutWhitespace); - } + buffer[bufferIdx] = utf8[sourceIndex]; + bufferIdx++; } - WriteThreeLowOrderBytes(bufferBytes + destIndex, result); - destIndex += 3; - sourceIndex = nextSrcIndex; + sourceIndex++; } - // Check if there are any bytes that should not be ignored after the last valid block size. - while (sourceIndex <= bufferLength - 1) + if (bufferIdx == 0) { - if (!IsByteWhitespace(bufferBytes[sourceIndex++])) - { - return OperationStatus.InvalidData; - } + continue; } - return OperationStatus.Done; - } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe OperationStatus TryWriteFinalBlockWithPadding(byte* dest, ref sbyte decodingMap, ref int destIndex, Span buffer) - { - byte b0 = buffer[0]; - byte b1 = buffer[1]; - byte b2 = buffer[2]; - - // Try read final block with padding - int i0 = Unsafe.Add(ref decodingMap, (IntPtr)b0); - int i1 = Unsafe.Add(ref decodingMap, (IntPtr)b1); - - i0 <<= 18; - i1 <<= 12; - - i0 |= i1; - - if (b2 != EncodingPad) - { - int i2 = Unsafe.Add(ref decodingMap, (IntPtr)b2); - - i2 <<= 6; + if (bufferIdx != 4) + { + status = OperationStatus.InvalidData; + break; + } - i0 |= i2; + if (hasPaddingBeenProcessed) + { + // Padding has already been processed, a new valid block cannot be processed. + // Revert previous dest increment, since an invalid state followed. + localDestIndex -= localBytesWritten; + status = OperationStatus.InvalidData; + break; + } - if (i0 < 0) - return OperationStatus.InvalidData; + status = DecodeFromUtf8InPlaceCore(buffer, out localBytesWritten, out _); + localDestIndex += localBytesWritten; + hasPaddingBeenProcessed = localBytesWritten < 3; - dest[destIndex] = (byte)(i0 >> 16); - dest[destIndex + 1] = (byte)(i0 >> 8); - destIndex += 2; - } - else - { - if (i0 < 0) - return OperationStatus.InvalidData; + if (status != OperationStatus.Done) + { + return status; + } - dest[destIndex] = (byte)(i0 >> 16); - destIndex += 1; + // Write result to source span in place. + for (int i = 0; i < localBytesWritten; i++) + { + utf8[localDestIndex - localBytesWritten + i] = buffer[i]; + } } - return OperationStatus.Done; + destIndex = localDestIndex; + return status; } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -826,13 +686,7 @@ private static unsafe void Avx2Decode(ref byte* srcBytes, ref byte* destBytes, b Vector256 lo = Avx2.Shuffle(lutLo, loNibbles); if (!Avx.TestZ(lo, hi)) - { - // Record current progress - srcBytes = src; - destBytes = dest; - - return; - } + break; Vector256 eq2F = Avx2.CompareEqual(str, mask2F); Vector256 shift = Avx2.Shuffle(lutShift, Avx2.Add(eq2F, hiNibbles)); @@ -876,8 +730,6 @@ private static unsafe void Avx2Decode(ref byte* srcBytes, ref byte* destBytes, b srcBytes = src; destBytes = dest; - - return; } // This can be replaced once https://github.com/dotnet/runtime/issues/63331 is implemented. @@ -1002,13 +854,7 @@ private static unsafe void Vector128Decode(ref byte* srcBytes, ref byte* destByt // Check for invalid input: if any "and" values from lo and hi are not zero, // fall back on bytewise code to do error checking and reporting: if ((lo & hi) != Vector128.Zero) - { - // Record current progress - srcBytes = src; - destBytes = dest; - - return; - } + break; Vector128 eq2F = Vector128.Equals(str, mask2F); Vector128 shift = SimdShuffle(lutShift.AsByte(), (eq2F + hiNibbles), mask8F); @@ -1071,8 +917,6 @@ private static unsafe void Vector128Decode(ref byte* srcBytes, ref byte* destByt srcBytes = src; destBytes = dest; - - return; } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -1083,23 +927,6 @@ private static unsafe int Decode(byte* encodedBytes, ref sbyte decodingMap) uint t2 = encodedBytes[2]; uint t3 = encodedBytes[3]; - return Decode(t0, t1, t2, t3, ref decodingMap); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static int Decode(ReadOnlySpan encodedBytes, ref sbyte decodingMap) - { - uint t0 = encodedBytes[0]; - uint t1 = encodedBytes[1]; - uint t2 = encodedBytes[2]; - uint t3 = encodedBytes[3]; - - return Decode(t0, t1, t2, t3, ref decodingMap); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static int Decode(uint t0, uint t1, uint t2, uint t3, ref sbyte decodingMap) - { int i0 = Unsafe.Add(ref decodingMap, (IntPtr)t0); int i1 = Unsafe.Add(ref decodingMap, (IntPtr)t1); int i2 = Unsafe.Add(ref decodingMap, (IntPtr)t2); @@ -1125,34 +952,26 @@ private static unsafe void WriteThreeLowOrderBytes(byte* destination, int value) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe int GetIndexOfFirstWhitespace(byte* src) + private static bool IsWhitespace(int value) { - for (int i = 0; i < 4; i++) + if (Environment.Is64BitProcess) { - if (IsByteWhitespace(src[i])) - { - return i; - } - } + const ulong MagicConstant = 0xC800010000000000UL; - return -1; - } + ulong i = (uint)value - '\t'; + ulong shift = MagicConstant << (int)i; + ulong mask = i - 64; - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool IsBlockEndBytesPadding(Span buffer) => - buffer[3] == EncodingPad - || buffer[2] == EncodingPad && buffer[3] == EncodingPad; + return (long)(shift & mask) < 0; + } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool IsByteWhitespace(int charByte) - { - if (charByte < 32) + if (value < 32) { const int BitMask = (1 << 9) | (1 << 10) | (1 << 13); - return ((1 << charByte) & BitMask) != 0; + return ((1 << value) & BitMask) != 0; } - return charByte == 32; + return value == 32; } // Pre-computing this table using a custom string(s_characters) and GenerateDecodingMapAndVerify (found in tests) diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs index be0a317453b3f6..6b72d39e196f99 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs @@ -38,7 +38,7 @@ public static bool IsValid(ReadOnlySpan base64Text, out int decodedLength) while (indexOfPaddingInvalidOrWhitespace >= 0) { char charToValidate = base64Text[indexOfPaddingInvalidOrWhitespace]; - if (IsByteWhitespace(charToValidate)) + if (IsWhitespace(charToValidate)) { // Chars to be ignored (e,g, whitespace...) should not count towards the length. length--; @@ -131,7 +131,7 @@ public static bool IsValid(ReadOnlySpan base64TextUtf8, out int decodedLen while (indexOfPaddingInvalidOrWhitespace >= 0) { byte byteToValidate = base64TextUtf8[indexOfPaddingInvalidOrWhitespace]; - if (IsByteWhitespace(byteToValidate)) + if (IsWhitespace(byteToValidate)) { // Bytes to be ignored (e,g, whitespace...) should not count towards the length. length--; From 261885ca14381b97327b43e522c82b7d301e9eae Mon Sep 17 00:00:00 2001 From: Heath Baron-Morgan Date: Sat, 7 Jan 2023 16:01:18 -0800 Subject: [PATCH 09/17] Address PR feedback: Remove redundant empty buffer check --- .../src/System/Buffers/Text/Base64Decoder.cs | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs index ec001ed181add4..fa16955e17b1d5 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs @@ -130,13 +130,6 @@ public static OperationStatus DecodeFromUtf8InPlace(Span buffer, out int b [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe OperationStatus DecodeFromUtf8Core(ReadOnlySpan utf8, Span bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) { - if (utf8.IsEmpty) - { - bytesConsumed = 0; - bytesWritten = 0; - return OperationStatus.Done; - } - fixed (byte* srcBytes = &MemoryMarshal.GetReference(utf8)) fixed (byte* destBytes = &MemoryMarshal.GetReference(bytes)) { @@ -444,13 +437,6 @@ private static int GetPaddingCount(ref byte ptrToLastElement) [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe OperationStatus DecodeFromUtf8InPlaceCore(Span buffer, out int bytesWritten, out uint sourceIndex) { - if (buffer.IsEmpty) - { - bytesWritten = 0; - sourceIndex = 0; - return OperationStatus.Done; - } - fixed (byte* bufferBytes = &MemoryMarshal.GetReference(buffer)) { int bufferLength = buffer.Length; From fb9e8de707c69361e286bf11c401e8d1c7718a86 Mon Sep 17 00:00:00 2001 From: Heath Baron-Morgan Date: Sat, 7 Jan 2023 16:02:11 -0800 Subject: [PATCH 10/17] Address PR Feedback: Add missing magic constant comment --- .../src/System/Buffers/Text/Base64Decoder.cs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs index fa16955e17b1d5..4b57dbb44d6190 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs @@ -942,6 +942,20 @@ private static bool IsWhitespace(int value) { if (Environment.Is64BitProcess) { + // For description see https://github.com/dotnet/runtime/blob/48e74187cb15386c29eedaa046a5ee2c7ddef161/src/libraries/Common/src/System/HexConverter.cs#L314-L330 + /* Constant created with + using System; + string validValues = "\t\n\r "; + ulong magic = 0; + foreach (char c in validValues) + { + int idx = c - '\t'; // lowest value of allowed set + magic |= 1UL << (64 - 1 - idx); + } + Console.WriteLine(magic); + Console.WriteLine($"0x{magic:X16}"); + */ + const ulong MagicConstant = 0xC800010000000000UL; ulong i = (uint)value - '\t'; From b68d36c0ac6e8e8d978bbde5b42cb785ab9420c4 Mon Sep 17 00:00:00 2001 From: Heath Baron-Morgan Date: Wed, 1 Mar 2023 23:34:01 -0800 Subject: [PATCH 11/17] Address PR Feedback: Avoid validation logic duplication --- .../src/System/Buffers/Text/Base64Decoder.cs | 5 +- .../src/System/Buffers/Text/Base64Encoder.cs | 4 +- .../System/Buffers/Text/Base64Validator.cs | 132 ++++++------------ 3 files changed, 46 insertions(+), 95 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs index 4b57dbb44d6190..390dbba29b7a69 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs @@ -70,8 +70,7 @@ public static OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Span { // First char isn't whitespace, but we didn't consume anything, // thus the input may have whitespace anywhere in between. So fall back to block-wise decoding. - status = DecodeWithWhitespaceBlockwise(utf8, bytes, out bytesConsumed, out bytesWritten, isFinalBlock); - break; + return DecodeWithWhitespaceBlockwise(utf8, bytes, out bytesConsumed, out bytesWritten, isFinalBlock); } bytesConsumed += localConsumed; @@ -938,7 +937,7 @@ private static unsafe void WriteThreeLowOrderBytes(byte* destination, int value) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool IsWhitespace(int value) + internal static bool IsWhitespace(int value) { if (Environment.Is64BitProcess) { diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Encoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Encoder.cs index 9a5ff88eb583f1..7c7a3ce5f44e4e 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Encoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Encoder.cs @@ -583,10 +583,10 @@ private static unsafe uint EncodeAndPadTwo(byte* oneByte, ref byte encodingMap) } } - private const uint EncodingPad = '='; // '=', for padding + internal const uint EncodingPad = '='; // '=', for padding private const int MaximumEncodeLength = (int.MaxValue / 4) * 3; // 1610612733 - private static ReadOnlySpan EncodingMap => "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"u8; + internal static ReadOnlySpan EncodingMap => "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"u8; } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs index 6b72d39e196f99..2b9c414152f278 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; namespace System.Buffers.Text { @@ -23,82 +22,7 @@ public static partial class Base64 /// true if is decodable; otherwise, false. public static bool IsValid(ReadOnlySpan base64Text, out int decodedLength) { - if (base64Text.IsEmpty) - { - decodedLength = 0; - return true; - } - - int length = base64Text.Length; - int paddingCount = 0; - - int indexOfPaddingInvalidOrWhitespace = base64Text.IndexOfAnyExcept(validBase64Chars); - if (indexOfPaddingInvalidOrWhitespace >= 0) - { - while (indexOfPaddingInvalidOrWhitespace >= 0) - { - char charToValidate = base64Text[indexOfPaddingInvalidOrWhitespace]; - if (IsWhitespace(charToValidate)) - { - // Chars to be ignored (e,g, whitespace...) should not count towards the length. - length--; - } - else if (charToValidate == EncodingPad) - { - // There can be at most 2 padding chars. - if (paddingCount == 2) - { - decodedLength = 0; - return false; - } - - paddingCount++; - } - else - { - // An invalid char was encountered. - decodedLength = 0; - return false; - } - - if (indexOfPaddingInvalidOrWhitespace == base64Text.Length - 1) - { - // The end of the input has been reached. - break; - } - - // If no padding is found, slice and use IndexOfAnyExcept to look for the next invalid char. - if (paddingCount == 0) - { - indexOfPaddingInvalidOrWhitespace = base64Text - .Slice(indexOfPaddingInvalidOrWhitespace + 1, base64Text.Length - indexOfPaddingInvalidOrWhitespace - 1) - .IndexOfAnyExcept(validBase64Chars) - + indexOfPaddingInvalidOrWhitespace + 1; // Add current index offset. - } - // If padding is already found, simply increment, as the common case might have 2 sequential padding chars. - else - { - indexOfPaddingInvalidOrWhitespace++; - } - } - - // If the invalid chars all consisted of whitespace, the input will be empty. - if (length == 0) - { - decodedLength = 0; - return true; - } - } - - if (length % 4 != 0) - { - decodedLength = 0; - return false; - } - - // Remove padding to get exact length - decodedLength = (int)((uint)length / 4 * 3) - paddingCount; - return true; + return IsValid(base64Text, out decodedLength); } /// @@ -116,27 +40,34 @@ public static bool IsValid(ReadOnlySpan base64Text, out int decodedLength) /// true if is decodable; otherwise, false. public static bool IsValid(ReadOnlySpan base64TextUtf8, out int decodedLength) { - if (base64TextUtf8.IsEmpty) + return IsValid(base64TextUtf8, out decodedLength); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool IsValid(ReadOnlySpan base64Text, out int decodedLength) + where T2 : IBase64Validatable + { + if (base64Text.IsEmpty) { decodedLength = 0; return true; } - int length = base64TextUtf8.Length; + int length = base64Text.Length; int paddingCount = 0; - int indexOfPaddingInvalidOrWhitespace = base64TextUtf8.IndexOfAnyExcept(validBase64Bytes); + int indexOfPaddingInvalidOrWhitespace = T2.IndexOfAnyExcept(base64Text); if (indexOfPaddingInvalidOrWhitespace >= 0) { while (indexOfPaddingInvalidOrWhitespace >= 0) { - byte byteToValidate = base64TextUtf8[indexOfPaddingInvalidOrWhitespace]; - if (IsWhitespace(byteToValidate)) + T charToValidate = base64Text[indexOfPaddingInvalidOrWhitespace]; + if (T2.IsWhitespace(charToValidate)) { - // Bytes to be ignored (e,g, whitespace...) should not count towards the length. + // Chars to be ignored (e,g, whitespace...) should not count towards the length. length--; } - else if (byteToValidate == EncodingPad) + else if (T2.IsEncodingPad(charToValidate)) { // There can be at most 2 padding chars. if (paddingCount == 2) @@ -154,7 +85,7 @@ public static bool IsValid(ReadOnlySpan base64TextUtf8, out int decodedLen return false; } - if (indexOfPaddingInvalidOrWhitespace == base64TextUtf8.Length - 1) + if (indexOfPaddingInvalidOrWhitespace == base64Text.Length - 1) { // The end of the input has been reached. break; @@ -163,9 +94,9 @@ public static bool IsValid(ReadOnlySpan base64TextUtf8, out int decodedLen // If no padding is found, slice and use IndexOfAnyExcept to look for the next invalid char. if (paddingCount == 0) { - indexOfPaddingInvalidOrWhitespace = base64TextUtf8 - .Slice(indexOfPaddingInvalidOrWhitespace + 1, base64TextUtf8.Length - indexOfPaddingInvalidOrWhitespace - 1) - .IndexOfAnyExcept(validBase64Bytes) + ReadOnlySpan slicedSpan = base64Text.Slice(indexOfPaddingInvalidOrWhitespace + 1); + indexOfPaddingInvalidOrWhitespace = + T2.IndexOfAnyExcept(slicedSpan) + indexOfPaddingInvalidOrWhitespace + 1; // Add current index offset. } // If padding is already found, simply increment, as the common case might have 2 sequential padding chars. @@ -194,8 +125,29 @@ public static bool IsValid(ReadOnlySpan base64TextUtf8, out int decodedLen return true; } - private static readonly IndexOfAnyValues validBase64Bytes = IndexOfAnyValues.Create(EncodingMap); + internal interface IBase64Validatable + { + static abstract int IndexOfAnyExcept(ReadOnlySpan span); + static abstract bool IsWhitespace(T value); + static abstract bool IsEncodingPad(T value); + } + + internal readonly struct Base64CharValidationHandler : IBase64Validatable + { + private static readonly IndexOfAnyValues s_validBase64Chars = IndexOfAnyValues.Create("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"); + + public static int IndexOfAnyExcept(ReadOnlySpan span) => span.IndexOfAnyExcept(s_validBase64Chars); + public static bool IsWhitespace(char value) => Base64.IsWhitespace(value); + public static bool IsEncodingPad(char value) => value == Base64.EncodingPad; + } + + internal readonly struct Base64ByteValidationHandler : IBase64Validatable + { + private static readonly IndexOfAnyValues s_validBase64Chars = IndexOfAnyValues.Create(Base64.EncodingMap); - private static readonly IndexOfAnyValues validBase64Chars = IndexOfAnyValues.Create("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"); + public static int IndexOfAnyExcept(ReadOnlySpan span) => span.IndexOfAnyExcept(s_validBase64Chars); + public static bool IsWhitespace(byte value) => Base64.IsWhitespace(value); + public static bool IsEncodingPad(byte value) => value == Base64.EncodingPad; + } } } From 745fd411ed20f3e7b20ef8c2560f0a2395efa68b Mon Sep 17 00:00:00 2001 From: Heath Baron-Morgan Date: Thu, 2 Mar 2023 20:48:58 -0800 Subject: [PATCH 12/17] Throw Base64FormatException when whitespace should not be ignored This fixes a failing test. --- .../System/Security/Cryptography/Base64Transforms.cs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/Base64Transforms.cs b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/Base64Transforms.cs index 126f758005f2b0..a55b86b90ff250 100644 --- a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/Base64Transforms.cs +++ b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/Base64Transforms.cs @@ -215,12 +215,16 @@ public byte[] TransformFinalBlock(byte[] inputBuffer, int inputOffset, int input private Span GetTempBuffer(Span inputBuffer, Span tmpBuffer) { - if (_whitespaces == FromBase64TransformMode.DoNotIgnoreWhiteSpaces) + Span tempBuffer = DiscardWhiteSpaces(inputBuffer, tmpBuffer); + + if (_whitespaces == FromBase64TransformMode.DoNotIgnoreWhiteSpaces + && inputBuffer.Length != tempBuffer.Length) { - return inputBuffer; + // Base64.DecodeFromUtf8() does not return OperationStatus.InvalidData when decoding whitespace. + ThrowHelper.ThrowBase64FormatException(); } - return DiscardWhiteSpaces(inputBuffer, tmpBuffer); + return tempBuffer; } [MethodImpl(MethodImplOptions.AggressiveInlining)] From b35ce124db5bf0badadac7b429e9e7c39c0aac99 Mon Sep 17 00:00:00 2001 From: Heath Baron-Morgan Date: Thu, 9 Mar 2023 22:12:50 -0800 Subject: [PATCH 13/17] Adress PR feedback: Improve naming of Base64Validator.cs internals --- .../System/Buffers/Text/Base64Validator.cs | 34 +++++++++---------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs index 2b9c414152f278..47b23c7988b1d3 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs @@ -12,7 +12,8 @@ public static partial class Base64 /// /// The input span which contains UTF-8 encoded text in base64 that needs to be validated. /// true if is decodable; otherwise, false. - public static bool IsValid(ReadOnlySpan base64Text) => IsValid(base64Text, out int _); + public static bool IsValid(ReadOnlySpan base64Text) => + IsValid(base64Text, out _); /// /// Validates the span of UTF-8 encoded text represented as base64 into binary data. @@ -20,17 +21,16 @@ public static partial class Base64 /// The input span which contains UTF-8 encoded text in base64 that needs to be validated. /// The maximum length (in bytes) if you were to decode the base 64 encoded text within a byte span. /// true if is decodable; otherwise, false. - public static bool IsValid(ReadOnlySpan base64Text, out int decodedLength) - { - return IsValid(base64Text, out decodedLength); - } + public static bool IsValid(ReadOnlySpan base64Text, out int decodedLength) => + IsValid(base64Text, out decodedLength); /// /// Validates the span of UTF-8 encoded text represented as base64 into binary data. /// /// The input span which contains UTF-8 encoded text in base64 that needs to be validated. /// true if is decodable; otherwise, false. - public static bool IsValid(ReadOnlySpan base64TextUtf8) => IsValid(base64TextUtf8, out int _); + public static bool IsValid(ReadOnlySpan base64TextUtf8) => + IsValid(base64TextUtf8, out _); /// /// Validates the span of UTF-8 encoded text represented as base64 into binary data. @@ -38,14 +38,12 @@ public static bool IsValid(ReadOnlySpan base64Text, out int decodedLength) /// The input span which contains UTF-8 encoded text in base64 that needs to be validated. /// The maximum length (in bytes) if you were to decode the base 64 encoded text within a byte span. /// true if is decodable; otherwise, false. - public static bool IsValid(ReadOnlySpan base64TextUtf8, out int decodedLength) - { - return IsValid(base64TextUtf8, out decodedLength); - } + public static bool IsValid(ReadOnlySpan base64TextUtf8, out int decodedLength) => + IsValid(base64TextUtf8, out decodedLength); [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool IsValid(ReadOnlySpan base64Text, out int decodedLength) - where T2 : IBase64Validatable + private static bool IsValid(ReadOnlySpan base64Text, out int decodedLength) + where TBase64Validatable : IBase64Validatable { if (base64Text.IsEmpty) { @@ -56,18 +54,18 @@ private static bool IsValid(ReadOnlySpan base64Text, out int decodedLe int length = base64Text.Length; int paddingCount = 0; - int indexOfPaddingInvalidOrWhitespace = T2.IndexOfAnyExcept(base64Text); + int indexOfPaddingInvalidOrWhitespace = TBase64Validatable.IndexOfAnyExcept(base64Text); if (indexOfPaddingInvalidOrWhitespace >= 0) { while (indexOfPaddingInvalidOrWhitespace >= 0) { T charToValidate = base64Text[indexOfPaddingInvalidOrWhitespace]; - if (T2.IsWhitespace(charToValidate)) + if (TBase64Validatable.IsWhitespace(charToValidate)) { // Chars to be ignored (e,g, whitespace...) should not count towards the length. length--; } - else if (T2.IsEncodingPad(charToValidate)) + else if (TBase64Validatable.IsEncodingPad(charToValidate)) { // There can be at most 2 padding chars. if (paddingCount == 2) @@ -96,7 +94,7 @@ private static bool IsValid(ReadOnlySpan base64Text, out int decodedLe { ReadOnlySpan slicedSpan = base64Text.Slice(indexOfPaddingInvalidOrWhitespace + 1); indexOfPaddingInvalidOrWhitespace = - T2.IndexOfAnyExcept(slicedSpan) + TBase64Validatable.IndexOfAnyExcept(slicedSpan) + indexOfPaddingInvalidOrWhitespace + 1; // Add current index offset. } // If padding is already found, simply increment, as the common case might have 2 sequential padding chars. @@ -132,7 +130,7 @@ internal interface IBase64Validatable static abstract bool IsEncodingPad(T value); } - internal readonly struct Base64CharValidationHandler : IBase64Validatable + internal readonly struct Base64CharValidatable : IBase64Validatable { private static readonly IndexOfAnyValues s_validBase64Chars = IndexOfAnyValues.Create("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"); @@ -141,7 +139,7 @@ internal interface IBase64Validatable public static bool IsEncodingPad(char value) => value == Base64.EncodingPad; } - internal readonly struct Base64ByteValidationHandler : IBase64Validatable + internal readonly struct Base64ByteValidatable : IBase64Validatable { private static readonly IndexOfAnyValues s_validBase64Chars = IndexOfAnyValues.Create(Base64.EncodingMap); From 5848058e6327ed5cc139192d93ac8854f5d28844 Mon Sep 17 00:00:00 2001 From: Heath Baron-Morgan Date: Thu, 9 Mar 2023 22:13:40 -0800 Subject: [PATCH 14/17] Adress PR feedback: Add test to demonstrate extra whitespace is counted --- .../tests/Base64/Base64DecoderUnitTests.cs | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs b/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs index c08b9569c25ecd..448d6ccd4ed902 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs @@ -730,5 +730,38 @@ public void DecodingInPlaceWithOnlyCharsToBeIgnored(string utf8WithCharsToBeIgno Assert.Equal(OperationStatus.Done, result); Assert.Equal(0, bytesWritten); } + + [Theory] + [InlineData("AQ==", 4, 1)] + [InlineData("AQ== ", 5, 1)] + [InlineData("AQ== ", 6, 1)] + [InlineData("AQ== ", 7, 1)] + [InlineData("AQ== ", 8, 1)] + [InlineData("AQ== ", 9, 1)] + [InlineData("AQ==\n", 5, 1)] + [InlineData("AQ==\n\n", 6, 1)] + [InlineData("AQ==\n\n\n", 7, 1)] + [InlineData("AQ==\n\n\n\n", 8, 1)] + [InlineData("AQ==\n\n\n\n\n", 9, 1)] + [InlineData("AQ==\t", 5, 1)] + [InlineData("AQ==\t\t", 6, 1)] + [InlineData("AQ==\t\t\t", 7, 1)] + [InlineData("AQ==\t\t\t\t", 8, 1)] + [InlineData("AQ==\t\t\t\t\t", 9, 1)] + [InlineData("AQ==\r", 5, 1)] + [InlineData("AQ==\r\r", 6, 1)] + [InlineData("AQ==\r\r\r", 7, 1)] + [InlineData("AQ==\r\r\r\r", 8, 1)] + [InlineData("AQ==\r\r\r\r\r", 9, 1)] + public void BasicDecodingWithExtraWhitespaceShouldBeCountedInConsumedBytes(string inputString, int expectedConsumed, int expectedWritten) + { + Span source = Encoding.ASCII.GetBytes(inputString); + Span decodedBytes = new byte[Base64.GetMaxDecodedFromUtf8Length(source.Length)]; + + Assert.Equal(OperationStatus.Done, Base64.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, decodedByteCount); + Assert.True(Base64TestHelper.VerifyDecodingCorrectness(expectedConsumed, expectedWritten, source, decodedBytes)); + } } } From 7c022b030e31cf84ef42f4d99338ec716e5658af Mon Sep 17 00:00:00 2001 From: Heath Baron-Morgan Date: Mon, 13 Mar 2023 20:59:22 -0700 Subject: [PATCH 15/17] Address PR feedback: avoid bound-check --- .../src/System/Buffers/Text/Base64Decoder.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs index 390dbba29b7a69..45a35f67ece180 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs @@ -538,7 +538,7 @@ private static OperationStatus DecodeWithWhitespaceFromUtf8InPlace(Span ut { int bufferIdx = 0; - while (sourceIndex < length + while ((uint)sourceIndex < (uint)length && bufferIdx < BlockSize) { if (!IsWhitespace(utf8[sourceIndex])) From 04989b53434a22f73b9d67339298654b2c42a369 Mon Sep 17 00:00:00 2001 From: Heath Baron-Morgan Date: Mon, 13 Mar 2023 21:25:16 -0700 Subject: [PATCH 16/17] Address PR feedback: Base64.IsValid: Return when no more invalid chars --- .../tests/Base64/Base64ValidationUnitTests.cs | 13 +++++++++++++ .../src/System/Buffers/Text/Base64Validator.cs | 10 +++++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs index abccac736bb217..c7f164ad9b7f5f 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs @@ -314,6 +314,19 @@ public void InvalidBase64Bytes(string utf8WithByteToBeIgnored) [InlineData("aYQ= =a")] [InlineData("aYQ== a")] [InlineData("aYQ==a ")] + [InlineData("a")] + [InlineData(" a")] + [InlineData(" a")] + [InlineData(" a")] + [InlineData(" a")] + [InlineData("a ")] + [InlineData("a ")] + [InlineData("a ")] + [InlineData("a ")] + [InlineData(" a ")] + [InlineData(" a ")] + [InlineData(" a ")] + [InlineData(" a ")] public void InvalidBase64Chars(string utf8WithByteToBeIgnored) { byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs index 47b23c7988b1d3..717aab3536842b 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs @@ -93,9 +93,13 @@ private static bool IsValid(ReadOnlySpan base64Text, o if (paddingCount == 0) { ReadOnlySpan slicedSpan = base64Text.Slice(indexOfPaddingInvalidOrWhitespace + 1); - indexOfPaddingInvalidOrWhitespace = - TBase64Validatable.IndexOfAnyExcept(slicedSpan) - + indexOfPaddingInvalidOrWhitespace + 1; // Add current index offset. + int nextIndexOfPaddingInvalidOrWhitespace = TBase64Validatable.IndexOfAnyExcept(slicedSpan); + if (nextIndexOfPaddingInvalidOrWhitespace == -1) + { + // No more invalid chars found. + break; + } + indexOfPaddingInvalidOrWhitespace = nextIndexOfPaddingInvalidOrWhitespace + indexOfPaddingInvalidOrWhitespace + 1; // Add current index offset. } // If padding is already found, simply increment, as the common case might have 2 sequential padding chars. else From 5007df7b50353bb8a7e3fa1ad06bd03e3a5a1468 Mon Sep 17 00:00:00 2001 From: Heath Baron-Morgan Date: Tue, 14 Mar 2023 00:19:50 -0700 Subject: [PATCH 17/17] Address PR feedback: Refactor Bas64.IsValid method --- .../System/Buffers/Text/Base64Validator.cs | 89 +++++++++---------- 1 file changed, 44 insertions(+), 45 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs index 717aab3536842b..8c47cfc3998281 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs @@ -51,24 +51,52 @@ private static bool IsValid(ReadOnlySpan base64Text, o return true; } - int length = base64Text.Length; + int length = 0; + bool isPaddingFound = false; + + while (true) + { + int index = TBase64Validatable.IndexOfAnyExcept(base64Text); + if ((uint)index >= (uint)base64Text.Length) + { + length += base64Text.Length; + break; + } + + length += index; + + T charToValidate = base64Text[index]; + base64Text = base64Text.Slice(index + 1); + + if (TBase64Validatable.IsWhitespace(charToValidate)) + { + continue; + } + + if (!TBase64Validatable.IsEncodingPad(charToValidate)) + { + // Invalid char was found. + decodedLength = 0; + return false; + } + + // Encoding pad found, determine if padding is valid below. + isPaddingFound = true; + break; + } + int paddingCount = 0; - int indexOfPaddingInvalidOrWhitespace = TBase64Validatable.IndexOfAnyExcept(base64Text); - if (indexOfPaddingInvalidOrWhitespace >= 0) + if (isPaddingFound) { - while (indexOfPaddingInvalidOrWhitespace >= 0) + paddingCount = 1; + + foreach (T charToValidateInPadding in base64Text) { - T charToValidate = base64Text[indexOfPaddingInvalidOrWhitespace]; - if (TBase64Validatable.IsWhitespace(charToValidate)) - { - // Chars to be ignored (e,g, whitespace...) should not count towards the length. - length--; - } - else if (TBase64Validatable.IsEncodingPad(charToValidate)) + if (TBase64Validatable.IsEncodingPad(charToValidateInPadding)) { // There can be at most 2 padding chars. - if (paddingCount == 2) + if (paddingCount >= 2) { decodedLength = 0; return false; @@ -76,44 +104,15 @@ private static bool IsValid(ReadOnlySpan base64Text, o paddingCount++; } - else + else if (!TBase64Validatable.IsWhitespace(charToValidateInPadding)) { - // An invalid char was encountered. + // Invalid char was found. decodedLength = 0; return false; } - - if (indexOfPaddingInvalidOrWhitespace == base64Text.Length - 1) - { - // The end of the input has been reached. - break; - } - - // If no padding is found, slice and use IndexOfAnyExcept to look for the next invalid char. - if (paddingCount == 0) - { - ReadOnlySpan slicedSpan = base64Text.Slice(indexOfPaddingInvalidOrWhitespace + 1); - int nextIndexOfPaddingInvalidOrWhitespace = TBase64Validatable.IndexOfAnyExcept(slicedSpan); - if (nextIndexOfPaddingInvalidOrWhitespace == -1) - { - // No more invalid chars found. - break; - } - indexOfPaddingInvalidOrWhitespace = nextIndexOfPaddingInvalidOrWhitespace + indexOfPaddingInvalidOrWhitespace + 1; // Add current index offset. - } - // If padding is already found, simply increment, as the common case might have 2 sequential padding chars. - else - { - indexOfPaddingInvalidOrWhitespace++; - } } - // If the invalid chars all consisted of whitespace, the input will be empty. - if (length == 0) - { - decodedLength = 0; - return true; - } + length += paddingCount; } if (length % 4 != 0) @@ -122,7 +121,7 @@ private static bool IsValid(ReadOnlySpan base64Text, o return false; } - // Remove padding to get exact length + // Remove padding to get exact length. decodedLength = (int)((uint)length / 4 * 3) - paddingCount; return true; }