diff --git a/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs b/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs index 0d792b26682d6..bbd28fb092b01 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs @@ -1,12 +1,15 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections; +using System.Collections.Generic; +using System.Linq; using System.Text; using Xunit; namespace System.Buffers.Text.Tests { - public class Base64DecoderUnitTests + public class Base64DecoderUnitTests : Base64TestBase { [Fact] public void BasicDecoding() @@ -157,7 +160,7 @@ public void DecodingOutputTooSmall() Span decodedBytes = new byte[3]; int consumed, written; - if (numBytes % 4 == 0) + if (numBytes >= 8) { Assert.True(OperationStatus.DestinationTooSmall == Base64.DecodeFromUtf8(source, decodedBytes, out consumed, out written), "Number of Input Bytes: " + numBytes); @@ -373,8 +376,12 @@ public void DecodingInvalidBytes(bool isFinalBlock) for (int i = 0; i < invalidBytes.Length; i++) { // Don't test padding (byte 61 i.e. '='), which is tested in DecodingInvalidBytesPadding - if (invalidBytes[i] == Base64TestHelper.EncodingPad) + // Don't test chars to be ignored (spaces: 9, 10, 13, 32 i.e. '\n', '\t', '\r', ' ') + if (invalidBytes[i] == Base64TestHelper.EncodingPad || + Base64TestHelper.IsByteToBeIgnored(invalidBytes[i])) + { continue; + } // replace one byte with an invalid input source[j] = invalidBytes[i]; @@ -568,8 +575,12 @@ public void DecodeInPlaceInvalidBytes() Span buffer = "2222PPPP"u8.ToArray(); // valid input // Don't test padding (byte 61 i.e. '='), which is tested in DecodeInPlaceInvalidBytesPadding - if (invalidBytes[i] == Base64TestHelper.EncodingPad) + // Don't test chars to be ignored (spaces: 9, 10, 13, 32 i.e. '\n', '\t', '\r', ' ') + if (invalidBytes[i] == Base64TestHelper.EncodingPad || + Base64TestHelper.IsByteToBeIgnored(invalidBytes[i])) + { continue; + } // replace one byte with an invalid input buffer[j] = invalidBytes[i]; @@ -594,7 +605,7 @@ public void DecodeInPlaceInvalidBytes() { Span buffer = "2222PPP"u8.ToArray(); // incomplete input Assert.Equal(OperationStatus.InvalidData, Base64.DecodeFromUtf8InPlace(buffer, out int bytesWritten)); - Assert.Equal(0, bytesWritten); + Assert.Equal(3, bytesWritten); } } @@ -667,5 +678,90 @@ public void DecodeInPlaceInvalidBytesPadding() } } + [Theory] + [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] + public void BasicDecodingIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored); + byte[] resultBytes = new byte[5]; + OperationStatus result = Base64.DecodeFromUtf8(utf8BytesWithByteToBeIgnored, resultBytes, out int bytesConsumed, out int bytesWritten); + + // Control value from Convert.FromBase64String + byte[] stringBytes = Convert.FromBase64String(utf8WithCharsToBeIgnored); + + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(utf8WithCharsToBeIgnored.Length, bytesConsumed); + Assert.Equal(expectedBytes.Length, bytesWritten); + Assert.True(expectedBytes.SequenceEqual(resultBytes)); + Assert.True(stringBytes.SequenceEqual(resultBytes)); + } + + [Theory] + [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] + public void DecodeInPlaceIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes) + { + Span utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored); + OperationStatus result = Base64.DecodeFromUtf8InPlace(utf8BytesWithByteToBeIgnored, out int bytesWritten); + Span bytesOverwritten = utf8BytesWithByteToBeIgnored.Slice(0, bytesWritten); + byte[] resultBytesArray = bytesOverwritten.ToArray(); + + // Control value from Convert.FromBase64String + byte[] stringBytes = Convert.FromBase64String(utf8WithCharsToBeIgnored); + + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(expectedBytes.Length, bytesWritten); + Assert.True(expectedBytes.SequenceEqual(resultBytesArray)); + Assert.True(stringBytes.SequenceEqual(resultBytesArray)); + } + + [Theory] + [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))] + public void BasicDecodingWithOnlyCharsToBeIgnored(string utf8WithCharsToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored); + byte[] resultBytes = new byte[5]; + OperationStatus result = Base64.DecodeFromUtf8(utf8BytesWithByteToBeIgnored, resultBytes, out int bytesConsumed, out int bytesWritten); + + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(0, bytesWritten); + } + + [Theory] + [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))] + public void DecodingInPlaceWithOnlyCharsToBeIgnored(string utf8WithCharsToBeIgnored) + { + Span utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored); + OperationStatus result = Base64.DecodeFromUtf8InPlace(utf8BytesWithByteToBeIgnored, out int bytesWritten); + + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(0, bytesWritten); + } + + [Theory] + [MemberData(nameof(BasicDecodingWithExtraWhitespaceShouldBeCountedInConsumedBytes_MemberData))] + public void BasicDecodingWithExtraWhitespaceShouldBeCountedInConsumedBytes(string inputString, int expectedConsumed, int expectedWritten) + { + Span source = Encoding.ASCII.GetBytes(inputString); + Span decodedBytes = new byte[Base64.GetMaxDecodedFromUtf8Length(source.Length)]; + + Assert.Equal(OperationStatus.Done, Base64.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount)); + Assert.Equal(expectedConsumed, consumed); + Assert.Equal(expectedWritten, decodedByteCount); + Assert.True(Base64TestHelper.VerifyDecodingCorrectness(expectedConsumed, expectedWritten, source, decodedBytes)); + } + + public static IEnumerable BasicDecodingWithExtraWhitespaceShouldBeCountedInConsumedBytes_MemberData() + { + var r = new Random(42); + for (int i = 0; i < 5; i++) + { + yield return new object[] { "AQ==" + new string(r.GetItems(" \n\t\r", i)), 4 + i, 1 }; + } + + foreach (string s in new[] { "MTIz", "M TIz", "MT Iz", "MTI z", "MTIz ", "M TI z", "M T I Z " }) + { + yield return new object[] { s + s + s + s, s.Length * 4, 12 }; + } + } } } diff --git a/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs b/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs new file mode 100644 index 0000000000000..882db3026722e --- /dev/null +++ b/src/libraries/System.Memory/tests/Base64/Base64TestBase.cs @@ -0,0 +1,111 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license.utf8Bytes, utf8Bytes.Length + +using System.Collections.Generic; +using System.Text; + +namespace System.Buffers.Text.Tests +{ + public class Base64TestBase + { + public static IEnumerable ValidBase64Strings_WithCharsThatMustBeIgnored() + { + // Create a Base64 string + string text = "a b c"; + byte[] utf8Bytes = Encoding.UTF8.GetBytes(text); + string base64Utf8String = Convert.ToBase64String(utf8Bytes); + + // Split the base64 string in half + int stringLength = base64Utf8String.Length / 2; + string firstSegment = base64Utf8String.Substring(0, stringLength); + string secondSegment = base64Utf8String.Substring(stringLength, stringLength); + + // Insert ignored chars between the base 64 string + // One will have 1 char, another will have 3 + + // Line feed + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(9), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(9), 3), utf8Bytes }; + + // Horizontal tab + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(10), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(10), 3), utf8Bytes }; + + // Carriage return + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(13), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(13), 3), utf8Bytes }; + + // Space + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(32), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(32), 3), utf8Bytes }; + + string GetBase64StringWithPassedCharInsertedInTheMiddle(char charToInsert, int numberOfTimesToInsert) => $"{firstSegment}{new string(charToInsert, numberOfTimesToInsert)}{secondSegment}"; + + // Insert ignored chars at the start of the base 64 string + // One will have 1 char, another will have 3 + + // Line feed + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(9), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(9), 3), utf8Bytes }; + + // Horizontal tab + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(10), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(10), 3), utf8Bytes }; + + // Carriage return + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(13), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(13), 3), utf8Bytes }; + + // Space + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(32), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(32), 3), utf8Bytes }; + + string GetBase64StringWithPassedCharInsertedAtTheStart(char charToInsert, int numberOfTimesToInsert) => $"{new string(charToInsert, numberOfTimesToInsert)}{firstSegment}{secondSegment}"; + + // Insert ignored chars at the end of the base 64 string + // One will have 1 char, another will have 3 + // Whitespace after end/padding is not included in consumed bytes + + // Line feed + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(9), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(9), 3), utf8Bytes }; + + // Horizontal tab + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(10), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(10), 3), utf8Bytes }; + + // Carriage return + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(13), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(13), 3), utf8Bytes }; + + // Space + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(32), 1), utf8Bytes }; + yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(32), 3), utf8Bytes }; + + string GetBase64StringWithPassedCharInsertedAtTheEnd(char charToInsert, int numberOfTimesToInsert) => $"{firstSegment}{secondSegment}{new string(charToInsert, numberOfTimesToInsert)}"; + } + + public static IEnumerable StringsOnlyWithCharsToBeIgnored() + { + // One will have 1 char, another will have 3 + + // Line feed + yield return new object[] { GetRepeatedChar(Convert.ToChar(9), 1) }; + yield return new object[] { GetRepeatedChar(Convert.ToChar(9), 3) }; + + // Horizontal tab + yield return new object[] { GetRepeatedChar(Convert.ToChar(10), 1) }; + yield return new object[] { GetRepeatedChar(Convert.ToChar(10), 3) }; + + // Carriage return + yield return new object[] { GetRepeatedChar(Convert.ToChar(13), 1) }; + yield return new object[] { GetRepeatedChar(Convert.ToChar(13), 3) }; + + // Space + yield return new object[] { GetRepeatedChar(Convert.ToChar(32), 1) }; + yield return new object[] { GetRepeatedChar(Convert.ToChar(32), 3) }; + + string GetRepeatedChar(char charToInsert, int numberOfTimesToInsert) => new string(charToInsert, numberOfTimesToInsert); + } + } +} diff --git a/src/libraries/System.Memory/tests/Base64/Base64TestHelper.cs b/src/libraries/System.Memory/tests/Base64/Base64TestHelper.cs index 7715f6b5d4bdf..1ccc8e0cb4289 100644 --- a/src/libraries/System.Memory/tests/Base64/Base64TestHelper.cs +++ b/src/libraries/System.Memory/tests/Base64/Base64TestHelper.cs @@ -44,6 +44,8 @@ public static class Base64TestHelper -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, }; + public static bool IsByteToBeIgnored(byte charByte) => charByte is (byte)' ' or (byte)'\t' or (byte)'\r' or (byte)'\n'; + public const byte EncodingPad = (byte)'='; // '=', for padding public const sbyte InvalidByte = -1; // Designating -1 for invalid bytes in the decoding map diff --git a/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs new file mode 100644 index 0000000000000..c7f164ad9b7f5 --- /dev/null +++ b/src/libraries/System.Memory/tests/Base64/Base64ValidationUnitTests.cs @@ -0,0 +1,339 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Linq; +using System.Text; +using Xunit; + +namespace System.Buffers.Text.Tests +{ + public class Base64ValidationUnitTests : Base64TestBase + { + [Fact] + public void BasicValidationBytes() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 != 0); // ensure we have a valid length + + Span source = new byte[numBytes]; + Base64TestHelper.InitializeDecodableBytes(source, numBytes); + + Assert.True(Base64.IsValid(source)); + Assert.True(Base64.IsValid(source, out int decodedLength)); + Assert.True(decodedLength > 0); + } + } + + [Fact] + public void BasicValidationChars() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 != 0); // ensure we have a valid length + + Span source = new byte[numBytes]; + Base64TestHelper.InitializeDecodableBytes(source, numBytes); + Span chars = source + .ToArray() + .Select(Convert.ToChar) + .ToArray() + .AsSpan(); + + Assert.True(Base64.IsValid(chars)); + Assert.True(Base64.IsValid(chars, out int decodedLength)); + Assert.True(decodedLength > 0); + } + } + + [Fact] + public void BasicValidationInvalidInputLengthBytes() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 == 0); // ensure we have a invalid length + + Span source = new byte[numBytes]; + + Assert.False(Base64.IsValid(source)); + Assert.False(Base64.IsValid(source, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + } + + [Fact] + public void BasicValidationInvalidInputLengthChars() + { + var rnd = new Random(42); + for (int i = 0; i < 10; i++) + { + int numBytes; + do + { + numBytes = rnd.Next(100, 1000 * 1000); + } while (numBytes % 4 == 0); // ensure we have a invalid length + + Span source = new char[numBytes]; + + Assert.False(Base64.IsValid(source)); + Assert.False(Base64.IsValid(source, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + } + + [Fact] + public void ValidateEmptySpanBytes() + { + Span source = Span.Empty; + + Assert.True(Base64.IsValid(source)); + Assert.True(Base64.IsValid(source, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Fact] + public void ValidateEmptySpanChars() + { + Span source = Span.Empty; + + Assert.True(Base64.IsValid(source)); + Assert.True(Base64.IsValid(source, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Fact] + public void ValidateGuidBytes() + { + Span source = new byte[24]; + Span decodedBytes = Guid.NewGuid().ToByteArray(); + Base64.EncodeToUtf8(decodedBytes, source, out int _, out int _); + + Assert.True(Base64.IsValid(source)); + Assert.True(Base64.IsValid(source, out int decodedLength)); + Assert.True(decodedLength > 0); + } + + [Fact] + public void ValidateGuidChars() + { + Span source = new byte[24]; + Span decodedBytes = Guid.NewGuid().ToByteArray(); + Base64.EncodeToUtf8(decodedBytes, source, out int _, out int _); + Span chars = source + .ToArray() + .Select(Convert.ToChar) + .ToArray() + .AsSpan(); + + Assert.True(Base64.IsValid(chars)); + Assert.True(Base64.IsValid(chars, out int decodedLength)); + Assert.True(decodedLength > 0); + } + + [Theory] + [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] + public void ValidateBytesIgnoresCharsToBeIgnoredBytes(string utf8WithByteToBeIgnored, byte[] expectedBytes) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedBytes.Length, decodedLength); + } + + [Theory] + [MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))] + public void ValidateBytesIgnoresCharsToBeIgnoredChars(string utf8WithByteToBeIgnored, byte[] expectedBytes) + { + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray(); + + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedBytes.Length, decodedLength); + } + + [Theory] + [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))] + public void ValidateWithOnlyCharsToBeIgnoredBytes(string utf8WithByteToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Theory] + [MemberData(nameof(StringsOnlyWithCharsToBeIgnored))] + public void ValidateWithOnlyCharsToBeIgnoredChars(string utf8WithByteToBeIgnored) + { + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray(); + + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Theory] + [InlineData("YQ==", 1)] + [InlineData("YWI=", 2)] + [InlineData("YWJj", 3)] + [InlineData(" YWI=", 2)] + [InlineData("Y WI=", 2)] + [InlineData("YW I=", 2)] + [InlineData("YWI =", 2)] + [InlineData("YWI= ", 2)] + [InlineData(" YQ==", 1)] + [InlineData("Y Q==", 1)] + [InlineData("YQ ==", 1)] + [InlineData("YQ= =", 1)] + [InlineData("YQ== ", 1)] + public void ValidateWithPaddingReturnsCorrectCountBytes(string utf8WithByteToBeIgnored, int expectedLength) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedLength, decodedLength); + } + + [Theory] + [InlineData("YQ==", 1)] + [InlineData("YWI=", 2)] + [InlineData("YWJj", 3)] + [InlineData(" YWI=", 2)] + [InlineData("Y WI=", 2)] + [InlineData("YW I=", 2)] + [InlineData("YWI =", 2)] + [InlineData("YWI= ", 2)] + [InlineData(" YQ==", 1)] + [InlineData("Y Q==", 1)] + [InlineData("YQ ==", 1)] + [InlineData("YQ= =", 1)] + [InlineData("YQ== ", 1)] + public void ValidateWithPaddingReturnsCorrectCountChars(string utf8WithByteToBeIgnored, int expectedLength) + { + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray(); + + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedLength, decodedLength); + } + + [Theory] + [InlineData("YQ==", 1)] + [InlineData("YWI=", 2)] + [InlineData("YWJj", 3)] + public void DecodeEmptySpan(string utf8WithByteToBeIgnored, int expectedLength) + { + ReadOnlySpan utf8BytesWithByteToBeIgnored = utf8WithByteToBeIgnored.ToArray(); + + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.True(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(expectedLength, decodedLength); + } + + [Theory] + [InlineData("YWJ")] + [InlineData("YW")] + [InlineData("Y")] + public void InvalidSizeBytes(string utf8WithByteToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Theory] + [InlineData("YWJ")] + [InlineData("YW")] + [InlineData("Y")] + public void InvalidSizeChars(string utf8WithByteToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Theory] + [InlineData("YQ===")] + [InlineData("YQ=a=")] + [InlineData("YWI=a")] + [InlineData(" aYWI=a")] + [InlineData("a YWI=a")] + [InlineData("aY WI=a")] + [InlineData("aYW I=a")] + [InlineData("aYWI =a")] + [InlineData("aYWI= a")] + [InlineData("a YQ==a")] + [InlineData("aY Q==a")] + [InlineData("aYQ ==a")] + [InlineData("aYQ= =a")] + [InlineData("aYQ== a")] + [InlineData("aYQ==a ")] + public void InvalidBase64Bytes(string utf8WithByteToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + + [Theory] + [InlineData("YQ===")] + [InlineData("YQ=a=")] + [InlineData("YWI=a")] + [InlineData("a YWI=a")] + [InlineData("aY WI=a")] + [InlineData("aYW I=a")] + [InlineData("aYWI =a")] + [InlineData("aYWI= a")] + [InlineData("a YQ==a")] + [InlineData("aY Q==a")] + [InlineData("aYQ ==a")] + [InlineData("aYQ= =a")] + [InlineData("aYQ== a")] + [InlineData("aYQ==a ")] + [InlineData("a")] + [InlineData(" a")] + [InlineData(" a")] + [InlineData(" a")] + [InlineData(" a")] + [InlineData("a ")] + [InlineData("a ")] + [InlineData("a ")] + [InlineData("a ")] + [InlineData(" a ")] + [InlineData(" a ")] + [InlineData(" a ")] + [InlineData(" a ")] + public void InvalidBase64Chars(string utf8WithByteToBeIgnored) + { + byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithByteToBeIgnored); + + Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored)); + Assert.False(Base64.IsValid(utf8BytesWithByteToBeIgnored, out int decodedLength)); + Assert.Equal(0, decodedLength); + } + } +} diff --git a/src/libraries/System.Memory/tests/System.Memory.Tests.csproj b/src/libraries/System.Memory/tests/System.Memory.Tests.csproj index ff1fe297af74e..ca10d25a3872c 100644 --- a/src/libraries/System.Memory/tests/System.Memory.Tests.csproj +++ b/src/libraries/System.Memory/tests/System.Memory.Tests.csproj @@ -272,6 +272,8 @@ + + + diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs index 2899c7d1021bf..5f49e544faf99 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs @@ -35,15 +35,66 @@ public static partial class Base64 /// - InvalidData - if the input contains bytes outside of the expected base64 range, or if it contains invalid/more than two padding characters, /// or if the input is incomplete (i.e. not a multiple of 4) and is . /// - public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Span bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) + public static OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Span bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) { - if (utf8.IsEmpty) + OperationStatus status = OperationStatus.Done; + bytesConsumed = 0; + bytesWritten = 0; + + while (!utf8.IsEmpty) { - bytesConsumed = 0; - bytesWritten = 0; - return OperationStatus.Done; + status = DecodeFromUtf8Core(utf8, bytes, out int localConsumed, out int localWritten, isFinalBlock); + bytesConsumed += localConsumed; + bytesWritten += localWritten; + + if (status is not OperationStatus.InvalidData) + { + break; + } + + utf8 = utf8.Slice(localConsumed); + bytes = bytes.Slice(localWritten); + + if (utf8.IsEmpty) + { + break; + } + + localConsumed = IndexOfAnyExceptWhiteSpace(utf8); + if (localConsumed < 0) + { + // The remainder of the input is all whitespace. Mark it all as having been consumed, + // and mark the operation as being done. + bytesConsumed += utf8.Length; + status = OperationStatus.Done; + break; + } + + if (localConsumed == 0) + { + // Non-whitespace was found at the beginning of the input. Since it wasn't consumed + // by the previous call to DecodeFromUtf8Core, it must be part of a Base64 sequence + // that was interrupted by whitespace or something else considered invalid. + // Fall back to block-wise decoding. This is very slow, but it's also very non-standard + // formatting of the input; whitespace is typically only found between blocks, such as + // when Convert.ToBase64String inserts a line break every 76 output characters. + return DecodeWithWhiteSpaceBlockwise(utf8, bytes, ref bytesConsumed, ref bytesWritten, isFinalBlock); + } + + // Skip over the starting whitespace and continue. + bytesConsumed += localConsumed; + utf8 = utf8.Slice(localConsumed); } + return status; + } + + /// + /// Core logic for decoding UTF-8 encoded text in base 64 into binary data. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe OperationStatus DecodeFromUtf8Core(ReadOnlySpan utf8, Span bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) + { fixed (byte* srcBytes = &MemoryMarshal.GetReference(utf8)) fixed (byte* destBytes = &MemoryMarshal.GetReference(bytes)) { @@ -72,7 +123,9 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa Avx2Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); if (src == srcEnd) + { goto DoneExit; + } } end = srcMax - 24; @@ -81,7 +134,9 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa Vector128Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); if (src == srcEnd) + { goto DoneExit; + } } } @@ -109,7 +164,9 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa int result = Decode(src, ref decodingMap); if (result < 0) + { goto InvalidDataExit; + } WriteThreeLowOrderBytes(dest, result); src += 4; @@ -117,17 +174,23 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa } if (maxSrcLength != srcLength - skipLastChunk) + { goto DestinationTooSmallExit; + } // If input is less than 4 bytes, srcLength == sourceIndex == 0 // If input is not a multiple of 4, sourceIndex == srcLength != 0 if (src == srcEnd) { if (isFinalBlock) + { goto InvalidDataExit; + } if (src == srcBytes + utf8.Length) + { goto DoneExit; + } goto NeedMoreDataExit; } @@ -161,9 +224,13 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa i0 |= i2; if (i0 < 0) + { goto InvalidDataExit; + } if (dest + 3 > destMax) + { goto DestinationTooSmallExit; + } WriteThreeLowOrderBytes(dest, i0); dest += 3; @@ -177,9 +244,13 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa i0 |= i2; if (i0 < 0) + { goto InvalidDataExit; + } if (dest + 2 > destMax) + { goto DestinationTooSmallExit; + } dest[0] = (byte)(i0 >> 16); dest[1] = (byte)(i0 >> 8); @@ -188,9 +259,13 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa else { if (i0 < 0) + { goto InvalidDataExit; + } if (dest + 1 > destMax) + { goto DestinationTooSmallExit; + } dest[0] = (byte)(i0 >> 16); dest += 1; @@ -199,7 +274,9 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa src += 4; if (srcLength != utf8.Length) + { goto InvalidDataExit; + } DoneExit: bytesConsumed = (int)(src - srcBytes); @@ -208,7 +285,9 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa DestinationTooSmallExit: if (srcLength != utf8.Length && isFinalBlock) + { goto InvalidDataExit; // if input is not a multiple of 4, and there is no more data, return invalid data instead + } bytesConsumed = (int)(src - srcBytes); bytesWritten = (int)(dest - destBytes); @@ -227,7 +306,7 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa } /// - /// Returns the maximum length (in bytes) of the result if you were to deocde base 64 encoded text within a byte span of size "length". + /// Returns the maximum length (in bytes) of the result if you were to decode base 64 encoded text within a byte span of size "length". /// /// /// Thrown when the specified is less than 0. @@ -236,7 +315,9 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa public static int GetMaxDecodedFromUtf8Length(int length) { if (length < 0) + { ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.length); + } return (length >> 2) * 3; } @@ -256,25 +337,138 @@ public static int GetMaxDecodedFromUtf8Length(int length) /// It does not return NeedMoreData since this method tramples the data in the buffer and /// hence can only be called once with all the data in the buffer. /// - public static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, out int bytesWritten) + public static OperationStatus DecodeFromUtf8InPlace(Span buffer, out int bytesWritten) { - if (buffer.IsEmpty) + OperationStatus status = DecodeFromUtf8InPlaceCore(buffer, out bytesWritten, out uint sourceIndex); + Debug.Assert(status is OperationStatus.Done or OperationStatus.InvalidData, "These are the only statuses the method is coded to return."); + + if (status != OperationStatus.Done) { - bytesWritten = 0; - return OperationStatus.Done; + // The input may have whitespace, attempt to decode while ignoring whitespace. + status = DecodeWithWhiteSpaceFromUtf8InPlace(buffer, ref bytesWritten, (int)sourceIndex); + } + + return status; + } + + private static OperationStatus DecodeWithWhiteSpaceBlockwise(ReadOnlySpan utf8, Span bytes, ref int bytesConsumed, ref int bytesWritten, bool isFinalBlock = true) + { + const int BlockSize = 4; + Span buffer = stackalloc byte[BlockSize]; + OperationStatus status = OperationStatus.Done; + + while (!utf8.IsEmpty) + { + int encodedIdx = 0; + int bufferIdx = 0; + int skipped = 0; + + for (; encodedIdx < utf8.Length && (uint)bufferIdx < (uint)buffer.Length; ++encodedIdx) + { + if (IsWhiteSpace(utf8[encodedIdx])) + { + skipped++; + } + else + { + buffer[bufferIdx] = utf8[encodedIdx]; + bufferIdx++; + } + } + + utf8 = utf8.Slice(encodedIdx); + bytesConsumed += skipped; + + if (bufferIdx == 0) + { + continue; + } + + bool hasAnotherBlock = utf8.Length >= BlockSize && bufferIdx == BlockSize; + bool localIsFinalBlock = !hasAnotherBlock; + + // If this block contains padding and there's another block, then only whitespace may follow for being valid. + if (hasAnotherBlock) + { + int paddingCount = GetPaddingCount(ref buffer[^1]); + if (paddingCount > 0) + { + hasAnotherBlock = false; + localIsFinalBlock = true; + } + } + + if (localIsFinalBlock && !isFinalBlock) + { + localIsFinalBlock = false; + } + + status = DecodeFromUtf8Core(buffer.Slice(0, bufferIdx), bytes, out int localConsumed, out int localWritten, localIsFinalBlock); + bytesConsumed += localConsumed; + bytesWritten += localWritten; + + if (status != OperationStatus.Done) + { + return status; + } + + // The remaining data must all be whitespace in order to be valid. + if (!hasAnotherBlock) + { + for (int i = 0; i < utf8.Length; ++i) + { + if (!IsWhiteSpace(utf8[i])) + { + // Revert previous dest increment, since an invalid state followed. + bytesConsumed -= localConsumed; + bytesWritten -= localWritten; + + return OperationStatus.InvalidData; + } + + bytesConsumed++; + } + + break; + } + + bytes = bytes.Slice(localWritten); } + return status; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int GetPaddingCount(ref byte ptrToLastElement) + { + int padding = 0; + + if (ptrToLastElement == EncodingPad) padding++; + if (Unsafe.Subtract(ref ptrToLastElement, 1) == EncodingPad) padding++; + + return padding; + } + + /// + /// Core logic for decoding UTF-8 encoded text in base 64 into binary data in place. + /// + private static unsafe OperationStatus DecodeFromUtf8InPlaceCore(Span buffer, out int bytesWritten, out uint sourceIndex) + { fixed (byte* bufferBytes = &MemoryMarshal.GetReference(buffer)) { int bufferLength = buffer.Length; - uint sourceIndex = 0; + sourceIndex = 0; uint destIndex = 0; // only decode input if it is a multiple of 4 if (bufferLength != ((bufferLength >> 2) * 4)) + { goto InvalidExit; + } if (bufferLength == 0) + { goto DoneExit; + } ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap); @@ -282,7 +476,10 @@ public static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, ou { int result = Decode(bufferBytes + sourceIndex, ref decodingMap); if (result < 0) + { goto InvalidExit; + } + WriteThreeLowOrderBytes(bufferBytes + destIndex, result); destIndex += 3; sourceIndex += 4; @@ -312,7 +509,9 @@ public static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, ou i0 |= i2; if (i0 < 0) + { goto InvalidExit; + } WriteThreeLowOrderBytes(bufferBytes + destIndex, i0); destIndex += 3; @@ -326,7 +525,9 @@ public static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, ou i0 |= i2; if (i0 < 0) + { goto InvalidExit; + } bufferBytes[destIndex] = (byte)(i0 >> 16); bufferBytes[destIndex + 1] = (byte)(i0 >> 8); @@ -335,7 +536,9 @@ public static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, ou else { if (i0 < 0) + { goto InvalidExit; + } bufferBytes[destIndex] = (byte)(i0 >> 16); destIndex += 1; @@ -351,6 +554,76 @@ public static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, ou } } + private static OperationStatus DecodeWithWhiteSpaceFromUtf8InPlace(Span utf8, ref int destIndex, int sourceIndex) + { + const int BlockSize = 4; + Span buffer = stackalloc byte[BlockSize]; + + OperationStatus status = OperationStatus.Done; + int localDestIndex = destIndex; + bool hasPaddingBeenProcessed = false; + int localBytesWritten = 0; + + while ((uint)sourceIndex < (uint)utf8.Length) + { + int bufferIdx = 0; + + while (bufferIdx < BlockSize) + { + if ((uint)sourceIndex >= (uint)utf8.Length) // TODO https://github.com/dotnet/runtime/issues/83349: move into the while condition once fixed + { + break; + } + + if (!IsWhiteSpace(utf8[sourceIndex])) + { + buffer[bufferIdx] = utf8[sourceIndex]; + bufferIdx++; + } + + sourceIndex++; + } + + if (bufferIdx == 0) + { + continue; + } + + if (bufferIdx != 4) + { + status = OperationStatus.InvalidData; + break; + } + + if (hasPaddingBeenProcessed) + { + // Padding has already been processed, a new valid block cannot be processed. + // Revert previous dest increment, since an invalid state followed. + localDestIndex -= localBytesWritten; + status = OperationStatus.InvalidData; + break; + } + + status = DecodeFromUtf8InPlaceCore(buffer, out localBytesWritten, out _); + localDestIndex += localBytesWritten; + hasPaddingBeenProcessed = localBytesWritten < 3; + + if (status != OperationStatus.Done) + { + break; + } + + // Write result to source span in place. + for (int i = 0; i < localBytesWritten; i++) + { + utf8[localDestIndex - localBytesWritten + i] = buffer[i]; + } + } + + destIndex = localDestIndex; + return status; + } + [BypassReadyToRun] [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe void Avx2Decode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) @@ -433,7 +706,9 @@ private static unsafe void Avx2Decode(ref byte* srcBytes, ref byte* destBytes, b Vector256 lo = Avx2.Shuffle(lutLo, loNibbles); if (!Avx.TestZ(lo, hi)) + { break; + } Vector256 eq2F = Avx2.CompareEqual(str, mask2F); Vector256 shift = Avx2.Shuffle(lutShift, Avx2.Add(eq2F, hiNibbles)); @@ -598,7 +873,9 @@ private static unsafe void Vector128Decode(ref byte* srcBytes, ref byte* destByt // Check for invalid input: if any "and" values from lo and hi are not zero, // fall back on bytewise code to do error checking and reporting: if ((lo & hi) != Vector128.Zero) + { break; + } Vector128 eq2F = Vector128.Equals(str, mask2F); Vector128 shift = SimdShuffle(lutShift.AsByte(), (eq2F + hiNibbles), mask8F); @@ -692,16 +969,54 @@ private static unsafe void WriteThreeLowOrderBytes(byte* destination, int value) { destination[0] = (byte)(value >> 16); destination[1] = (byte)(value >> 8); - destination[2] = (byte)(value); + destination[2] = (byte)value; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int IndexOfAnyExceptWhiteSpace(ReadOnlySpan span) + { + for (int i = 0; i < span.Length; i++) + { + if (!IsWhiteSpace(span[i])) + { + return i; + } + } + + return -1; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static bool IsWhiteSpace(int value) + { + if (Environment.Is64BitProcess) + { + // For description see https://github.com/dotnet/runtime/blob/48e74187cb15386c29eedaa046a5ee2c7ddef161/src/libraries/Common/src/System/HexConverter.cs#L314-L330 + // Lookup bit mask for "\t\n\r ". + const ulong MagicConstant = 0xC800010000000000UL; + ulong i = (uint)value - '\t'; + ulong shift = MagicConstant << (int)i; + ulong mask = i - 64; + return (long)(shift & mask) < 0; + } + + if (value < 32) + { + const int BitMask = (1 << (int)'\t') | (1 << (int)'\n') | (1 << (int)'\r'); + return ((1 << value) & BitMask) != 0; + } + + return value == 32; } // Pre-computing this table using a custom string(s_characters) and GenerateDecodingMapAndVerify (found in tests) - private static ReadOnlySpan DecodingMap => new sbyte[] { + private static ReadOnlySpan DecodingMap => new sbyte[] + { -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63, //62 is placed at index 43 (for +), 63 at index 47 (for /) 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1, //52-61 are placed at index 48-57 (for 0-9), 64 at index 61 (for =) - -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1, //0-25 are placed at index 65-90 (for A-Z) -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1, //26-51 are placed at index 97-122 (for a-z) diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Encoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Encoder.cs index 83fe75aef335d..e051bbe932cb0 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Encoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Encoder.cs @@ -585,10 +585,10 @@ private static unsafe uint EncodeAndPadTwo(byte* oneByte, ref byte encodingMap) } } - private const uint EncodingPad = '='; // '=', for padding + internal const uint EncodingPad = '='; // '=', for padding private const int MaximumEncodeLength = (int.MaxValue / 4) * 3; // 1610612733 - private static ReadOnlySpan EncodingMap => "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"u8; + internal static ReadOnlySpan EncodingMap => "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"u8; } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs new file mode 100644 index 0000000000000..403377fab99c5 --- /dev/null +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Validator.cs @@ -0,0 +1,159 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Buffers.Text +{ + public static partial class Base64 + { + /// Validates that the specified span of text is comprised of valid base-64 encoded data. + /// A span of text to validate. + /// if contains a valid, decodable sequence of base-64 encoded data; otherwise, . + /// + /// If the method returns , the same text passed to and + /// would successfully decode (in the case + /// of assuming sufficient output space). Any amount of whitespace is allowed anywhere in the input, + /// where whitespace is defined as the characters ' ', '\t', '\r', or '\n'. + /// + public static bool IsValid(ReadOnlySpan base64Text) => + IsValid(base64Text, out _); + + /// Validates that the specified span of text is comprised of valid base-64 encoded data. + /// A span of text to validate. + /// If the method returns true, the number of decoded bytes that will result from decoding the input text. + /// if contains a valid, decodable sequence of base-64 encoded data; otherwise, . + /// + /// If the method returns , the same text passed to and + /// would successfully decode (in the case + /// of assuming sufficient output space). Any amount of whitespace is allowed anywhere in the input, + /// where whitespace is defined as the characters ' ', '\t', '\r', or '\n'. + /// + public static bool IsValid(ReadOnlySpan base64Text, out int decodedLength) => + IsValid(base64Text, out decodedLength); + + /// Validates that the specified span of UTF8 text is comprised of valid base-64 encoded data. + /// A span of UTF8 text to validate. + /// if contains a valid, decodable sequence of base-64 encoded data; otherwise, . + /// + /// If the method returns , the same text passed to and + /// would successfully decode. Any amount of whitespace is allowed anywhere in the input, + /// where whitespace is defined as the characters ' ', '\t', '\r', or '\n' (as bytes). + /// + public static bool IsValid(ReadOnlySpan base64TextUtf8) => + IsValid(base64TextUtf8, out _); + + /// Validates that the specified span of UTF8 text is comprised of valid base-64 encoded data. + /// A span of UTF8 text to validate. + /// If the method returns true, the number of decoded bytes that will result from decoding the input UTF8 text. + /// if contains a valid, decodable sequence of base-64 encoded data; otherwise, . + /// + /// If the method returns , the same text passed to and + /// would successfully decode. Any amount of whitespace is allowed anywhere in the input, + /// where whitespace is defined as the characters ' ', '\t', '\r', or '\n' (as bytes). + /// + public static bool IsValid(ReadOnlySpan base64TextUtf8, out int decodedLength) => + IsValid(base64TextUtf8, out decodedLength); + + private static bool IsValid(ReadOnlySpan base64Text, out int decodedLength) + where TBase64Validatable : IBase64Validatable + { + int length = 0, paddingCount = 0; + + if (!base64Text.IsEmpty) + { + while (true) + { + int index = TBase64Validatable.IndexOfAnyExcept(base64Text); + if ((uint)index >= (uint)base64Text.Length) + { + length += base64Text.Length; + break; + } + + length += index; + + T charToValidate = base64Text[index]; + base64Text = base64Text.Slice(index + 1); + + if (TBase64Validatable.IsWhiteSpace(charToValidate)) + { + // It's common if there's whitespace for there to be multiple whitespace characters in a row, + // e.g. \r\n. Optimize for that case by looping here. + while (!base64Text.IsEmpty && TBase64Validatable.IsWhiteSpace(base64Text[0])) + { + base64Text = base64Text.Slice(1); + } + continue; + } + + if (!TBase64Validatable.IsEncodingPad(charToValidate)) + { + // Invalid char was found. + goto Fail; + } + + // Encoding pad found. Determine if padding is valid, then stop processing. + paddingCount = 1; + foreach (T charToValidateInPadding in base64Text) + { + if (TBase64Validatable.IsEncodingPad(charToValidateInPadding)) + { + // There can be at most 2 padding chars. + if (paddingCount >= 2) + { + goto Fail; + } + + paddingCount++; + } + else if (!TBase64Validatable.IsWhiteSpace(charToValidateInPadding)) + { + // Invalid char was found. + goto Fail; + } + } + + length += paddingCount; + break; + } + + if (length % 4 != 0) + { + goto Fail; + } + } + + // Remove padding to get exact length. + decodedLength = (int)((uint)length / 4 * 3) - paddingCount; + return true; + + Fail: + decodedLength = 0; + return false; + } + + private interface IBase64Validatable + { + static abstract int IndexOfAnyExcept(ReadOnlySpan span); + static abstract bool IsWhiteSpace(T value); + static abstract bool IsEncodingPad(T value); + } + + private readonly struct Base64CharValidatable : IBase64Validatable + { + private static readonly SearchValues s_validBase64Chars = SearchValues.Create("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"); + + public static int IndexOfAnyExcept(ReadOnlySpan span) => span.IndexOfAnyExcept(s_validBase64Chars); + public static bool IsWhiteSpace(char value) => Base64.IsWhiteSpace(value); + public static bool IsEncodingPad(char value) => value == EncodingPad; + } + + private readonly struct Base64ByteValidatable : IBase64Validatable + { + private static readonly SearchValues s_validBase64Chars = SearchValues.Create(EncodingMap); + + public static int IndexOfAnyExcept(ReadOnlySpan span) => span.IndexOfAnyExcept(s_validBase64Chars); + public static bool IsWhiteSpace(byte value) => Base64.IsWhiteSpace(value); + public static bool IsEncodingPad(byte value) => value == EncodingPad; + } + } +} diff --git a/src/libraries/System.Runtime/ref/System.Runtime.cs b/src/libraries/System.Runtime/ref/System.Runtime.cs index 5e3e26c5225f3..f5a8ae284b722 100644 --- a/src/libraries/System.Runtime/ref/System.Runtime.cs +++ b/src/libraries/System.Runtime/ref/System.Runtime.cs @@ -7342,6 +7342,10 @@ public static partial class Base64 public static System.Buffers.OperationStatus EncodeToUtf8InPlace(System.Span buffer, int dataLength, out int bytesWritten) { throw null; } public static int GetMaxDecodedFromUtf8Length(int length) { throw null; } public static int GetMaxEncodedToUtf8Length(int length) { throw null; } + public static bool IsValid(System.ReadOnlySpan base64Text) { throw null; } + public static bool IsValid(System.ReadOnlySpan base64Text, out int decodedLength) { throw null; } + public static bool IsValid(System.ReadOnlySpan base64TextUtf8) { throw null; } + public static bool IsValid(System.ReadOnlySpan base64TextUtf8, out int decodedLength) { throw null; } } } namespace System.CodeDom.Compiler diff --git a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/Base64Transforms.cs b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/Base64Transforms.cs index 544eb3620782d..275a708b38fb6 100644 --- a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/Base64Transforms.cs +++ b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/Base64Transforms.cs @@ -108,9 +108,21 @@ protected virtual void Dispose(bool disposing) { } public class FromBase64Transform : ICryptoTransform { + /// Characters considered whitespace. + /// + /// We assume ASCII encoded data. If there is any non-ASCII char, it is invalid + /// Base64 and will be caught during decoding. + /// SPACE 32 + /// TAB 9 + /// LF 10 + /// VTAB 11 + /// FORM FEED 12 + /// CR 13 + /// + private static readonly SearchValues s_whiteSpace = SearchValues.Create(" \t\n\v\f\r"u8); + private readonly FromBase64TransformMode _whitespaces; private byte[] _inputBuffer = new byte[4]; private int _inputIndex; - private readonly FromBase64TransformMode _whitespaces; public FromBase64Transform() : this(FromBase64TransformMode.IgnoreWhiteSpaces) { } public FromBase64Transform(FromBase64TransformMode whitespaces) @@ -223,41 +235,35 @@ public byte[] TransformFinalBlock(byte[] inputBuffer, int inputOffset, int input private Span AppendInputBuffers(ReadOnlySpan inputBuffer, Span transformBuffer) { - _inputBuffer.AsSpan(0, _inputIndex).CopyTo(transformBuffer); + int index = _inputIndex; + _inputBuffer.AsSpan(0, index).CopyTo(transformBuffer); if (_whitespaces == FromBase64TransformMode.DoNotIgnoreWhiteSpaces) { - inputBuffer.CopyTo(transformBuffer.Slice(_inputIndex)); - return transformBuffer.Slice(0, _inputIndex + inputBuffer.Length); + if (inputBuffer.IndexOfAny(s_whiteSpace) >= 0) + { + ThrowHelper.ThrowBase64FormatException(); + } } else { - int count = _inputIndex; - for (int i = 0; i < inputBuffer.Length; i++) + int whitespaceIndex; + while ((whitespaceIndex = inputBuffer.IndexOfAny(s_whiteSpace)) >= 0) { - if (!IsWhitespace(inputBuffer[i])) + inputBuffer.Slice(0, whitespaceIndex).CopyTo(transformBuffer.Slice(index)); + index += whitespaceIndex; + inputBuffer = inputBuffer.Slice(whitespaceIndex); + + do { - transformBuffer[count++] = inputBuffer[i]; + inputBuffer = inputBuffer.Slice(1); } + while (!inputBuffer.IsEmpty && s_whiteSpace.Contains(inputBuffer[0])); } - - return transformBuffer.Slice(0, count); } - } - - private static bool IsWhitespace(byte value) - { - // We assume ASCII encoded data. If there is any non-ASCII char, it is invalid - // Base64 and will be caught during decoding. - - // SPACE 32 - // TAB 9 - // LF 10 - // VTAB 11 - // FORM FEED 12 - // CR 13 - return value == 32 || ((uint)value - 9 <= (13 - 9)); + inputBuffer.CopyTo(transformBuffer.Slice(index)); + return transformBuffer.Slice(0, index + inputBuffer.Length); } [MethodImpl(MethodImplOptions.AggressiveInlining)]