Skip to content

Commit

Permalink
Vectorize Aggregate2D
Browse files Browse the repository at this point in the history
  • Loading branch information
aalmada committed Jan 28, 2024
1 parent cdc5db2 commit e29fb96
Showing 1 changed file with 75 additions and 40 deletions.
115 changes: 75 additions & 40 deletions src/NetFabric.Numerics.Tensors/Aggregate2D.cs
Original file line number Diff line number Diff line change
@@ -1,54 +1,89 @@
namespace NetFabric.Numerics
namespace NetFabric.Numerics;

public static partial class Tensor
{
public static partial class Tensor
public static ValueTuple<T, T> Aggregate2D<T, TOperator>(ReadOnlySpan<T> source)

Check warning on line 5 in src/NetFabric.Numerics.Tensors/Aggregate2D.cs

View workflow job for this annotation

GitHub Actions / build

Missing XML comment for publicly visible type or member 'Tensor.Aggregate2D<T, TOperator>(ReadOnlySpan<T>)'
where T : struct
where TOperator : struct, IAggregationOperator<T, T>
=> Aggregate2D<T, T, TOperator>(source);

public static ValueTuple<TResult, TResult> Aggregate2D<TSource, TResult, TOperator>(ReadOnlySpan<TSource> source)

Check warning on line 10 in src/NetFabric.Numerics.Tensors/Aggregate2D.cs

View workflow job for this annotation

GitHub Actions / build

Missing XML comment for publicly visible type or member 'Tensor.Aggregate2D<TSource, TResult, TOperator>(ReadOnlySpan<TSource>)'
where TSource : struct
where TResult : struct
where TOperator : struct, IAggregationOperator<TSource, TResult>
{
public static ValueTuple<T, T> Aggregate2D<T, TOperator>(ReadOnlySpan<T> source)
where T : struct
where TOperator : struct, IAggregationOperator<T, T>
=> Aggregate2D<T, T, TOperator>(source);

public static ValueTuple<TResult, TResult> Aggregate2D<TSource, TResult, TOperator>(ReadOnlySpan<TSource> source)
where TSource : struct
where TResult : struct
where TOperator : struct, IAggregationOperator<TSource, TResult>
if (source.Length % 2 is not 0)
Throw.ArgumentException(nameof(source), "source span must have a size multiple of 2.");

// initialize aggregate
var aggregateX = TOperator.Identity;
var aggregateY = TOperator.Identity;
var sourceIndex = nint.Zero;

// aggregate using hardware acceleration if available
if (TOperator.IsVectorizable &&
Vector.IsHardwareAccelerated &&
Vector<TSource>.IsSupported &&
Vector<TResult>.IsSupported &&
Vector<TSource>.Count is >2)
{
if (source.Length % 2 is not 0)
Throw.ArgumentException(nameof(source), "source span must have a size multiple of 2.");

// initialize aggregate
var aggregateX = TOperator.Identity;
var aggregateY = TOperator.Identity;
var sourceIndex = nint.Zero;

// aggregate the remaining elements in the source
ref var sourceRef = ref MemoryMarshal.GetReference(source);
var remaining = source.Length;
if (remaining >= 4)
// convert source span to vector span without copies
var sourceVectors = MemoryMarshal.Cast<TSource, Vector<TSource>>(source);

// check if there are multiple vectors to aggregate
if (sourceVectors.Length is >1)
{
var partialX1 = TOperator.Identity;
var partialY1 = TOperator.Identity;
for (; sourceIndex + 3 < source.Length; sourceIndex += 4)
// initialize aggregate vector
var resultVector = new Vector<TResult>(TOperator.Identity);

// aggregate the source vectors into the aggregate vector
ref var sourceVectorsRef = ref MemoryMarshal.GetReference(sourceVectors);
for (var indexVector = nint.Zero; indexVector < sourceVectors.Length; indexVector++)
{
resultVector = TOperator.Invoke(ref resultVector, ref Unsafe.Add(ref sourceVectorsRef, indexVector));
}

// aggregate the aggregate vector into the aggregate
ref var resultVectorRef = ref Unsafe.As<Vector<TResult>, TResult>(ref Unsafe.AsRef(in resultVector));
for (var index = 0; index + 1 < Vector<TResult>.Count; index += 2)
{
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex));
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 1));
partialX1 = TOperator.Invoke(partialX1, Unsafe.Add(ref sourceRef, sourceIndex + 2));
partialY1 = TOperator.Invoke(partialY1, Unsafe.Add(ref sourceRef, sourceIndex + 3));
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref resultVectorRef, index));
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref resultVectorRef, index + 1));
}
aggregateX = TOperator.Invoke(aggregateX, partialX1);
aggregateY = TOperator.Invoke(aggregateY, partialY1);
remaining = source.Length - (int)sourceIndex;

// skip the source elements already aggregated
sourceIndex = source.Length - (source.Length % Vector<TSource>.Count);
}
}

switch(remaining)
// aggregate the remaining elements in the source
ref var sourceRef = ref MemoryMarshal.GetReference(source);
var remaining = source.Length;
if (remaining is >=4)
{
var partialX1 = TOperator.Identity;
var partialY1 = TOperator.Identity;
for (; sourceIndex + 3 < source.Length; sourceIndex += 4)
{
case 2:
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex));
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 1));
break;
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex));
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 1));
partialX1 = TOperator.Invoke(partialX1, Unsafe.Add(ref sourceRef, sourceIndex + 2));
partialY1 = TOperator.Invoke(partialY1, Unsafe.Add(ref sourceRef, sourceIndex + 3));
}
aggregateX = TOperator.Invoke(aggregateX, partialX1);
aggregateY = TOperator.Invoke(aggregateY, partialY1);
remaining = source.Length - (int)sourceIndex;
}


return (aggregateX, aggregateY);
switch(remaining)
{
case 2:
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex));
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 1));
break;
}


return (aggregateX, aggregateY);
}
}

0 comments on commit e29fb96

Please sign in to comment.