Skip to content

Commit

Permalink
Vectorize ProbabilisticMap.LastIndexOfAny (dotnet#102331)
Browse files Browse the repository at this point in the history
* Vectorize ProbabilisticMap.LastIndexOfAny

* Fix loop in TryFindLastMatchOverlapped

* Use Avx512 name suffix with more helpers
  • Loading branch information
MihaZupan authored May 17, 2024
1 parent 0c64e66 commit 68ae561
Showing 1 changed file with 295 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,12 @@ internal static int IndexOfAny<TUseFastContains>(ref char searchSpace, int searc
internal static int LastIndexOfAny<TUseFastContains>(ref char searchSpace, int searchSpaceLength, ref ProbabilisticMapState state)
where TUseFastContains : struct, SearchValues.IRuntimeConst
{
// TODO: Implement vectorized LastIndexOfAny.
if ((Sse41.IsSupported || AdvSimd.Arm64.IsSupported) && searchSpaceLength >= 16)
{
return Vector512.IsHardwareAccelerated && Avx512Vbmi.VL.IsSupported
? LastIndexOfAnyVectorizedAvx512<TUseFastContains>(ref searchSpace, searchSpaceLength, ref state)
: LastIndexOfAnyVectorized<TUseFastContains>(ref searchSpace, searchSpaceLength, ref state);
}

return ProbabilisticMapState.LastIndexOfAnySimpleLoop<TUseFastContains, IndexOfAnyAsciiSearcher.DontNegate>(ref searchSpace, searchSpaceLength, ref state);
}
Expand Down Expand Up @@ -419,7 +424,7 @@ private static int IndexOfAnyVectorizedAvx512<TUseFastContains>(ref char searchS

if (result != Vector512<byte>.Zero)
{
if (TryFindMatch<TUseFastContains>(ref cur, PackedSpanHelpers.FixUpPackedVector512Result(result).ExtractMostSignificantBits(), ref state, out int index))
if (TryFindMatchAvx512<TUseFastContains>(ref cur, PackedSpanHelpers.FixUpPackedVector512Result(result).ExtractMostSignificantBits(), ref state, out int index))
{
return MatchOffset(ref searchSpace, ref cur) + index;
}
Expand Down Expand Up @@ -449,7 +454,7 @@ private static int IndexOfAnyVectorizedAvx512<TUseFastContains>(ref char searchS

if (result != Vector512<byte>.Zero)
{
if (TryFindMatchOverlapped<TUseFastContains>(ref searchSpace, searchSpaceLength, PackedSpanHelpers.FixUpPackedVector512Result(result).ExtractMostSignificantBits(), ref state, out int index))
if (TryFindMatchOverlappedAvx512<TUseFastContains>(ref searchSpace, searchSpaceLength, PackedSpanHelpers.FixUpPackedVector512Result(result).ExtractMostSignificantBits(), ref state, out int index))
{
return index;
}
Expand All @@ -466,7 +471,7 @@ private static int IndexOfAnyVectorizedAvx512<TUseFastContains>(ref char searchS

if (result != Vector256<byte>.Zero)
{
if (TryFindMatchOverlapped<TUseFastContains>(ref searchSpace, searchSpaceLength, PackedSpanHelpers.FixUpPackedVector256Result(result).ExtractMostSignificantBits(), ref state, out int index))
if (TryFindMatchOverlappedAvx512<TUseFastContains>(ref searchSpace, searchSpaceLength, PackedSpanHelpers.FixUpPackedVector256Result(result).ExtractMostSignificantBits(), ref state, out int index))
{
return index;
}
Expand Down Expand Up @@ -568,6 +573,187 @@ private static int IndexOfAnyVectorized<TUseFastContains>(ref char searchSpace,
return -1;
}

[CompExactlyDependsOn(typeof(Avx512Vbmi.VL))]
private static int LastIndexOfAnyVectorizedAvx512<TUseFastContains>(ref char searchSpace, int searchSpaceLength, ref ProbabilisticMapState state)
where TUseFastContains : struct, SearchValues.IRuntimeConst
{
Debug.Assert(Avx512Vbmi.VL.IsSupported);
Debug.Assert(searchSpaceLength >= 16);

ref char cur = ref Unsafe.Add(ref searchSpace, searchSpaceLength);

Vector256<byte> charMap256 = Vector256.LoadUnsafe(ref Unsafe.As<ProbabilisticMap, byte>(ref state.Map));

if (searchSpaceLength > 32)
{
Vector512<byte> charMap512 = Vector512.Create(charMap256, charMap256);

if (searchSpaceLength > 64)
{
ref char lastStartVector = ref Unsafe.Add(ref searchSpace, 64);

while (true)
{
Debug.Assert(Unsafe.ByteOffset(ref searchSpace, ref cur) >= 64 * sizeof(char));

cur = ref Unsafe.Subtract(ref cur, 64);

Vector512<byte> result = ContainsMask64CharsAvx512(charMap512, ref cur, ref Unsafe.Add(ref cur, Vector512<ushort>.Count));

if (result != Vector512<byte>.Zero)
{
if (TryFindLastMatchAvx512<TUseFastContains>(ref cur, PackedSpanHelpers.FixUpPackedVector512Result(result).ExtractMostSignificantBits(), ref state, out int index))
{
return MatchOffset(ref searchSpace, ref cur) + index;
}
}

if (!Unsafe.IsAddressGreaterThan(ref cur, ref lastStartVector))
{
if (Unsafe.AreSame(ref cur, ref searchSpace))
{
break;
}

// Adjust the current vector and do one last iteration.
cur = ref lastStartVector;
}
}
}
else
{
Debug.Assert(searchSpaceLength is > 32 and <= 64);
Debug.Assert(Unsafe.ByteOffset(ref searchSpace, ref cur) >= 32 * sizeof(char));

// Process the first and last vector in the search space.
// They may overlap, but we'll handle that in the index calculation if we do get a match.
Vector512<byte> result = ContainsMask64CharsAvx512(charMap512, ref searchSpace, ref Unsafe.Subtract(ref cur, Vector512<ushort>.Count));

if (result != Vector512<byte>.Zero)
{
if (TryFindLastMatchOverlappedAvx512<TUseFastContains>(ref searchSpace, searchSpaceLength, PackedSpanHelpers.FixUpPackedVector512Result(result).ExtractMostSignificantBits(), ref state, out int index))
{
return index;
}
}
}
}
else
{
Debug.Assert(searchSpaceLength is >= 16 and <= 32);
Debug.Assert(Unsafe.ByteOffset(ref searchSpace, ref cur) >= 16 * sizeof(char));

// Process the first and last vector in the search space.
// They may overlap, but we'll handle that in the index calculation if we do get a match.
Vector256<byte> result = ContainsMask32CharsAvx512(charMap256, ref searchSpace, ref Unsafe.Subtract(ref cur, Vector256<ushort>.Count));

if (result != Vector256<byte>.Zero)
{
if (TryFindLastMatchOverlappedAvx512<TUseFastContains>(ref searchSpace, searchSpaceLength, PackedSpanHelpers.FixUpPackedVector256Result(result).ExtractMostSignificantBits(), ref state, out int index))
{
return index;
}
}
}

return -1;
}

[CompExactlyDependsOn(typeof(AdvSimd.Arm64))]
[CompExactlyDependsOn(typeof(Sse41))]
private static int LastIndexOfAnyVectorized<TUseFastContains>(ref char searchSpace, int searchSpaceLength, ref ProbabilisticMapState state)
where TUseFastContains : struct, SearchValues.IRuntimeConst
{
Debug.Assert(Sse41.IsSupported || AdvSimd.Arm64.IsSupported);
Debug.Assert(searchSpaceLength >= 16);

ref char cur = ref Unsafe.Add(ref searchSpace, searchSpaceLength);

Vector128<byte> charMapLower = Vector128.LoadUnsafe(ref Unsafe.As<ProbabilisticMap, byte>(ref state.Map));
Vector128<byte> charMapUpper = Vector128.LoadUnsafe(ref Unsafe.As<ProbabilisticMap, byte>(ref state.Map), (nuint)Vector128<byte>.Count);

#pragma warning disable IntrinsicsInSystemPrivateCoreLibAttributeNotSpecificEnough // In this case, we have an else clause which has the same semantic meaning whether or not Avx2 is considered supported or unsupported
if (Avx2.IsSupported && searchSpaceLength >= 32)
#pragma warning restore IntrinsicsInSystemPrivateCoreLibAttributeNotSpecificEnough
{
Vector256<byte> charMapLower256 = Vector256.Create(charMapLower, charMapLower);
Vector256<byte> charMapUpper256 = Vector256.Create(charMapUpper, charMapUpper);

ref char lastStartVectorAvx2 = ref Unsafe.Add(ref searchSpace, 32);

while (true)
{
Debug.Assert(Unsafe.ByteOffset(ref searchSpace, ref cur) >= 32 * sizeof(char));

cur = ref Unsafe.Subtract(ref cur, 32);

Vector256<byte> result = ContainsMask32CharsAvx2(charMapLower256, charMapUpper256, ref cur);

if (result != Vector256<byte>.Zero)
{
if (TryFindLastMatch<TUseFastContains>(ref cur, PackedSpanHelpers.FixUpPackedVector256Result(result).ExtractMostSignificantBits(), ref state, out int index))
{
return MatchOffset(ref searchSpace, ref cur) + index;
}
}

if (!Unsafe.IsAddressGreaterThan(ref cur, ref lastStartVectorAvx2))
{
if (Unsafe.AreSame(ref cur, ref searchSpace))
{
return -1;
}

if (Unsafe.ByteOffset(ref searchSpace, ref cur) > 16 * sizeof(char))
{
// If we have more than 16 characters left to process, we can
// adjust the current vector and do one last iteration of Avx2.
cur = ref lastStartVectorAvx2;
}
else
{
// Otherwise adjust the vector such that we'll only need to do a single
// iteration of ContainsMask16Chars below.
cur = ref Unsafe.Add(ref searchSpace, 16);
break;
}
}
}
}

ref char lastStartVector = ref Unsafe.Add(ref searchSpace, 16);

while (true)
{
Debug.Assert(Unsafe.ByteOffset(ref searchSpace, ref cur) >= 16 * sizeof(char));

cur = ref Unsafe.Subtract(ref cur, 16);

Vector128<byte> result = ContainsMask16Chars(charMapLower, charMapUpper, ref cur);

if (result != Vector128<byte>.Zero)
{
if (TryFindLastMatch<TUseFastContains>(ref cur, result.ExtractMostSignificantBits(), ref state, out int index))
{
return MatchOffset(ref searchSpace, ref cur) + index;
}
}

if (!Unsafe.IsAddressGreaterThan(ref cur, ref lastStartVector))
{
if (Unsafe.AreSame(ref cur, ref searchSpace))
{
break;
}

// Adjust the current vector and do one last iteration.
cur = ref lastStartVector;
}
}

return -1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int MatchOffset(ref char searchSpace, ref char cur) =>
(int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref cur) / sizeof(char));
Expand All @@ -594,7 +780,7 @@ private static bool TryFindMatch<TUseFastContains>(ref char cur, uint mask, ref
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool TryFindMatchOverlapped<TUseFastContains>(ref char cur, int searchSpaceLength, uint mask, ref ProbabilisticMapState state, out int index)
private static bool TryFindMatchOverlappedAvx512<TUseFastContains>(ref char cur, int searchSpaceLength, uint mask, ref ProbabilisticMapState state, out int index)
where TUseFastContains : struct, SearchValues.IRuntimeConst
{
do
Expand Down Expand Up @@ -622,7 +808,7 @@ private static bool TryFindMatchOverlapped<TUseFastContains>(ref char cur, int s
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool TryFindMatch<TUseFastContains>(ref char cur, ulong mask, ref ProbabilisticMapState state, out int index)
private static bool TryFindMatchAvx512<TUseFastContains>(ref char cur, ulong mask, ref ProbabilisticMapState state, out int index)
where TUseFastContains : struct, SearchValues.IRuntimeConst
{
do
Expand All @@ -643,7 +829,7 @@ private static bool TryFindMatch<TUseFastContains>(ref char cur, ulong mask, ref
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool TryFindMatchOverlapped<TUseFastContains>(ref char cur, int searchSpaceLength, ulong mask, ref ProbabilisticMapState state, out int index)
private static bool TryFindMatchOverlappedAvx512<TUseFastContains>(ref char cur, int searchSpaceLength, ulong mask, ref ProbabilisticMapState state, out int index)
where TUseFastContains : struct, SearchValues.IRuntimeConst
{
do
Expand All @@ -670,6 +856,108 @@ private static bool TryFindMatchOverlapped<TUseFastContains>(ref char cur, int s
return false;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool TryFindLastMatch<TUseFastContains>(ref char cur, uint mask, ref ProbabilisticMapState state, out int index)
where TUseFastContains : struct, SearchValues.IRuntimeConst
{
do
{
index = 31 - BitOperations.LeadingZeroCount(mask);

if (state.ConfirmProbabilisticMatch<TUseFastContains>(Unsafe.Add(ref cur, index)))
{
return true;
}

// Clear the highest set bit
mask = BitOperations.FlipBit(mask, index);
}
while (mask != 0);

index = 0;
return false;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool TryFindLastMatchOverlappedAvx512<TUseFastContains>(ref char cur, int searchSpaceLength, uint mask, ref ProbabilisticMapState state, out int index)
where TUseFastContains : struct, SearchValues.IRuntimeConst
{
do
{
index = 31 - BitOperations.LeadingZeroCount(mask);

// Clear the highest set bit
mask = BitOperations.FlipBit(mask, index);

if (index >= Vector256<ushort>.Count)
{
// The potential match is in the second vector.
// Fixup the index to account for how we loaded the second overlapped vector.
index += searchSpaceLength - (2 * Vector256<ushort>.Count);
}

if (state.ConfirmProbabilisticMatch<TUseFastContains>(Unsafe.Add(ref cur, index)))
{
return true;
}
}
while (mask != 0);

index = 0;
return false;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool TryFindLastMatchAvx512<TUseFastContains>(ref char cur, ulong mask, ref ProbabilisticMapState state, out int index)
where TUseFastContains : struct, SearchValues.IRuntimeConst
{
do
{
index = 63 - BitOperations.LeadingZeroCount(mask);

if (state.ConfirmProbabilisticMatch<TUseFastContains>(Unsafe.Add(ref cur, index)))
{
return true;
}

// Clear the highest set bit
mask = BitOperations.FlipBit(mask, index);
}
while (mask != 0);

index = 0;
return false;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool TryFindLastMatchOverlappedAvx512<TUseFastContains>(ref char cur, int searchSpaceLength, ulong mask, ref ProbabilisticMapState state, out int index)
where TUseFastContains : struct, SearchValues.IRuntimeConst
{
do
{
index = 63 - BitOperations.LeadingZeroCount(mask);

// Clear the highest set bit
mask = BitOperations.FlipBit(mask, index);

if (index >= Vector512<ushort>.Count)
{
// The potential match is in the second vector.
// Fixup the index to account for how we loaded the second overlapped vector.
index += searchSpaceLength - (2 * Vector512<ushort>.Count);
}

if (state.ConfirmProbabilisticMatch<TUseFastContains>(Unsafe.Add(ref cur, index)))
{
return true;
}
}
while (mask != 0);

index = 0;
return false;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static int IndexOfAnySimpleLoop<TNegator>(ref char searchSpace, int searchSpaceLength, ReadOnlySpan<char> values)
where TNegator : struct, IndexOfAnyAsciiSearcher.INegator
Expand Down

0 comments on commit 68ae561

Please sign in to comment.