-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
92 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,71 @@ | ||
namespace NetFabric.Numerics | ||
namespace NetFabric.Numerics; | ||
|
||
public static partial class Tensor | ||
{ | ||
public static partial class Tensor | ||
public static ValueTuple<T, T, T> Aggregate3D<T, TOperator>(ReadOnlySpan<T> source) | ||
where T : struct | ||
where TOperator : struct, IAggregationOperator<T, T> | ||
=> Aggregate3D<T, T, TOperator>(source); | ||
|
||
public static ValueTuple<TResult, TResult, TResult> Aggregate3D<TSource, TResult, TOperator>(ReadOnlySpan<TSource> source) | ||
where TSource : struct | ||
where TResult : struct | ||
where TOperator : struct, IAggregationOperator<TSource, TResult> | ||
{ | ||
if (source.Length % 3 is not 0) | ||
Throw.ArgumentException(nameof(source), "source span must have a size multiple of 3."); | ||
|
||
// initialize aggregate | ||
var aggregateX = TOperator.Identity; | ||
var aggregateY = TOperator.Identity; | ||
var aggregateZ = TOperator.Identity; | ||
var sourceIndex = nint.Zero; | ||
|
||
// aggregate using hardware acceleration if available | ||
if (TOperator.IsVectorizable && | ||
Vector.IsHardwareAccelerated && | ||
Vector<TSource>.IsSupported && | ||
Vector<TResult>.IsSupported && | ||
Vector<TSource>.Count is >2) | ||
{ | ||
public static ValueTuple<T, T, T> Aggregate3D<T, TOperator>(ReadOnlySpan<T> source) | ||
where T : struct | ||
where TOperator : struct, IAggregationOperator<T, T> | ||
=> Aggregate3D<T, T, TOperator>(source); | ||
|
||
public static ValueTuple<TResult, TResult, TResult> Aggregate3D<TSource, TResult, TOperator>(ReadOnlySpan<TSource> source) | ||
where TSource : struct | ||
where TResult : struct | ||
where TOperator : struct, IAggregationOperator<TSource, TResult> | ||
// convert source span to vector span without copies | ||
var sourceVectors = MemoryMarshal.Cast<TSource, Vector<TSource>>(source); | ||
|
||
// check if there are multiple vectors to aggregate | ||
if (sourceVectors.Length is >1) | ||
{ | ||
if (source.Length % 3 is not 0) | ||
Throw.ArgumentException(nameof(source), "source span must have a size multiple of 3."); | ||
|
||
// initialize aggregate | ||
var aggregateX = TOperator.Identity; | ||
var aggregateY = TOperator.Identity; | ||
var aggregateZ = TOperator.Identity; | ||
var sourceIndex = nint.Zero; | ||
|
||
// aggregate the remaining elements in the source | ||
ref var sourceRef = ref MemoryMarshal.GetReference(source); | ||
for (; sourceIndex + 2 < source.Length; sourceIndex += 3) | ||
// initialize aggregate vector | ||
var resultVector = new Vector<TResult>(TOperator.Identity); | ||
|
||
// aggregate the source vectors into the aggregate vector | ||
ref var sourceVectorsRef = ref MemoryMarshal.GetReference(sourceVectors); | ||
for (var indexVector = nint.Zero; indexVector < sourceVectors.Length; indexVector++) | ||
{ | ||
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex)); | ||
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 1)); | ||
aggregateZ = TOperator.Invoke(aggregateZ, Unsafe.Add(ref sourceRef, sourceIndex + 2)); | ||
resultVector = TOperator.Invoke(ref resultVector, ref Unsafe.Add(ref sourceVectorsRef, indexVector)); | ||
} | ||
|
||
return (aggregateX, aggregateY, aggregateZ); | ||
// aggregate the aggregate vector into the aggregate | ||
ref var resultVectorRef = ref Unsafe.As<Vector<TResult>, TResult>(ref Unsafe.AsRef(in resultVector)); | ||
for (var index = 0; index + 1 < Vector<TResult>.Count; index += 2) | ||
{ | ||
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref resultVectorRef, index)); | ||
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref resultVectorRef, index + 1)); | ||
} | ||
|
||
// skip the source elements already aggregated | ||
sourceIndex = source.Length - (source.Length % Vector<TSource>.Count); | ||
} | ||
} | ||
|
||
// aggregate the remaining elements in the source | ||
ref var sourceRef = ref MemoryMarshal.GetReference(source); | ||
for (; sourceIndex + 2 < source.Length; sourceIndex += 3) | ||
{ | ||
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex)); | ||
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 1)); | ||
aggregateZ = TOperator.Invoke(aggregateZ, Unsafe.Add(ref sourceRef, sourceIndex + 2)); | ||
} | ||
|
||
return (aggregateX, aggregateY, aggregateZ); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,37 @@ | ||
namespace NetFabric.Numerics | ||
{ | ||
public static partial class Tensor | ||
{ | ||
public static ValueTuple<T, T, T, T> Aggregate4D<T, TOperator>(ReadOnlySpan<T> source) | ||
where T : struct | ||
where TOperator : struct, IAggregationOperator<T, T> | ||
=> Aggregate4D<T, T, TOperator>(source); | ||
namespace NetFabric.Numerics; | ||
|
||
public static ValueTuple<TResult, TResult, TResult, TResult> Aggregate4D<TSource, TResult, TOperator>(ReadOnlySpan<TSource> source) | ||
where TSource : struct | ||
where TResult : struct | ||
where TOperator : struct, IAggregationOperator<TSource, TResult> | ||
{ | ||
if (source.Length % 4 is not 0) | ||
Throw.ArgumentException(nameof(source), "source span must have a size multiple of 4."); | ||
public static partial class Tensor | ||
{ | ||
public static ValueTuple<T, T, T, T> Aggregate4D<T, TOperator>(ReadOnlySpan<T> source) | ||
where T : struct | ||
where TOperator : struct, IAggregationOperator<T, T> | ||
=> Aggregate4D<T, T, TOperator>(source); | ||
|
||
// initialize aggregate | ||
var aggregateX = TOperator.Identity; | ||
var aggregateY = TOperator.Identity; | ||
var aggregateZ = TOperator.Identity; | ||
var aggregateW = TOperator.Identity; | ||
var sourceIndex = nint.Zero; | ||
public static ValueTuple<TResult, TResult, TResult, TResult> Aggregate4D<TSource, TResult, TOperator>(ReadOnlySpan<TSource> source) | ||
where TSource : struct | ||
where TResult : struct | ||
where TOperator : struct, IAggregationOperator<TSource, TResult> | ||
{ | ||
if (source.Length % 4 is not 0) | ||
Throw.ArgumentException(nameof(source), "source span must have a size multiple of 4."); | ||
|
||
// aggregate the remaining elements in the source | ||
ref var sourceRef = ref MemoryMarshal.GetReference(source); | ||
for (; sourceIndex + 3 < source.Length; sourceIndex += 4) | ||
{ | ||
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex)); | ||
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 1)); | ||
aggregateZ = TOperator.Invoke(aggregateZ, Unsafe.Add(ref sourceRef, sourceIndex + 2)); | ||
aggregateW = TOperator.Invoke(aggregateW, Unsafe.Add(ref sourceRef, sourceIndex + 3)); | ||
} | ||
// initialize aggregate | ||
var aggregateX = TOperator.Identity; | ||
var aggregateY = TOperator.Identity; | ||
var aggregateZ = TOperator.Identity; | ||
var aggregateW = TOperator.Identity; | ||
var sourceIndex = nint.Zero; | ||
|
||
return (aggregateX, aggregateY, aggregateZ, aggregateW); | ||
// aggregate the remaining elements in the source | ||
ref var sourceRef = ref MemoryMarshal.GetReference(source); | ||
for (; sourceIndex + 3 < source.Length; sourceIndex += 4) | ||
{ | ||
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex)); | ||
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 1)); | ||
aggregateZ = TOperator.Invoke(aggregateZ, Unsafe.Add(ref sourceRef, sourceIndex + 2)); | ||
aggregateW = TOperator.Invoke(aggregateW, Unsafe.Add(ref sourceRef, sourceIndex + 3)); | ||
} | ||
|
||
return (aggregateX, aggregateY, aggregateZ, aggregateW); | ||
} | ||
} |