Skip to content

Commit

Permalink
Implement LFU sketch using arm64 intrinsics (redux) (#648)
Browse files Browse the repository at this point in the history
* basic impl

* run tests

* fix

* table lookup

* opt

* opt

* temp

* cleanup

* endif

* fix return

* cleanup

---------
  • Loading branch information
bitfaster authored Dec 1, 2024
1 parent d8ac6f4 commit 5669d38
Show file tree
Hide file tree
Showing 8 changed files with 289 additions and 13 deletions.
12 changes: 10 additions & 2 deletions BitFaster.Caching.Benchmarks/BitFaster.Caching.Benchmarks.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
<PropertyGroup>
<OutputType>Exe</OutputType>
<LangVersion>latest</LangVersion>
<TargetFrameworks>net48;net6.0;net8.0</TargetFrameworks>
<TargetFrameworks>net6.0;net8.0</TargetFrameworks>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
<!-- https://stackoverflow.com/a/59916801/131345 -->
<IsWindows Condition="'$([System.Runtime.InteropServices.RuntimeInformation]::IsOSPlatform($([System.Runtime.InteropServices.OSPlatform]::Windows)))' == 'true'">true</IsWindows>
<IsLinux Condition="'$([System.Runtime.InteropServices.RuntimeInformation]::IsOSPlatform($([System.Runtime.InteropServices.OSPlatform]::Linux)))' == 'true'">true</IsLinux>
<IsMacOS Condition="'$([System.Runtime.InteropServices.RuntimeInformation]::IsOSPlatform($([System.Runtime.InteropServices.OSPlatform]::OSX)))' == 'true'">true</IsMacOS>
<IsArm64 Condition="$([System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture) == Arm64">true</IsArm64>
<IsX64 Condition="$([System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture) == X64">true</IsX64>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
Expand Down Expand Up @@ -41,5 +43,11 @@
<PropertyGroup Condition="'$(IsMacOS)'=='true'">
<DefineConstants>MacOS</DefineConstants>
</PropertyGroup>
<PropertyGroup Condition="'$(IsArm64)'=='true'">
<DefineConstants>Arm64</DefineConstants>
</PropertyGroup>
<PropertyGroup Condition="'$(IsX64)'=='true'">
<DefineConstants>X64</DefineConstants>
</PropertyGroup>

</Project>
</Project>
104 changes: 104 additions & 0 deletions BitFaster.Caching.Benchmarks/Lfu/CmSketchNoPin.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;


#if NET6_0_OR_GREATER
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.Arm;
using System.Runtime.Intrinsics.X86;
#endif

Expand Down Expand Up @@ -61,6 +64,12 @@ public int EstimateFrequency(T value)
{
return EstimateFrequencyAvx(value);
}
#if NET6_0_OR_GREATER
else if (isa.IsArm64Supported)
{
return EstimateFrequencyArm(value);
}
#endif
else
{
return EstimateFrequencyStd(value);
Expand All @@ -84,6 +93,12 @@ public void Increment(T value)
{
IncrementAvx(value);
}
#if NET6_0_OR_GREATER
else if (isa.IsArm64Supported)
{
IncrementArm(value);
}
#endif
else
{
IncrementStd(value);
Expand Down Expand Up @@ -314,5 +329,94 @@ private unsafe void IncrementAvx(T value)
}
}
#endif

#if NET6_0_OR_GREATER
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private unsafe void IncrementArm(T value)
{
int blockHash = Spread(comparer.GetHashCode(value));
int counterHash = Rehash(blockHash);
int block = (blockHash & blockMask) << 3;

Vector128<int> h = AdvSimd.ShiftArithmetic(Vector128.Create(counterHash), Vector128.Create(0, -8, -16, -24));
Vector128<int> index = AdvSimd.And(AdvSimd.ShiftRightLogical(h, 1), Vector128.Create(0xf));
Vector128<int> blockOffset = AdvSimd.Add(AdvSimd.Add(Vector128.Create(block), AdvSimd.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6));

fixed (long* tablePtr = table)
{
int t0 = AdvSimd.Extract(blockOffset, 0);
int t1 = AdvSimd.Extract(blockOffset, 1);
int t2 = AdvSimd.Extract(blockOffset, 2);
int t3 = AdvSimd.Extract(blockOffset, 3);

Vector128<long> tableVectorA = Vector128.Create(AdvSimd.LoadVector64(tablePtr + t0), AdvSimd.LoadVector64(tablePtr + t1));
Vector128<long> tableVectorB = Vector128.Create(AdvSimd.LoadVector64(tablePtr + t2), AdvSimd.LoadVector64(tablePtr + t3));

index = AdvSimd.ShiftLeftLogicalSaturate(index, 2);

Vector128<int> longOffA = AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128<int>.Zero, 0, index, 0), 2, index, 1);
Vector128<int> longOffB = AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128<int>.Zero, 0, index, 2), 2, index, 3);

Vector128<long> fifteen = Vector128.Create(0xfL);
Vector128<long> maskA = AdvSimd.ShiftArithmetic(fifteen, longOffA.AsInt64());
Vector128<long> maskB = AdvSimd.ShiftArithmetic(fifteen, longOffB.AsInt64());

