From fc70f8b7b9d477b5a1188f4698ad525090672fb2 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Fri, 21 Jun 2024 20:11:05 +0200 Subject: [PATCH] Improve Span.Count (#103728) * Improve Span.Count * Fix indentation --- .../src/System/SpanHelpers.T.cs | 79 ++++++------------- 1 file changed, 23 insertions(+), 56 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs index ced611ec012a9..65b17b29156a0 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs @@ -3786,92 +3786,59 @@ public static int CountValueType(ref T current, T value, int length) where T { Vector512 targetVector = Vector512.Create(value); ref T oneVectorAwayFromEnd = ref Unsafe.Subtract(ref end, Vector512.Count); - do + while (Unsafe.IsAddressLessThan(ref current, ref oneVectorAwayFromEnd)) { count += BitOperations.PopCount(Vector512.Equals(Vector512.LoadUnsafe(ref current), targetVector).ExtractMostSignificantBits()); current = ref Unsafe.Add(ref current, Vector512.Count); } - while (!Unsafe.IsAddressGreaterThan(ref current, ref oneVectorAwayFromEnd)); - - // If there are just a few elements remaining, then processing these elements by the scalar loop - // is cheaper than doing bitmask + popcount on the full last vector. To avoid complicated type - // based checks, other remainder-count based logic to determine the correct cut-off, for simplicity - // a half-vector size is chosen (based on benchmarks). - uint remaining = (uint)Unsafe.ByteOffset(ref current, ref end) / (uint)Unsafe.SizeOf(); - if (remaining > Vector512.Count / 2) - { - ulong mask = Vector512.Equals(Vector512.LoadUnsafe(ref oneVectorAwayFromEnd), targetVector).ExtractMostSignificantBits(); - // The mask contains some elements that may be double-checked, so shift them away in order to get the correct pop-count. - uint overlaps = (uint)Vector512.Count - remaining; - mask >>= (int)overlaps; - count += BitOperations.PopCount(mask); - - return count; - } + // Count the last vector and mask off the elements that were already counted (number of elements between oneVectorAwayFromEnd and current). + ulong mask = Vector512.Equals(Vector512.LoadUnsafe(ref oneVectorAwayFromEnd), targetVector).ExtractMostSignificantBits(); + mask >>= (int)((nuint)Unsafe.ByteOffset(ref oneVectorAwayFromEnd, ref current) / (uint)Unsafe.SizeOf()); + count += BitOperations.PopCount(mask); } else if (Vector256.IsHardwareAccelerated && length >= Vector256.Count) { Vector256 targetVector = Vector256.Create(value); ref T oneVectorAwayFromEnd = ref Unsafe.Subtract(ref end, Vector256.Count); - do + while (Unsafe.IsAddressLessThan(ref current, ref oneVectorAwayFromEnd)) { count += BitOperations.PopCount(Vector256.Equals(Vector256.LoadUnsafe(ref current), targetVector).ExtractMostSignificantBits()); current = ref Unsafe.Add(ref current, Vector256.Count); } - while (!Unsafe.IsAddressGreaterThan(ref current, ref oneVectorAwayFromEnd)); - // If there are just a few elements remaining, then processing these elements by the scalar loop - // is cheaper than doing bitmask + popcount on the full last vector. To avoid complicated type - // based checks, other remainder-count based logic to determine the correct cut-off, for simplicity - // a half-vector size is chosen (based on benchmarks). - uint remaining = (uint)Unsafe.ByteOffset(ref current, ref end) / (uint)Unsafe.SizeOf(); - if (remaining > Vector256.Count / 2) - { - uint mask = Vector256.Equals(Vector256.LoadUnsafe(ref oneVectorAwayFromEnd), targetVector).ExtractMostSignificantBits(); - - // The mask contains some elements that may be double-checked, so shift them away in order to get the correct pop-count. - uint overlaps = (uint)Vector256.Count - remaining; - mask >>= (int)overlaps; - count += BitOperations.PopCount(mask); - - return count; - } + // Count the last vector and mask off the elements that were already counted (number of elements between oneVectorAwayFromEnd and current). + uint mask = Vector256.Equals(Vector256.LoadUnsafe(ref oneVectorAwayFromEnd), targetVector).ExtractMostSignificantBits(); + mask >>= (int)((nuint)Unsafe.ByteOffset(ref oneVectorAwayFromEnd, ref current) / (uint)Unsafe.SizeOf()); + count += BitOperations.PopCount(mask); } else { Vector128 targetVector = Vector128.Create(value); ref T oneVectorAwayFromEnd = ref Unsafe.Subtract(ref end, Vector128.Count); - do + while (Unsafe.IsAddressLessThan(ref current, ref oneVectorAwayFromEnd)) { count += BitOperations.PopCount(Vector128.Equals(Vector128.LoadUnsafe(ref current), targetVector).ExtractMostSignificantBits()); current = ref Unsafe.Add(ref current, Vector128.Count); } - while (!Unsafe.IsAddressGreaterThan(ref current, ref oneVectorAwayFromEnd)); - - uint remaining = (uint)Unsafe.ByteOffset(ref current, ref end) / (uint)Unsafe.SizeOf(); - if (remaining > Vector128.Count / 2) - { - uint mask = Vector128.Equals(Vector128.LoadUnsafe(ref oneVectorAwayFromEnd), targetVector).ExtractMostSignificantBits(); - - // The mask contains some elements that may be double-checked, so shift them away in order to get the correct pop-count. - uint overlaps = (uint)Vector128.Count - remaining; - mask >>= (int)overlaps; - count += BitOperations.PopCount(mask); - return count; - } + // Count the last vector and mask off the elements that were already counted (number of elements between oneVectorAwayFromEnd and current). + uint mask = Vector128.Equals(Vector128.LoadUnsafe(ref oneVectorAwayFromEnd), targetVector).ExtractMostSignificantBits(); + mask >>= (int)((nuint)Unsafe.ByteOffset(ref oneVectorAwayFromEnd, ref current) / (uint)Unsafe.SizeOf()); + count += BitOperations.PopCount(mask); } } - - while (Unsafe.IsAddressLessThan(ref current, ref end)) + else { - if (current.Equals(value)) + while (Unsafe.IsAddressLessThan(ref current, ref end)) { - count++; - } + if (current.Equals(value)) + { + count++; + } - current = ref Unsafe.Add(ref current, 1); + current = ref Unsafe.Add(ref current, 1); + } } return count;