Skip to content

Commit

Permalink
Simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
aalmada committed Jan 21, 2024
1 parent 25606e4 commit e7ec730
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 66 deletions.
32 changes: 1 addition & 31 deletions src/NetFabric.Numerics.Tensors/Aggregate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public static T Aggregate<T, TOperator>(ReadOnlySpan<T> source)
// aggregate the remaining elements in the source
ref var sourceRef = ref MemoryMarshal.GetReference(source);
var remaining = source.Length - (int)sourceIndex;
if (remaining >= 8)
if (remaining >= 4)
{
var partial1 = TOperator.Identity;
var partial2 = TOperator.Identity;
Expand All @@ -60,36 +60,6 @@ public static T Aggregate<T, TOperator>(ReadOnlySpan<T> source)

switch(remaining)
{
case 7:
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 1));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 2));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 3));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 4));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 5));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 6));
break;
case 6:
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 1));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 2));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 3));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 4));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 5));
break;
case 5:
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 1));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 2));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 3));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 4));
break;
case 4:
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 1));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 2));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 3));
break;
case 3:
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex));
aggregate = TOperator.Invoke(aggregate, Unsafe.Add(ref sourceRef, sourceIndex + 1));
Expand Down
16 changes: 1 addition & 15 deletions src/NetFabric.Numerics.Tensors/Aggregate2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public static ValueTuple<T, T> Aggregate2D<T, TOperator>(ReadOnlySpan<T> source)
// aggregate the remaining elements in the source
ref var sourceRef = ref MemoryMarshal.GetReference(source);
var remaining = source.Length;
if (remaining >= 8)
if (remaining >= 4)
{
var partialX1 = TOperator.Identity;
var partialY1 = TOperator.Identity;
Expand All @@ -35,20 +35,6 @@ public static ValueTuple<T, T> Aggregate2D<T, TOperator>(ReadOnlySpan<T> source)

switch(remaining)
{
case 6:
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex));
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 1));
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex + 2));
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 3));
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex + 4));
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 5));
break;
case 4:
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex));
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 1));
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex + 2));
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 3));
break;
case 2:
aggregateX = TOperator.Invoke(aggregateX, Unsafe.Add(ref sourceRef, sourceIndex));
aggregateY = TOperator.Invoke(aggregateY, Unsafe.Add(ref sourceRef, sourceIndex + 1));
Expand Down
8 changes: 3 additions & 5 deletions src/NetFabric.Numerics.Tensors/ApplyBinary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,10 @@ public static void Apply<T, TOperator>(ReadOnlySpan<T> x, ValueTuple<T, T> y, Sp
Unsafe.Add(ref destinationRef, index + 3) = TOperator.Invoke(Unsafe.Add(ref xRef, index + 3), y.Item2);
}

switch(x.Length - (int)index)
if(x.Length > (int)index)
{
case 2:
Unsafe.Add(ref destinationRef, index) = TOperator.Invoke(Unsafe.Add(ref xRef, index), y.Item1);
Unsafe.Add(ref destinationRef, index + 1) = TOperator.Invoke(Unsafe.Add(ref xRef, index + 1), y.Item2);
break;
Unsafe.Add(ref destinationRef, index) = TOperator.Invoke(Unsafe.Add(ref xRef, index), y.Item1);
Unsafe.Add(ref destinationRef, index + 1) = TOperator.Invoke(Unsafe.Add(ref xRef, index + 1), y.Item2);
}
}

Expand Down
24 changes: 9 additions & 15 deletions src/NetFabric.Numerics.Tensors/ApplyTernary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,10 @@ public static void Apply<T, TOperator>(ReadOnlySpan<T> x, ValueTuple<T, T> y, Re
Unsafe.Add(ref destinationRef, index + 3) = TOperator.Invoke(Unsafe.Add(ref xRef, index + 3), y.Item2, Unsafe.Add(ref zRef, index + 3));
}

switch(x.Length - (int)index)
if(x.Length > (int)index)
{
case 2:
Unsafe.Add(ref destinationRef, index) = TOperator.Invoke(Unsafe.Add(ref xRef, index), y.Item1, Unsafe.Add(ref zRef, index));
Unsafe.Add(ref destinationRef, index + 1) = TOperator.Invoke(Unsafe.Add(ref xRef, index + 1), y.Item2, Unsafe.Add(ref zRef, index + 1));
break;
Unsafe.Add(ref destinationRef, index) = TOperator.Invoke(Unsafe.Add(ref xRef, index), y.Item1, Unsafe.Add(ref zRef, index));
Unsafe.Add(ref destinationRef, index + 1) = TOperator.Invoke(Unsafe.Add(ref xRef, index + 1), y.Item2, Unsafe.Add(ref zRef, index + 1));
}
}

Expand Down Expand Up @@ -386,12 +384,10 @@ public static void Apply<T, TOperator>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Val
Unsafe.Add(ref destinationRef, index + 3) = TOperator.Invoke(Unsafe.Add(ref xRef, index + 3), Unsafe.Add(ref yRef, index + 3), z.Item2);
}

switch(x.Length - (int)index)
if(x.Length > (int)index)
{
case 2:
Unsafe.Add(ref destinationRef, index) = TOperator.Invoke(Unsafe.Add(ref xRef, index), Unsafe.Add(ref yRef, index), z.Item1);
Unsafe.Add(ref destinationRef, index + 1) = TOperator.Invoke(Unsafe.Add(ref xRef, index + 1), Unsafe.Add(ref yRef, index + 1), z.Item2);
break;
Unsafe.Add(ref destinationRef, index) = TOperator.Invoke(Unsafe.Add(ref xRef, index), Unsafe.Add(ref yRef, index), z.Item1);
Unsafe.Add(ref destinationRef, index + 1) = TOperator.Invoke(Unsafe.Add(ref xRef, index + 1), Unsafe.Add(ref yRef, index + 1), z.Item2);
}
}

Expand Down Expand Up @@ -544,12 +540,10 @@ public static void Apply<T, TOperator>(ReadOnlySpan<T> x, ValueTuple<T, T> y, Va
Unsafe.Add(ref destinationRef, index + 3) = TOperator.Invoke(Unsafe.Add(ref xRef, index + 3), y.Item2, z.Item2);
}

switch(x.Length - (int)index)
if(x.Length > (int)index)
{
case 2:
Unsafe.Add(ref destinationRef, index) = TOperator.Invoke(Unsafe.Add(ref xRef, index), y.Item1, z.Item1);
Unsafe.Add(ref destinationRef, index + 1) = TOperator.Invoke(Unsafe.Add(ref xRef, index + 1), y.Item2, z.Item2);
break;
Unsafe.Add(ref destinationRef, index) = TOperator.Invoke(Unsafe.Add(ref xRef, index), y.Item1, z.Item1);
Unsafe.Add(ref destinationRef, index + 1) = TOperator.Invoke(Unsafe.Add(ref xRef, index + 1), y.Item2, z.Item2);
}
}

Expand Down

0 comments on commit e7ec730

Please sign in to comment.