-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
75 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,54 +1,89 @@ | ||
namespace NetFabric.Numerics | ||
namespace NetFabric.Numerics; | ||
|
||
public static partial class Tensor | ||
{ | ||
public static partial class Tensor | ||
public static ValueTuple<T, T> Aggregate2D<T, TOperator>(ReadOnlySpan<T> source) | ||
where T : struct | ||
where TOperator : struct, IAggregationOperator<T, T> | ||
=> Aggregate2D<T, T, TOperator>(source); | ||
|
||
public static ValueTuple<TResult, TResult> Aggregate2D<TSource, TResult, TOperator>(ReadOnlySpan<TSource> source) | ||
where TSource : struct | ||
where TResult : struct | ||
where TOperator : struct, IAggregationOperator<TSource, TResult> | ||
{ | ||
public static ValueTuple<T, T> Aggregate2D<T, TOperator>(ReadOnlySpan<T> source) | ||
where T : struct | ||
where TOperator : struct, IAggregationOperator<T, T> | ||
=> Aggregate2D<T, T, TOperator>(source); | ||
|
||
public static ValueTuple<TResult, TResult> Aggregate2D<TSource, TResult, TOperator>(ReadOnlySpan<TSource> source) | ||
where TSource : struct | ||
where TResult : struct | ||
where TOperator : struct, IAggregationOperator<TSource, TResult> | ||
if (source.Length % 2 is not 0) | ||
Throw.ArgumentException(nameof(source), "source span must have a size multiple of 2."); | ||
|
||
// initialize aggregate | ||
var aggregateX = TOperator.Identity; | ||
var aggregateY = TOperator.Identity; | ||
var sourceIndex = nint.Zero; | ||
|
||
// aggregate using hardware acceleration if available | ||
if (TOperator.IsVectorizable && | ||
Vector.IsHardwareAccelerated && | ||
Vector<TSource>.IsSupported && | ||
Vector<TResult>.IsSupported && | ||
Vector<TSource>.Count is >2) | ||
{ | ||
if (source.Length % 2 is not 0) | ||
Throw.ArgumentException(nameof(source), "source span must have a size multiple of 2."); | ||
|
||
// initialize aggregate | ||
var aggregateX = TOperator.Identity; | ||
var aggregateY = TOperator.Identity; | ||
var sourceIndex = nint.Zero; | ||
|
||
// aggregate the remaining elements in the source | ||
ref var sourceRef = ref MemoryMarshal.GetReference(source); | ||
var remaining = source.Length; | ||
if (remaining >= 4) | ||
// convert source span to vector span without copies | ||
var sourceVectors = MemoryMarshal.Cast<TSource, Vector<TSource>>(source); | ||
|
||
// check if there are multiple vectors to aggregate | ||
if (sourceVectors.Length is >1) | ||
{ | ||
var partialX1 = TOperator.Identity; | ||
var partialY1 = TOperator.Identity; | ||
for (; sourceIndex + 3 < source.Length; sourceIndex += 4) | ||
// initialize aggregate vector | ||
var resultVector = new Vector<TResult>(TOperator.Identity); | ||
|
||
// aggregate the source vectors into the aggregate vector | ||
ref var sourceVectorsRef = ref MemoryMarshal.GetReference(sourceVectors); | ||
for (var indexVector = nint.Zero; indexVector < sourceVectors.Length; indexVector++) | ||
{ | ||
resultVector = TOperator.Invoke(ref resultVector, ref Unsafe.Add(ref sourceVectorsRef, indexVector)); | ||
} | ||
|
||
// aggregate the aggregate vector into the aggregate | ||
ref var resultVectorRef = ref Unsafe.As<Vector<TResult>, TResult>(ref Unsafe.AsRef(in resultVector)); | ||
for (var index = 0; index + 1 < Vector<TResult>.Count; index += 2) | ||
{ | ||
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex)); | ||
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 1)); | ||
partialX1 = TOperator.Invoke(partialX1, Unsafe.Add(ref sourceRef, sourceIndex + 2)); | ||
partialY1 = TOperator.Invoke(partialY1, Unsafe.Add(ref sourceRef, sourceIndex + 3)); | ||
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref resultVectorRef, index)); | ||
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref resultVectorRef, index + 1)); | ||
} | ||
aggregateX = TOperator.Invoke(aggregateX, partialX1); | ||
aggregateY = TOperator.Invoke(aggregateY, partialY1); | ||
remaining = source.Length - (int)sourceIndex; | ||
|
||
// skip the source elements already aggregated | ||
sourceIndex = source.Length - (source.Length % Vector<TSource>.Count); | ||
} | ||
} | ||
|
||
switch(remaining) | ||
// aggregate the remaining elements in the source | ||
ref var sourceRef = ref MemoryMarshal.GetReference(source); | ||
var remaining = source.Length; | ||
if (remaining is >=4) | ||
{ | ||
var partialX1 = TOperator.Identity; | ||
var partialY1 = TOperator.Identity; | ||
for (; sourceIndex + 3 < source.Length; sourceIndex += 4) | ||
{ | ||
case 2: | ||
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex)); | ||
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 1)); | ||
break; | ||
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex)); | ||
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 1)); | ||
partialX1 = TOperator.Invoke(partialX1, Unsafe.Add(ref sourceRef, sourceIndex + 2)); | ||
partialY1 = TOperator.Invoke(partialY1, Unsafe.Add(ref sourceRef, sourceIndex + 3)); | ||
} | ||
aggregateX = TOperator.Invoke(aggregateX, partialX1); | ||
aggregateY = TOperator.Invoke(aggregateY, partialY1); | ||
remaining = source.Length - (int)sourceIndex; | ||
} | ||
|
||
|
||
return (aggregateX, aggregateY); | ||
switch(remaining) | ||
{ | ||
case 2: | ||
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex)); | ||
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 1)); | ||
break; | ||
} | ||
|
||
|
||
return (aggregateX, aggregateY); | ||
} | ||
} |