Vector128<long> maskedA = AdvSimd.Not(AdvSimd.Arm64.CompareEqual(AdvSimd.And(tableVectorA, maskA), maskA));
Vector128<long> maskedB = AdvSimd.Not(AdvSimd.Arm64.CompareEqual(AdvSimd.And(tableVectorB, maskB), maskB));

var one = Vector128.Create(1L);
Vector128<long> incA = AdvSimd.And(maskedA, AdvSimd.ShiftArithmetic(one, longOffA.AsInt64()));
Vector128<long> incB = AdvSimd.And(maskedB, AdvSimd.ShiftArithmetic(one, longOffB.AsInt64()));

tablePtr[t0] += AdvSimd.Extract(incA, 0);
tablePtr[t1] += AdvSimd.Extract(incA, 1);
tablePtr[t2] += AdvSimd.Extract(incB, 0);
tablePtr[t3] += AdvSimd.Extract(incB, 1);

var max = AdvSimd.Arm64.MaxAcross(AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.MaxAcross(incA.AsInt32()), 1, AdvSimd.Arm64.MaxAcross(incB.AsInt32()), 0).AsInt16());

if (max.ToScalar() != 0 && (++size == sampleSize))
{
Reset();
}
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private unsafe int EstimateFrequencyArm(T value)
{
int blockHash = Spread(comparer.GetHashCode(value));
int counterHash = Rehash(blockHash);
int block = (blockHash & blockMask) << 3;

Vector128<int> h = AdvSimd.ShiftArithmetic(Vector128.Create(counterHash), Vector128.Create(0, -8, -16, -24));
Vector128<int> index = AdvSimd.And(AdvSimd.ShiftRightLogical(h, 1), Vector128.Create(0xf));
Vector128<int> blockOffset = AdvSimd.Add(AdvSimd.Add(Vector128.Create(block), AdvSimd.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6));

fixed (long* tablePtr = table)
{
Vector128<long> tableVectorA = Vector128.Create(AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 0)), AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 1)));
Vector128<long> tableVectorB = Vector128.Create(AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 2)), AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 3)));

index = AdvSimd.ShiftLeftLogicalSaturate(index, 2);

Vector128<int> indexA = AdvSimd.Negate(AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128<int>.Zero, 0, index, 0), 2, index, 1));
Vector128<int> indexB = AdvSimd.Negate(AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128<int>.Zero, 0, index, 2), 2, index, 3));

var fifteen = Vector128.Create(0xfL);
Vector128<long> a = AdvSimd.And(AdvSimd.ShiftArithmetic(tableVectorA, indexA.AsInt64()), fifteen);
Vector128<long> b = AdvSimd.And(AdvSimd.ShiftArithmetic(tableVectorB, indexB.AsInt64()), fifteen);

// Before: < 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, A, B, C, D, E, F >
// After: < 0, 1, 2, 3, 8, 9, A, B, 4, 5, 6, 7, C, D, E, F >
var min = AdvSimd.Arm64.VectorTableLookup(a.AsByte(), Vector128.Create(0x0B0A090803020100, 0xFFFFFFFFFFFFFFFF).AsByte());
min = AdvSimd.Arm64.VectorTableLookupExtension(min, b.AsByte(), Vector128.Create(0xFFFFFFFFFFFFFFFF, 0x0B0A090803020100).AsByte());

var min32 = AdvSimd.Arm64.MinAcross(min.AsInt32());

return min32.ToScalar();
}
}
#endif
}
}
13 changes: 11 additions & 2 deletions BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public int FrequencyFlat()

return count;
}

#if X64
[Benchmark(OperationsPerInvoke = iterations)]
public int FrequencyFlatAvx()
{
Expand All @@ -61,7 +61,7 @@ public int FrequencyFlatAvx()

return count;
}

#endif
[Benchmark(OperationsPerInvoke = iterations)]
public int FrequencyBlock()
{
Expand All @@ -73,7 +73,11 @@ public int FrequencyBlock()
}

[Benchmark(OperationsPerInvoke = iterations)]
#if Arm64
public int FrequencyBlockNeonNotPinned()
#else
public int FrequencyBlockAvxNotPinned()
#endif
{
int count = 0;
for (int i = 0; i < iterations; i++)
Expand All @@ -83,7 +87,12 @@ public int FrequencyBlockAvxNotPinned()
}

[Benchmark(OperationsPerInvoke = iterations)]

#if Arm64
public int FrequencyBlockNeonPinned()
#else
public int FrequencyBlockAvxPinned()
#endif
{
int count = 0;
for (int i = 0; i < iterations; i++)
Expand Down
13 changes: 11 additions & 2 deletions BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class SketchIncrement
private CmSketchNoPin<int, DetectIsa> blockAvxNoPin;
private CmSketchCore<int, DetectIsa> blockAvx;


[Params(32_768, 524_288, 8_388_608, 134_217_728)]
public int Size { get; set; }

Expand All @@ -49,7 +50,7 @@ public void IncFlat()
flatStd.Increment(i);
}
}

