Skip to content

Commit

Permalink
Improve Span.Count (#103728)
Browse files Browse the repository at this point in the history
* Improve Span.Count

* Fix indentation
  • Loading branch information
MihaZupan authored Jun 21, 2024
1 parent ff0c538 commit fc70f8b
Showing 1 changed file with 23 additions and 56 deletions.
79 changes: 23 additions & 56 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3786,92 +3786,59 @@ public static int CountValueType<T>(ref T current, T value, int length) where T
{
Vector512<T> targetVector = Vector512.Create(value);
ref T oneVectorAwayFromEnd = ref Unsafe.Subtract(ref end, Vector512<T>.Count);
do
while (Unsafe.IsAddressLessThan(ref current, ref oneVectorAwayFromEnd))
{
count += BitOperations.PopCount(Vector512.Equals(Vector512.LoadUnsafe(ref current), targetVector).ExtractMostSignificantBits());
current = ref Unsafe.Add(ref current, Vector512<T>.Count);
}
while (!Unsafe.IsAddressGreaterThan(ref current, ref oneVectorAwayFromEnd));

// If there are just a few elements remaining, then processing these elements by the scalar loop
// is cheaper than doing bitmask + popcount on the full last vector. To avoid complicated type
// based checks, other remainder-count based logic to determine the correct cut-off, for simplicity
// a half-vector size is chosen (based on benchmarks).
uint remaining = (uint)Unsafe.ByteOffset(ref current, ref end) / (uint)Unsafe.SizeOf<T>();
if (remaining > Vector512<T>.Count / 2)
{
ulong mask = Vector512.Equals(Vector512.LoadUnsafe(ref oneVectorAwayFromEnd), targetVector).ExtractMostSignificantBits();

// The mask contains some elements that may be double-checked, so shift them away in order to get the correct pop-count.
uint overlaps = (uint)Vector512<T>.Count - remaining;
mask >>= (int)overlaps;
count += BitOperations.PopCount(mask);

return count;
}
// Count the last vector and mask off the elements that were already counted (number of elements between oneVectorAwayFromEnd and current).
ulong mask = Vector512.Equals(Vector512.LoadUnsafe(ref oneVectorAwayFromEnd), targetVector).ExtractMostSignificantBits();
mask >>= (int)((nuint)Unsafe.ByteOffset(ref oneVectorAwayFromEnd, ref current) / (uint)Unsafe.SizeOf<T>());
count += BitOperations.PopCount(mask);
}
else if (Vector256.IsHardwareAccelerated && length >= Vector256<T>.Count)
{
Vector256<T> targetVector = Vector256.Create(value);
ref T oneVectorAwayFromEnd = ref Unsafe.Subtract(ref end, Vector256<T>.Count);
do
while (Unsafe.IsAddressLessThan(ref current, ref oneVectorAwayFromEnd))
{
count += BitOperations.PopCount(Vector256.Equals(Vector256.LoadUnsafe(ref current), targetVector).ExtractMostSignificantBits());
current = ref Unsafe.Add(ref current, Vector256<T>.Count);
}
while (!Unsafe.IsAddressGreaterThan(ref current, ref oneVectorAwayFromEnd));

// If there are just a few elements remaining, then processing these elements by the scalar loop
// is cheaper than doing bitmask + popcount on the full last vector. To avoid complicated type
// based checks, other remainder-count based logic to determine the correct cut-off, for simplicity
// a half-vector size is chosen (based on benchmarks).
uint remaining = (uint)Unsafe.ByteOffset(ref current, ref end) / (uint)Unsafe.SizeOf<T>();
if (remaining > Vector256<T>.Count / 2)
{
uint mask = Vector256.Equals(Vector256.LoadUnsafe(ref oneVectorAwayFromEnd), targetVector).ExtractMostSignificantBits();

// The mask contains some elements that may be double-checked, so shift them away in order to get the correct pop-count.
uint overlaps = (uint)Vector256<T>.Count - remaining;
mask >>= (int)overlaps;
count += BitOperations.PopCount(mask);

return count;
}
// Count the last vector and mask off the elements that were already counted (number of elements between oneVectorAwayFromEnd and current).
uint mask = Vector256.Equals(Vector256.LoadUnsafe(ref oneVectorAwayFromEnd), targetVector).ExtractMostSignificantBits();
mask >>= (int)((nuint)Unsafe.ByteOffset(ref oneVectorAwayFromEnd, ref current) / (uint)Unsafe.SizeOf<T>());
count += BitOperations.PopCount(mask);
}
else
{
Vector128<T> targetVector = Vector128.Create(value);
ref T oneVectorAwayFromEnd = ref Unsafe.Subtract(ref end, Vector128<T>.Count);
do
while (Unsafe.IsAddressLessThan(ref current, ref oneVectorAwayFromEnd))
{
count += BitOperations.PopCount(Vector128.Equals(Vector128.LoadUnsafe(ref current), targetVector).ExtractMostSignificantBits());
current = ref Unsafe.Add(ref current, Vector128<T>.Count);
}
while (!Unsafe.IsAddressGreaterThan(ref current, ref oneVectorAwayFromEnd));

uint remaining = (uint)Unsafe.ByteOffset(ref current, ref end) / (uint)Unsafe.SizeOf<T>();
if (remaining > Vector128<T>.Count / 2)
{
uint mask = Vector128.Equals(Vector128.LoadUnsafe(ref oneVectorAwayFromEnd), targetVector).ExtractMostSignificantBits();

// The mask contains some elements that may be double-checked, so shift them away in order to get the correct pop-count.
uint overlaps = (uint)Vector128<T>.Count - remaining;
mask >>= (int)overlaps;
count += BitOperations.PopCount(mask);

return count;
}
// Count the last vector and mask off the elements that were already counted (number of elements between oneVectorAwayFromEnd and current).
uint mask = Vector128.Equals(Vector128.LoadUnsafe(ref oneVectorAwayFromEnd), targetVector).ExtractMostSignificantBits();
mask >>= (int)((nuint)Unsafe.ByteOffset(ref oneVectorAwayFromEnd, ref current) / (uint)Unsafe.SizeOf<T>());
count += BitOperations.PopCount(mask);
}
}

while (Unsafe.IsAddressLessThan(ref current, ref end))
else
{
if (current.Equals(value))
while (Unsafe.IsAddressLessThan(ref current, ref end))
{
count++;
}
if (current.Equals(value))
{
count++;
}

current = ref Unsafe.Add(ref current, 1);
current = ref Unsafe.Add(ref current, 1);
}
}

return count;
Expand Down

0 comments on commit fc70f8b

Please sign in to comment.