Skip to content

Commit

Permalink
Optimize AVX intrinsics for .NET8 (#597)
Browse files Browse the repository at this point in the history
* pot

* fix projs

* test

* opt

* 512

* cleanup

* params

---------
  • Loading branch information
bitfaster authored Nov 28, 2024
1 parent 1c2f221 commit d8ac6f4
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 47 deletions.
3 changes: 3 additions & 0 deletions BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

namespace BitFaster.Caching.Benchmarks.Lfu
{
#if Windows
[DisassemblyDiagnoser(printSource: true, maxDepth: 4)]
#endif
[SimpleJob(RuntimeMoniker.Net60)]
[SimpleJob(RuntimeMoniker.Net80)]
[SimpleJob(RuntimeMoniker.Net90)]
Expand Down
3 changes: 3 additions & 0 deletions BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

namespace BitFaster.Caching.Benchmarks.Lfu
{
#if Windows
[DisassemblyDiagnoser(printSource: true, maxDepth: 4)]
#endif
[SimpleJob(RuntimeMoniker.Net60)]
[SimpleJob(RuntimeMoniker.Net80)]
[SimpleJob(RuntimeMoniker.Net90)]
Expand Down
62 changes: 15 additions & 47 deletions BitFaster.Caching/Lfu/CmSketchCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -255,39 +255,26 @@ private void Reset()
}

#if !NETSTANDARD2_0
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private unsafe int EstimateFrequencyAvx(T value)
{
int blockHash = Spread(comparer.GetHashCode(value));
int counterHash = Rehash(blockHash);
int block = (blockHash & blockMask) << 3;

Vector128<int> h = Vector128.Create(counterHash);
h = Avx2.ShiftRightLogicalVariable(h.AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
Vector128<int> h = Avx2.ShiftRightLogicalVariable(Vector128.Create(counterHash).AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
Vector128<int> index = Avx2.ShiftLeftLogical(Avx2.And(Avx2.ShiftRightLogical(h, 1), Vector128.Create(15)), 2);
Vector128<int> blockOffset = Avx2.Add(Avx2.Add(Vector128.Create(block), Avx2.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6));

var index = Avx2.ShiftRightLogical(h, 1);
index = Avx2.And(index, Vector128.Create(15)); // j - counter index
Vector128<int> offset = Avx2.And(h, Vector128.Create(1));
Vector128<int> blockOffset = Avx2.Add(Vector128.Create(block), offset); // i - table index
blockOffset = Avx2.Add(blockOffset, Vector128.Create(0, 2, 4, 6)); // + (i << 1)
Vector256<ulong> indexLong = Avx2.PermuteVar8x32(Vector256.Create(index, Vector128<int>.Zero), Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7)).AsUInt64();

#if NET6_0_OR_GREATER
long* tablePtr = tableAddr;
#else
fixed (long* tablePtr = table)
#endif
{
Vector256<long> tableVector = Avx2.GatherVector256(tablePtr, blockOffset, 8);
index = Avx2.ShiftLeftLogical(index, 2);

// convert index from int to long via permute
Vector256<long> indexLong = Vector256.Create(index, Vector128<int>.Zero).AsInt64();
Vector256<int> permuteMask2 = Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7);
indexLong = Avx2.PermuteVar8x32(indexLong.AsInt32(), permuteMask2).AsInt64();
tableVector = Avx2.ShiftRightLogicalVariable(tableVector, indexLong.AsUInt64());
tableVector = Avx2.And(tableVector, Vector256.Create(0xfL));

Vector256<int> permuteMask = Vector256.Create(0, 2, 4, 6, 1, 3, 5, 7);
Vector128<ushort> count = Avx2.PermuteVar8x32(tableVector.AsInt32(), permuteMask)
Vector128<ushort> count = Avx2.PermuteVar8x32(Avx2.And(Avx2.ShiftRightLogicalVariable(Avx2.GatherVector256(tablePtr, blockOffset, 8), indexLong), Vector256.Create(0xfL)).AsInt32(), Vector256.Create(0, 2, 4, 6, 1, 3, 5, 7))
.GetLower()
.AsUInt16();

Expand All @@ -302,52 +289,33 @@ private unsafe int EstimateFrequencyAvx(T value)
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private unsafe void IncrementAvx(T value)
{
int blockHash = Spread(comparer.GetHashCode(value));
int counterHash = Rehash(blockHash);
int block = (blockHash & blockMask) << 3;

Vector128<int> h = Vector128.Create(counterHash);
h = Avx2.ShiftRightLogicalVariable(h.AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
Vector128<int> h = Avx2.ShiftRightLogicalVariable(Vector128.Create(counterHash).AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
Vector128<int> index = Avx2.ShiftLeftLogical(Avx2.And(Avx2.ShiftRightLogical(h, 1), Vector128.Create(15)), 2);
Vector128<int> blockOffset = Avx2.Add(Avx2.Add(Vector128.Create(block), Avx2.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6));

Vector128<int> index = Avx2.ShiftRightLogical(h, 1);
index = Avx2.And(index, Vector128.Create(15)); // j - counter index
Vector128<int> offset = Avx2.And(h, Vector128.Create(1));
Vector128<int> blockOffset = Avx2.Add(Vector128.Create(block), offset); // i - table index
blockOffset = Avx2.Add(blockOffset, Vector128.Create(0, 2, 4, 6)); // + (i << 1)
Vector256<ulong> offsetLong = Avx2.PermuteVar8x32(Vector256.Create(index, Vector128<int>.Zero), Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7)).AsUInt64();
Vector256<long> mask = Avx2.ShiftLeftLogicalVariable(Vector256.Create(0xfL), offsetLong);

#if NET6_0_OR_GREATER
long* tablePtr = tableAddr;
#else
fixed (long* tablePtr = table)
#endif
{
Vector256<long> tableVector = Avx2.GatherVector256(tablePtr, blockOffset, 8);

// j == index
index = Avx2.ShiftLeftLogical(index, 2);
Vector256<long> offsetLong = Vector256.Create(index, Vector128<int>.Zero).AsInt64();

Vector256<int> permuteMask = Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7);
offsetLong = Avx2.PermuteVar8x32(offsetLong.AsInt32(), permuteMask).AsInt64();

// mask = (0xfL << offset)
Vector256<long> fifteen = Vector256.Create(0xfL);
Vector256<long> mask = Avx2.ShiftLeftLogicalVariable(fifteen, offsetLong.AsUInt64());

// (table[i] & mask) != mask)
// Note masked is 'equal' - therefore use AndNot below
Vector256<long> masked = Avx2.CompareEqual(Avx2.And(tableVector, mask), mask);

// 1L << offset
Vector256<long> inc = Avx2.ShiftLeftLogicalVariable(Vector256.Create(1L), offsetLong.AsUInt64());
Vector256<long> masked = Avx2.CompareEqual(Avx2.And(Avx2.GatherVector256(tablePtr, blockOffset, 8), mask), mask);

// Mask to zero out non matches (add zero below) - first operand is NOT then AND result (order matters)
inc = Avx2.AndNot(masked, inc);
Vector256<long> inc = Avx2.AndNot(masked, Avx2.ShiftLeftLogicalVariable(Vector256.Create(1L), offsetLong));

Vector256<byte> result = Avx2.CompareEqual(masked.AsByte(), Vector256<byte>.Zero);
bool wasInc = Avx2.MoveMask(result.AsByte()) == unchecked((int)(0b1111_1111_1111_1111_1111_1111_1111_1111));
bool wasInc = Avx2.MoveMask(Avx2.CompareEqual(masked.AsByte(), Vector256<byte>.Zero).AsByte()) == unchecked((int)(0b1111_1111_1111_1111_1111_1111_1111_1111));

tablePtr[blockOffset.GetElement(0)] += inc.GetElement(0);
tablePtr[blockOffset.GetElement(1)] += inc.GetElement(1);
Expand Down

0 comments on commit d8ac6f4

Please sign in to comment.