#if X64
[Benchmark(OperationsPerInvoke = iterations)]
public void IncFlatAvx()
{
Expand All @@ -58,7 +59,7 @@ public void IncFlatAvx()
flatAvx.Increment(i);
}
}

#endif
[Benchmark(OperationsPerInvoke = iterations)]
public void IncBlock()
{
Expand All @@ -69,7 +70,11 @@ public void IncBlock()
}

[Benchmark(OperationsPerInvoke = iterations)]
#if Arm64
public void IncBlockNeonNotPinned()
#else
public void IncBlockAvxNotPinned()
#endif
{
for (int i = 0; i < iterations; i++)
{
Expand All @@ -78,7 +83,11 @@ public void IncBlockAvxNotPinned()
}

[Benchmark(OperationsPerInvoke = iterations)]
#if Arm64
public void IncBlockNeonPinned()
#else
public void IncBlockAvxPinned()
#endif
{
for (int i = 0; i < iterations; i++)
{
Expand Down
10 changes: 10 additions & 0 deletions BitFaster.Caching.UnitTests/Intrinsics.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#if NETCOREAPP3_1_OR_GREATER
using System.Runtime.Intrinsics.X86;
#endif
#if NET6_0_OR_GREATER
using System.Runtime.Intrinsics.Arm;
#endif

using Xunit;

namespace BitFaster.Caching.UnitTests
Expand All @@ -10,8 +14,14 @@ public static class Intrinsics
public static void SkipAvxIfNotSupported<I>()
{
#if NETCOREAPP3_1_OR_GREATER
#if NET6_0_OR_GREATER
// when we are trying to test Avx2/Arm64, skip the test if it's not supported
Skip.If(typeof(I) == typeof(DetectIsa) && !(Avx2.IsSupported || AdvSimd.Arm64.IsSupported));
#else
// when we are trying to test Avx2, skip the test if it's not supported
Skip.If(typeof(I) == typeof(DetectIsa) && !Avx2.IsSupported);
#endif

#else
Skip.If(true);
#endif
Expand Down
15 changes: 12 additions & 3 deletions BitFaster.Caching.UnitTests/Lfu/CmSketchTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

namespace BitFaster.Caching.UnitTests.Lfu
{
// Test with AVX2 if it is supported
public class CMSketchAvx2Tests : CmSketchTestBase<DetectIsa>
// Test with AVX2/ARM64 if it is supported
public class CMSketchIntrinsicsTests : CmSketchTestBase<DetectIsa>
{
}

// Test with AVX2 disabled
// Test with AVX2/ARM64 disabled
public class CmSketchTests : CmSketchTestBase<DisableHardwareIntrinsics>
{
}
Expand All @@ -29,14 +29,23 @@ public CmSketchTestBase()
public void Repro()
{
sketch = new CmSketchCore<int, I>(1_048_576, EqualityComparer<int>.Default);
var baseline = new CmSketchCore<int, DisableHardwareIntrinsics>(1_048_576, EqualityComparer<int>.Default);

for (int i = 0; i < 1_048_576; i++)
{
if (i % 3 == 0)
{
sketch.Increment(i);
baseline.Increment(i);
}
}

baseline.Size.Should().Be(sketch.Size);

for (int i = 0; i < 1_048_576; i++)
{
sketch.EstimateFrequency(i).Should().Be(baseline.EstimateFrequency(i));
}
}


Expand Down
28 changes: 25 additions & 3 deletions BitFaster.Caching/Intrinsics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
using System.Runtime.Intrinsics.X86;
#endif

#if NET6_0
using System.Runtime.Intrinsics.Arm;
#endif

namespace BitFaster.Caching
{
/// <summary>
Expand All @@ -12,7 +16,14 @@ public interface IsaProbe
/// <summary>
/// Gets a value indicating whether AVX2 is supported.
/// </summary>
bool IsAvx2Supported { get; }
bool IsAvx2Supported { get; }

#if NET6_0_OR_GREATER
/// <summary>
/// Gets a value indicating whether Arm64 is supported.
/// </summary>
bool IsArm64Supported { get => false; }
#endif
}

/// <summary>
Expand All @@ -25,7 +36,15 @@ public interface IsaProbe
public bool IsAvx2Supported => false;
#else
/// <inheritdoc/>
public bool IsAvx2Supported => Avx2.IsSupported;
public bool IsAvx2Supported => Avx2.IsSupported;
#endif

#if NET6_0_OR_GREATER
/// <inheritdoc/>
public bool IsArm64Supported => AdvSimd.Arm64.IsSupported;
#else
/// <inheritdoc/>
public bool IsArm64Supported => false;
#endif
}

Expand All @@ -35,6 +54,9 @@ public interface IsaProbe
public readonly struct DisableHardwareIntrinsics : IsaProbe
{
/// <inheritdoc/>
public bool IsAvx2Supported => false;
public bool IsAvx2Supported => false;

/// <inheritdoc/>
public bool IsArm64Supported => false;
}
}
Loading

0 comments on commit 5669d38

Please sign in to comment.