diff --git a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs index 6a6c2f018d8c0..956562623785d 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Buffers/Text/Base64Decoder.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; -using System.Runtime; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.Intrinsics; @@ -35,66 +34,18 @@ public static partial class Base64 /// - InvalidData - if the input contains bytes outside of the expected base64 range, or if it contains invalid/more than two padding characters, /// or if the input is incomplete (i.e. not a multiple of 4) and is . /// - public static OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Span bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) - { - OperationStatus status = OperationStatus.Done; - bytesConsumed = 0; - bytesWritten = 0; + public static OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Span bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) => + DecodeFromUtf8(utf8, bytes, out bytesConsumed, out bytesWritten, isFinalBlock, ignoreWhiteSpace: true); - while (!utf8.IsEmpty) + private static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Span bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock, bool ignoreWhiteSpace) + { + if (utf8.IsEmpty) { - status = DecodeFromUtf8Core(utf8, bytes, out int localConsumed, out int localWritten, isFinalBlock); - bytesConsumed += localConsumed; - bytesWritten += localWritten; - - if (status is not OperationStatus.InvalidData) - { - break; - } - - utf8 = utf8.Slice(localConsumed); - bytes = bytes.Slice(localWritten); - - if (utf8.IsEmpty) - { - break; - } - - localConsumed = IndexOfAnyExceptWhiteSpace(utf8); - if (localConsumed < 0) - { - // The remainder of the input is all whitespace. Mark it all as having been consumed, - // and mark the operation as being done. - bytesConsumed += utf8.Length; - status = OperationStatus.Done; - break; - } - - if (localConsumed == 0) - { - // Non-whitespace was found at the beginning of the input. Since it wasn't consumed - // by the previous call to DecodeFromUtf8Core, it must be part of a Base64 sequence - // that was interrupted by whitespace or something else considered invalid. - // Fall back to block-wise decoding. This is very slow, but it's also very non-standard - // formatting of the input; whitespace is typically only found between blocks, such as - // when Convert.ToBase64String inserts a line break every 76 output characters. - return DecodeWithWhiteSpaceBlockwise(utf8, bytes, ref bytesConsumed, ref bytesWritten, isFinalBlock); - } - - // Skip over the starting whitespace and continue. - bytesConsumed += localConsumed; - utf8 = utf8.Slice(localConsumed); + bytesConsumed = 0; + bytesWritten = 0; + return OperationStatus.Done; } - return status; - } - - /// - /// Core logic for decoding UTF-8 encoded text in base 64 into binary data. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe OperationStatus DecodeFromUtf8Core(ReadOnlySpan utf8, Span bytes, out int bytesConsumed, out int bytesWritten, bool isFinalBlock = true) - { fixed (byte* srcBytes = &MemoryMarshal.GetReference(utf8)) fixed (byte* destBytes = &MemoryMarshal.GetReference(bytes)) { @@ -301,7 +252,59 @@ private static unsafe OperationStatus DecodeFromUtf8Core(ReadOnlySpan utf8 InvalidDataExit: bytesConsumed = (int)(src - srcBytes); bytesWritten = (int)(dest - destBytes); - return OperationStatus.InvalidData; + return ignoreWhiteSpace ? + InvalidDataFallback(utf8, bytes, ref bytesConsumed, ref bytesWritten, isFinalBlock) : + OperationStatus.InvalidData; + } + + static OperationStatus InvalidDataFallback(ReadOnlySpan utf8, Span bytes, ref int bytesConsumed, ref int bytesWritten, bool isFinalBlock) + { + utf8 = utf8.Slice(bytesConsumed); + bytes = bytes.Slice(bytesWritten); + + OperationStatus status; + do + { + int localConsumed = IndexOfAnyExceptWhiteSpace(utf8); + if (localConsumed < 0) + { + // The remainder of the input is all whitespace. Mark it all as having been consumed, + // and mark the operation as being done. + bytesConsumed += utf8.Length; + status = OperationStatus.Done; + break; + } + + if (localConsumed == 0) + { + // Non-whitespace was found at the beginning of the input. Since it wasn't consumed + // by the previous call to DecodeFromUtf8, it must be part of a Base64 sequence + // that was interrupted by whitespace or something else considered invalid. + // Fall back to block-wise decoding. This is very slow, but it's also very non-standard + // formatting of the input; whitespace is typically only found between blocks, such as + // when Convert.ToBase64String inserts a line break every 76 output characters. + return DecodeWithWhiteSpaceBlockwise(utf8, bytes, ref bytesConsumed, ref bytesWritten, isFinalBlock); + } + + // Skip over the starting whitespace and continue. + bytesConsumed += localConsumed; + utf8 = utf8.Slice(localConsumed); + + // Try again after consumed whitespace + status = DecodeFromUtf8(utf8, bytes, out localConsumed, out int localWritten, isFinalBlock, ignoreWhiteSpace: false); + bytesConsumed += localConsumed; + bytesWritten += localWritten; + if (status is not OperationStatus.InvalidData) + { + break; + } + + utf8 = utf8.Slice(localConsumed); + bytes = bytes.Slice(localWritten); + } + while (!utf8.IsEmpty); + + return status; } } @@ -337,18 +340,112 @@ public static int GetMaxDecodedFromUtf8Length(int length) /// It does not return NeedMoreData since this method tramples the data in the buffer and /// hence can only be called once with all the data in the buffer. /// - public static OperationStatus DecodeFromUtf8InPlace(Span buffer, out int bytesWritten) - { - OperationStatus status = DecodeFromUtf8InPlaceCore(buffer, out bytesWritten, out uint sourceIndex); - Debug.Assert(status is OperationStatus.Done or OperationStatus.InvalidData, "These are the only statuses the method is coded to return."); + public static OperationStatus DecodeFromUtf8InPlace(Span buffer, out int bytesWritten) => + DecodeFromUtf8InPlace(buffer, out bytesWritten, ignoreWhiteSpace: true); - if (status != OperationStatus.Done) + private static unsafe OperationStatus DecodeFromUtf8InPlace(Span buffer, out int bytesWritten, bool ignoreWhiteSpace) + { + if (buffer.IsEmpty) { - // The input may have whitespace, attempt to decode while ignoring whitespace. - status = DecodeWithWhiteSpaceFromUtf8InPlace(buffer, ref bytesWritten, (int)sourceIndex); + bytesWritten = 0; + return OperationStatus.Done; } - return status; + fixed (byte* bufferBytes = &MemoryMarshal.GetReference(buffer)) + { + uint bufferLength = (uint)buffer.Length; + uint sourceIndex = 0; + uint destIndex = 0; + + // only decode input if it is a multiple of 4 + if (bufferLength % 4 != 0) + { + goto InvalidExit; + } + + ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap); + + while (sourceIndex < bufferLength - 4) + { + int result = Decode(bufferBytes + sourceIndex, ref decodingMap); + if (result < 0) + { + goto InvalidExit; + } + + WriteThreeLowOrderBytes(bufferBytes + destIndex, result); + destIndex += 3; + sourceIndex += 4; + } + + uint t0 = bufferBytes[bufferLength - 4]; + uint t1 = bufferBytes[bufferLength - 3]; + uint t2 = bufferBytes[bufferLength - 2]; + uint t3 = bufferBytes[bufferLength - 1]; + + int i0 = Unsafe.Add(ref decodingMap, t0); + int i1 = Unsafe.Add(ref decodingMap, t1); + + i0 <<= 18; + i1 <<= 12; + + i0 |= i1; + + if (t3 != EncodingPad) + { + int i2 = Unsafe.Add(ref decodingMap, t2); + int i3 = Unsafe.Add(ref decodingMap, t3); + + i2 <<= 6; + + i0 |= i3; + i0 |= i2; + + if (i0 < 0) + { + goto InvalidExit; + } + + WriteThreeLowOrderBytes(bufferBytes + destIndex, i0); + destIndex += 3; + } + else if (t2 != EncodingPad) + { + int i2 = Unsafe.Add(ref decodingMap, t2); + + i2 <<= 6; + + i0 |= i2; + + if (i0 < 0) + { + goto InvalidExit; + } + + bufferBytes[destIndex] = (byte)(i0 >> 16); + bufferBytes[destIndex + 1] = (byte)(i0 >> 8); + destIndex += 2; + } + else + { + if (i0 < 0) + { + goto InvalidExit; + } + + bufferBytes[destIndex] = (byte)(i0 >> 16); + destIndex += 1; + } + + bytesWritten = (int)destIndex; + return OperationStatus.Done; + + InvalidExit: + bytesWritten = (int)destIndex; + return ignoreWhiteSpace ? + DecodeWithWhiteSpaceFromUtf8InPlace(buffer, ref bytesWritten, sourceIndex) : // The input may have whitespace, attempt to decode while ignoring whitespace. + OperationStatus.InvalidData; + } } private static OperationStatus DecodeWithWhiteSpaceBlockwise(ReadOnlySpan utf8, Span bytes, ref int bytesConsumed, ref int bytesWritten, bool isFinalBlock = true) @@ -403,7 +500,7 @@ private static OperationStatus DecodeWithWhiteSpaceBlockwise(ReadOnlySpan localIsFinalBlock = false; } - status = DecodeFromUtf8Core(buffer.Slice(0, bufferIdx), bytes, out int localConsumed, out int localWritten, localIsFinalBlock); + status = DecodeFromUtf8(buffer.Slice(0, bufferIdx), bytes, out int localConsumed, out int localWritten, localIsFinalBlock, ignoreWhiteSpace: false); bytesConsumed += localConsumed; bytesWritten += localWritten; @@ -449,112 +546,7 @@ private static int GetPaddingCount(ref byte ptrToLastElement) return padding; } - /// - /// Core logic for decoding UTF-8 encoded text in base 64 into binary data in place. - /// - private static unsafe OperationStatus DecodeFromUtf8InPlaceCore(Span buffer, out int bytesWritten, out uint sourceIndex) - { - fixed (byte* bufferBytes = &MemoryMarshal.GetReference(buffer)) - { - int bufferLength = buffer.Length; - sourceIndex = 0; - uint destIndex = 0; - - // only decode input if it is a multiple of 4 - if (bufferLength != ((bufferLength >> 2) * 4)) - { - goto InvalidExit; - } - if (bufferLength == 0) - { - goto DoneExit; - } - - ref sbyte decodingMap = ref MemoryMarshal.GetReference(DecodingMap); - - while (sourceIndex < bufferLength - 4) - { - int result = Decode(bufferBytes + sourceIndex, ref decodingMap); - if (result < 0) - { - goto InvalidExit; - } - - WriteThreeLowOrderBytes(bufferBytes + destIndex, result); - destIndex += 3; - sourceIndex += 4; - } - - uint t0 = bufferBytes[bufferLength - 4]; - uint t1 = bufferBytes[bufferLength - 3]; - uint t2 = bufferBytes[bufferLength - 2]; - uint t3 = bufferBytes[bufferLength - 1]; - - int i0 = Unsafe.Add(ref decodingMap, (IntPtr)t0); - int i1 = Unsafe.Add(ref decodingMap, (IntPtr)t1); - - i0 <<= 18; - i1 <<= 12; - - i0 |= i1; - - if (t3 != EncodingPad) - { - int i2 = Unsafe.Add(ref decodingMap, (IntPtr)t2); - int i3 = Unsafe.Add(ref decodingMap, (IntPtr)t3); - - i2 <<= 6; - - i0 |= i3; - i0 |= i2; - - if (i0 < 0) - { - goto InvalidExit; - } - - WriteThreeLowOrderBytes(bufferBytes + destIndex, i0); - destIndex += 3; - } - else if (t2 != EncodingPad) - { - int i2 = Unsafe.Add(ref decodingMap, (IntPtr)t2); - - i2 <<= 6; - - i0 |= i2; - - if (i0 < 0) - { - goto InvalidExit; - } - - bufferBytes[destIndex] = (byte)(i0 >> 16); - bufferBytes[destIndex + 1] = (byte)(i0 >> 8); - destIndex += 2; - } - else - { - if (i0 < 0) - { - goto InvalidExit; - } - - bufferBytes[destIndex] = (byte)(i0 >> 16); - destIndex += 1; - } - - DoneExit: - bytesWritten = (int)destIndex; - return OperationStatus.Done; - - InvalidExit: - bytesWritten = (int)destIndex; - return OperationStatus.InvalidData; - } - } - - private static OperationStatus DecodeWithWhiteSpaceFromUtf8InPlace(Span utf8, ref int destIndex, int sourceIndex) + private static OperationStatus DecodeWithWhiteSpaceFromUtf8InPlace(Span utf8, ref int destIndex, uint sourceIndex) { const int BlockSize = 4; Span buffer = stackalloc byte[BlockSize]; @@ -564,20 +556,20 @@ private static OperationStatus DecodeWithWhiteSpaceFromUtf8InPlace(Span ut bool hasPaddingBeenProcessed = false; int localBytesWritten = 0; - while ((uint)sourceIndex < (uint)utf8.Length) + while (sourceIndex < (uint)utf8.Length) { int bufferIdx = 0; while (bufferIdx < BlockSize) { - if ((uint)sourceIndex >= (uint)utf8.Length) // TODO https://github.com/dotnet/runtime/issues/83349: move into the while condition once fixed + if (sourceIndex >= (uint)utf8.Length) // TODO https://github.com/dotnet/runtime/issues/83349: move into the while condition once fixed { break; } - if (!IsWhiteSpace(utf8[sourceIndex])) + if (!IsWhiteSpace(utf8[(int)sourceIndex])) { - buffer[bufferIdx] = utf8[sourceIndex]; + buffer[bufferIdx] = utf8[(int)sourceIndex]; bufferIdx++; } @@ -604,7 +596,7 @@ private static OperationStatus DecodeWithWhiteSpaceFromUtf8InPlace(Span ut break; } - status = DecodeFromUtf8InPlaceCore(buffer, out localBytesWritten, out _); + status = DecodeFromUtf8InPlace(buffer, out localBytesWritten, ignoreWhiteSpace: false); localDestIndex += localBytesWritten; hasPaddingBeenProcessed = localBytesWritten < 3; @@ -952,10 +944,10 @@ private static unsafe int Decode(byte* encodedBytes, ref sbyte decodingMap) uint t2 = encodedBytes[2]; uint t3 = encodedBytes[3]; - int i0 = Unsafe.Add(ref decodingMap, (IntPtr)t0); - int i1 = Unsafe.Add(ref decodingMap, (IntPtr)t1); - int i2 = Unsafe.Add(ref decodingMap, (IntPtr)t2); - int i3 = Unsafe.Add(ref decodingMap, (IntPtr)t3); + int i0 = Unsafe.Add(ref decodingMap, t0); + int i1 = Unsafe.Add(ref decodingMap, t1); + int i2 = Unsafe.Add(ref decodingMap, t2); + int i3 = Unsafe.Add(ref decodingMap, t3); i0 <<= 18; i1 <<= 12;