From d8ac6f40ea7050251dcde0f1ac0ebf2669b86083 Mon Sep 17 00:00:00 2001 From: Alex Peck Date: Thu, 28 Nov 2024 15:55:20 -0800 Subject: [PATCH] Optimize AVX intrinsics for .NET8 (#597) * pot * fix projs * test * opt * 512 * cleanup * params --------- --- .../Lfu/SketchFrequency.cs | 3 + .../Lfu/SketchIncrement.cs | 3 + BitFaster.Caching/Lfu/CmSketchCore.cs | 62 +++++-------------- 3 files changed, 21 insertions(+), 47 deletions(-) diff --git a/BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs b/BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs index b49e5bcf..137b9dcd 100644 --- a/BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs +++ b/BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs @@ -7,6 +7,9 @@ namespace BitFaster.Caching.Benchmarks.Lfu { +#if Windows + [DisassemblyDiagnoser(printSource: true, maxDepth: 4)] +#endif [SimpleJob(RuntimeMoniker.Net60)] [SimpleJob(RuntimeMoniker.Net80)] [SimpleJob(RuntimeMoniker.Net90)] diff --git a/BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs b/BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs index 6bcd0272..6f6ab1e7 100644 --- a/BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs +++ b/BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs @@ -7,6 +7,9 @@ namespace BitFaster.Caching.Benchmarks.Lfu { +#if Windows + [DisassemblyDiagnoser(printSource: true, maxDepth: 4)] +#endif [SimpleJob(RuntimeMoniker.Net60)] [SimpleJob(RuntimeMoniker.Net80)] [SimpleJob(RuntimeMoniker.Net90)] diff --git a/BitFaster.Caching/Lfu/CmSketchCore.cs b/BitFaster.Caching/Lfu/CmSketchCore.cs index fdb5d9f0..733b1ea0 100644 --- a/BitFaster.Caching/Lfu/CmSketchCore.cs +++ b/BitFaster.Caching/Lfu/CmSketchCore.cs @@ -255,20 +255,18 @@ private void Reset() } #if !NETSTANDARD2_0 + [MethodImpl(MethodImplOptions.AggressiveInlining)] private unsafe int EstimateFrequencyAvx(T value) { int blockHash = Spread(comparer.GetHashCode(value)); int counterHash = Rehash(blockHash); int block = (blockHash & blockMask) << 3; - Vector128 h = Vector128.Create(counterHash); - h = Avx2.ShiftRightLogicalVariable(h.AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32(); + Vector128 h = Avx2.ShiftRightLogicalVariable(Vector128.Create(counterHash).AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32(); + Vector128 index = Avx2.ShiftLeftLogical(Avx2.And(Avx2.ShiftRightLogical(h, 1), Vector128.Create(15)), 2); + Vector128 blockOffset = Avx2.Add(Avx2.Add(Vector128.Create(block), Avx2.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6)); - var index = Avx2.ShiftRightLogical(h, 1); - index = Avx2.And(index, Vector128.Create(15)); // j - counter index - Vector128 offset = Avx2.And(h, Vector128.Create(1)); - Vector128 blockOffset = Avx2.Add(Vector128.Create(block), offset); // i - table index - blockOffset = Avx2.Add(blockOffset, Vector128.Create(0, 2, 4, 6)); // + (i << 1) + Vector256 indexLong = Avx2.PermuteVar8x32(Vector256.Create(index, Vector128.Zero), Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7)).AsUInt64(); #if NET6_0_OR_GREATER long* tablePtr = tableAddr; @@ -276,18 +274,7 @@ private unsafe int EstimateFrequencyAvx(T value) fixed (long* tablePtr = table) #endif { - Vector256 tableVector = Avx2.GatherVector256(tablePtr, blockOffset, 8); - index = Avx2.ShiftLeftLogical(index, 2); - - // convert index from int to long via permute - Vector256 indexLong = Vector256.Create(index, Vector128.Zero).AsInt64(); - Vector256 permuteMask2 = Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7); - indexLong = Avx2.PermuteVar8x32(indexLong.AsInt32(), permuteMask2).AsInt64(); - tableVector = Avx2.ShiftRightLogicalVariable(tableVector, indexLong.AsUInt64()); - tableVector = Avx2.And(tableVector, Vector256.Create(0xfL)); - - Vector256 permuteMask = Vector256.Create(0, 2, 4, 6, 1, 3, 5, 7); - Vector128 count = Avx2.PermuteVar8x32(tableVector.AsInt32(), permuteMask) + Vector128 count = Avx2.PermuteVar8x32(Avx2.And(Avx2.ShiftRightLogicalVariable(Avx2.GatherVector256(tablePtr, blockOffset, 8), indexLong), Vector256.Create(0xfL)).AsInt32(), Vector256.Create(0, 2, 4, 6, 1, 3, 5, 7)) .GetLower() .AsUInt16(); @@ -302,20 +289,19 @@ private unsafe int EstimateFrequencyAvx(T value) } } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private unsafe void IncrementAvx(T value) { int blockHash = Spread(comparer.GetHashCode(value)); int counterHash = Rehash(blockHash); int block = (blockHash & blockMask) << 3; - Vector128 h = Vector128.Create(counterHash); - h = Avx2.ShiftRightLogicalVariable(h.AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32(); + Vector128 h = Avx2.ShiftRightLogicalVariable(Vector128.Create(counterHash).AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32(); + Vector128 index = Avx2.ShiftLeftLogical(Avx2.And(Avx2.ShiftRightLogical(h, 1), Vector128.Create(15)), 2); + Vector128 blockOffset = Avx2.Add(Avx2.Add(Vector128.Create(block), Avx2.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6)); - Vector128 index = Avx2.ShiftRightLogical(h, 1); - index = Avx2.And(index, Vector128.Create(15)); // j - counter index - Vector128 offset = Avx2.And(h, Vector128.Create(1)); - Vector128 blockOffset = Avx2.Add(Vector128.Create(block), offset); // i - table index - blockOffset = Avx2.Add(blockOffset, Vector128.Create(0, 2, 4, 6)); // + (i << 1) + Vector256 offsetLong = Avx2.PermuteVar8x32(Vector256.Create(index, Vector128.Zero), Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7)).AsUInt64(); + Vector256 mask = Avx2.ShiftLeftLogicalVariable(Vector256.Create(0xfL), offsetLong); #if NET6_0_OR_GREATER long* tablePtr = tableAddr; @@ -323,31 +309,13 @@ private unsafe void IncrementAvx(T value) fixed (long* tablePtr = table) #endif { - Vector256 tableVector = Avx2.GatherVector256(tablePtr, blockOffset, 8); - - // j == index - index = Avx2.ShiftLeftLogical(index, 2); - Vector256 offsetLong = Vector256.Create(index, Vector128.Zero).AsInt64(); - - Vector256 permuteMask = Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7); - offsetLong = Avx2.PermuteVar8x32(offsetLong.AsInt32(), permuteMask).AsInt64(); - - // mask = (0xfL << offset) - Vector256 fifteen = Vector256.Create(0xfL); - Vector256 mask = Avx2.ShiftLeftLogicalVariable(fifteen, offsetLong.AsUInt64()); - - // (table[i] & mask) != mask) // Note masked is 'equal' - therefore use AndNot below - Vector256 masked = Avx2.CompareEqual(Avx2.And(tableVector, mask), mask); - - // 1L << offset - Vector256 inc = Avx2.ShiftLeftLogicalVariable(Vector256.Create(1L), offsetLong.AsUInt64()); + Vector256 masked = Avx2.CompareEqual(Avx2.And(Avx2.GatherVector256(tablePtr, blockOffset, 8), mask), mask); // Mask to zero out non matches (add zero below) - first operand is NOT then AND result (order matters) - inc = Avx2.AndNot(masked, inc); + Vector256 inc = Avx2.AndNot(masked, Avx2.ShiftLeftLogicalVariable(Vector256.Create(1L), offsetLong)); - Vector256 result = Avx2.CompareEqual(masked.AsByte(), Vector256.Zero); - bool wasInc = Avx2.MoveMask(result.AsByte()) == unchecked((int)(0b1111_1111_1111_1111_1111_1111_1111_1111)); + bool wasInc = Avx2.MoveMask(Avx2.CompareEqual(masked.AsByte(), Vector256.Zero).AsByte()) == unchecked((int)(0b1111_1111_1111_1111_1111_1111_1111_1111)); tablePtr[blockOffset.GetElement(0)] += inc.GetElement(0); tablePtr[blockOffset.GetElement(1)] += inc.GetElement(1);