diff --git a/BrightData/BrightData.xml b/BrightData/BrightData.xml index 403ac508..eb98933f 100644 --- a/BrightData/BrightData.xml +++ b/BrightData/BrightData.xml @@ -2655,7 +2655,6 @@ Converts the typed buffer to a buffer of objects - @@ -8447,12 +8446,12 @@ - Tensor shape - for a vector the array will have a single element, for a matrix it will be [columns, rows], a 3D tensor will be [columns, rows, depth] etc + Tensor shape - for a vector the array will have a single element, for a matrix it will be [columns, rows], a 3D tensor will be [columns, rows, depth] etc. - Typed tensor interface - vector, matrix, 3D tensor etc + Typed tensor interface - vector, matrix, 3D tensor etc. @@ -8580,7 +8579,7 @@ - Typed tensor interface - vector, matrix, 3D tensor etc + Typed tensor interface - vector, matrix, 3D tensor etc. diff --git a/BrightData/Buffer/ReadOnly/Helper/BufferConcatenator.cs b/BrightData/Buffer/ReadOnly/Helper/BufferConcatenator.cs deleted file mode 100644 index 970cb0a1..00000000 --- a/BrightData/Buffer/ReadOnly/Helper/BufferConcatenator.cs +++ /dev/null @@ -1,80 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; -using BrightData.Types; - -namespace BrightData.Buffer.ReadOnly.Helper -{ - /// - /// Concatenates multiple buffers into one single buffer - /// - /// - //internal class BufferConcatenator : TypedBufferBase, IReadOnlyBufferWithMetaData where T : notnull - //{ - // readonly IReadOnlyBufferWithMetaData[] _buffers; - - // public BufferConcatenator(params IReadOnlyBufferWithMetaData[] buffers) - // { - // _buffers = buffers; - // var first = buffers.First(); - // var size = first.Size; - // var blockCount = first.BlockCount; - // MetaData = first.MetaData; - // foreach (var buffer in buffers.Skip(1)) - // { - // if (first.BlockSize != buffer.BlockSize) - // throw new ArgumentException("All buffer block sizes must be the same"); - // size += buffer.Size; - // blockCount += buffer.BlockCount; - // } - // Size = size; - // BlockSize = first.BlockSize; - // BlockCount = blockCount; - // DataType = typeof(T); - // } - - // public uint Size { get; } - // public uint BlockSize { get; } - // public uint BlockCount { get; } - // public Type DataType { get; } - // public MetaData MetaData { get; } - - // public override async IAsyncEnumerable EnumerateAll() - // { - // foreach (var buffer in _buffers) - // { - // await foreach (var item in buffer.EnumerateAll()) - // yield return item; - // } - // } - - // public async Task ForEachBlock(BlockCallback callback, INotifyOperationProgress? notify = null, string? message = null, CancellationToken ct = default) - // { - // foreach (var buffer in _buffers) - // await buffer.ForEachBlock(callback, notify, message, ct); - // } - - // public override Task> GetTypedBlock(uint blockIndex) - // { - // uint curr = 0; - // foreach (var buffer in _buffers) - // { - // if (blockIndex < curr + buffer.BlockCount) - // return buffer.GetTypedBlock(curr - blockIndex); - // curr += buffer.BlockCount; - // } - // throw new Exception("Block not found"); - // } - - // public override async IAsyncEnumerable EnumerateAllTyped() - // { - // foreach (var buffer in _buffers) - // { - // await foreach (var item in buffer.EnumerateAllTyped()) - // yield return item; - // } - // } - //} -} diff --git a/BrightData/ExtensionMethods.Buffers.cs b/BrightData/ExtensionMethods.Buffers.cs index f3dc8dfb..89508f79 100644 --- a/BrightData/ExtensionMethods.Buffers.cs +++ b/BrightData/ExtensionMethods.Buffers.cs @@ -1065,7 +1065,6 @@ public static IReadOnlyBuffer ConvertUnmanagedTo(this IReadOnlyBuffer buffer, Ty /// /// Converts the typed buffer to a buffer of objects /// - /// /// /// public static IReadOnlyBuffer ToObjectBuffer(this IReadOnlyBuffer buffer) diff --git a/BrightData/Interfaces.LinearAlgebra.cs b/BrightData/Interfaces.LinearAlgebra.cs index a23cbf6c..5f91ec1a 100644 --- a/BrightData/Interfaces.LinearAlgebra.cs +++ b/BrightData/Interfaces.LinearAlgebra.cs @@ -799,13 +799,13 @@ public interface ITensor : IDisposable, IHaveBrightDataContext uint TotalSize { get; } /// - /// Tensor shape - for a vector the array will have a single element, for a matrix it will be [columns, rows], a 3D tensor will be [columns, rows, depth] etc + /// Tensor shape - for a vector the array will have a single element, for a matrix it will be [columns, rows], a 3D tensor will be [columns, rows, depth] etc. /// uint[] Shape { get; } } /// - /// Typed tensor interface - vector, matrix, 3D tensor etc + /// Typed tensor interface - vector, matrix, 3D tensor etc. /// public interface ITensor : ITensor, IReadOnlyTensor, IHaveLinearAlgebraProvider, IHaveTensorSegment where T: unmanaged, IBinaryFloatingPointIeee754, IMinMaxValue @@ -935,7 +935,7 @@ public interface ITensor : ITensor, IReadOnlyTensor, IHaveLinearAlgebraPro } /// - /// Typed tensor interface - vector, matrix, 3D tensor etc + /// Typed tensor interface - vector, matrix, 3D tensor etc. /// /// /// diff --git a/BrightData/LinearAlgebra/MutableMatrix.cs b/BrightData/LinearAlgebra/MutableMatrix.cs index 8e4f3761..1547624c 100644 --- a/BrightData/LinearAlgebra/MutableMatrix.cs +++ b/BrightData/LinearAlgebra/MutableMatrix.cs @@ -4,6 +4,7 @@ using System.Linq; using System.Numerics; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Threading.Tasks; using BrightData.LinearAlgebra.ReadOnly; using BrightData.LinearAlgebra.Segments; @@ -24,8 +25,8 @@ namespace BrightData.LinearAlgebra /// Number of columns /// Linear algebra provider public class MutableMatrix(INumericSegment data, uint rows, uint columns, LAP lap) : MutableTensorBase, IMatrix, LAP>(data, lap), IMatrix - where T: unmanaged, IBinaryFloatingPointIeee754, IMinMaxValue - where LAP: LinearAlgebraProvider + where T : unmanaged, IBinaryFloatingPointIeee754, IMinMaxValue + where LAP : LinearAlgebraProvider { /// public uint RowCount { get; private set; } = rows; @@ -79,7 +80,7 @@ protected set /// public INumericSegment GetRow(uint index) { - if(index > RowCount) + if (index > RowCount) throw new ArgumentOutOfRangeException(nameof(index), $"Number of rows is {RowCount} but index {index} was requested"); return new MutableTensorSegmentWrapper(Segment, index, RowCount, ColumnCount); } @@ -87,7 +88,7 @@ public INumericSegment GetRow(uint index) /// public virtual INumericSegment GetColumn(uint index) { - if(index > ColumnCount) + if (index > ColumnCount) throw new ArgumentOutOfRangeException(nameof(index), $"Number of columns is {ColumnCount} but index {index} was requested"); return new MutableTensorSegmentWrapper(Segment, index * RowCount, 1, RowCount); } @@ -95,7 +96,7 @@ public virtual INumericSegment GetColumn(uint index) /// public virtual IReadOnlyNumericSegment GetReadOnlyRow(uint index) { - if(index > RowCount) + if (index > RowCount) throw new ArgumentOutOfRangeException(nameof(index), $"Number of rows is {RowCount} but index {index} was requested"); return new ReadOnlyTensorSegmentWrapper(Segment, index, RowCount, ColumnCount); } @@ -103,7 +104,7 @@ public virtual IReadOnlyNumericSegment GetReadOnlyRow(uint index) /// public virtual IReadOnlyNumericSegment GetReadOnlyColumn(uint index) { - if(index > ColumnCount) + if (index > ColumnCount) throw new ArgumentOutOfRangeException(nameof(index), $"Number of columns is {ColumnCount} but index {index} was requested"); return new ReadOnlyTensorSegmentWrapper(Segment, index * RowCount, 1, RowCount); } @@ -317,13 +318,14 @@ static unsafe IMatrix MultiplyWithThisTransposed(LinearAlgebraProvider lap fixed (T* matrixPtr = matrixSpan) fixed (T* otherPtr = otherSpan) fixed (T* retPtr = retSpan) { - MatrixMultiplyChunked(matrixPtr, otherPtr, lda, rowCount, columnCount, retPtr); + //MatrixMultiplyChunked(matrixPtr, otherPtr, lda, rowCount, columnCount, retPtr); + MatrixMultiplyTiled2(matrixPtr, otherPtr, lda, rowCount, columnCount, retPtr); } } finally { - if(wasMatrixTempUsed) + if (wasMatrixTempUsed) matrixTemp.Dispose(); - if(wasOtherTempUsed) + if (wasOtherTempUsed) otherTemp.Dispose(); } @@ -348,9 +350,9 @@ static unsafe void MatrixMultiplyChunked(T* a, T* b, int size, uint rows, uint c return; - [MethodImpl(MethodImplOptions.AggressiveInlining)]void Multiply(long startIndex) + [MethodImpl(MethodImplOptions.AggressiveInlining)] void Multiply(long startIndex) { - for(long index = startIndex, len = Math.Min(startIndex + ChunkSize, totalSize); index < len; index++) { + for (long index = startIndex, len = Math.Min(startIndex + ChunkSize, totalSize); index < len; index++) { var i = (uint)(index % rows); var j = (uint)(index / rows); @@ -371,6 +373,111 @@ static unsafe void MatrixMultiplyChunked(T* a, T* b, int size, uint rows, uint c } } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static unsafe void MatrixMultiplyTiled(T* a, T* b, int size, uint rows, uint cols, T* ret) + { + const int TileSize = 32; // Size of the tile, should be adjusted based on hardware cache sizes. + var vectorSize = Vector.Count; + var numVectors = size / vectorSize; + var ceiling = numVectors * vectorSize; + var totalSize = rows * cols; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + void MultiplyTile(uint rowStart, uint colStart) + { + for (uint i = rowStart; i < rowStart + TileSize && i < rows; i++) { + for (uint j = colStart; j < colStart + TileSize && j < cols; j++) { + var xPtr = &a[i * size]; + var xSpan = new ReadOnlySpan(xPtr, size); + var xVectors = MemoryMarshal.Cast>(xSpan); + + var yPtr = &b[j * size]; + var ySpan = new ReadOnlySpan(yPtr, size); + var yVectors = MemoryMarshal.Cast>(ySpan); + + var vSum = Vector.Zero; + for (var z = 0; z < numVectors; z++) + vSum += xVectors[z] * yVectors[z]; + + var sum = Vector.Dot(vSum, Vector.One); + for (var z = ceiling; z < size; z++) + sum += xPtr[z] * yPtr[z]; + ret[j * rows + i] = sum; + } + } + } + + if (totalSize >= Consts.MinimumSizeForParallel) { + Parallel.For(0, (int)Math.Ceiling((double)rows / TileSize), rowTile => { + for (uint colTile = 0; colTile < cols; colTile += TileSize) { + MultiplyTile((uint)rowTile * TileSize, colTile); + } + }); + } + else { + for (uint rowTile = 0; rowTile < rows; rowTile += TileSize) { + for (uint colTile = 0; colTile < cols; colTile += TileSize) { + MultiplyTile(rowTile, colTile); + } + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static unsafe void MatrixMultiplyTiled2(T* a, T* b, int size, uint rows, uint cols, T* ret) + { + const int L1BlockSize = 32; + const int L2BlockSize = 64; + var vectorSize = Vector.Count; + var numVectors = size / vectorSize; + var ceiling = numVectors * vectorSize; + var totalSize = rows * cols; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + void MultiplyBlock(uint rowStart, uint colStart, uint rowEnd, uint colEnd) + { + for (uint i = rowStart; i < rowEnd && i < rows; i += L1BlockSize) { + for (uint j = colStart; j < colEnd && j < cols; j += L1BlockSize) { + for (uint ii = i; ii < i + L1BlockSize && ii < rowEnd && ii < rows; ii++) { + for (uint jj = j; jj < j + L1BlockSize && jj < colEnd && jj < cols; jj++) { + var xPtr = &a[ii * size]; + var xSpan = new ReadOnlySpan(xPtr, size); + var xVectors = MemoryMarshal.Cast>(xSpan); + + var yPtr = &b[jj * size]; + var ySpan = new ReadOnlySpan(yPtr, size); + var yVectors = MemoryMarshal.Cast>(ySpan); + + var vSum = Vector.Zero; + for (var z = 0; z < numVectors; z++) + vSum += xVectors[z] * yVectors[z]; + + var sum = Vector.Dot(vSum, Vector.One); + for (var z = ceiling; z < size; z++) + sum += xPtr[z] * yPtr[z]; + ret[jj * rows + ii] = sum; + } + } + } + } + } + + if (totalSize >= Consts.MinimumSizeForParallel) { + Parallel.For(0, (int)Math.Ceiling((double)rows / L2BlockSize), rowTile => { + for (uint colTile = 0; colTile < cols; colTile += L2BlockSize) { + MultiplyBlock((uint)rowTile * L2BlockSize, colTile, (uint)((rowTile + 1) * L2BlockSize), colTile + L2BlockSize); + } + }); + } + else { + for (uint rowTile = 0; rowTile < rows; rowTile += L2BlockSize) { + for (uint colTile = 0; colTile < cols; colTile += L2BlockSize) { + MultiplyBlock(rowTile, colTile, rowTile + L2BlockSize, colTile + L2BlockSize); + } + } + } + } + /// public override string ToString() { diff --git a/BrightWire/BrightWire.xml b/BrightWire/BrightWire.xml index beaaf830..6011257f 100644 --- a/BrightWire/BrightWire.xml +++ b/BrightWire/BrightWire.xml @@ -1992,6 +1992,16 @@ Base class for graph nodes + + + Callback method when the node has executed + + + + + Called when the node is executed + + Constructor diff --git a/BrightWire/ExecutionGraph/Node/NodeBase.cs b/BrightWire/ExecutionGraph/Node/NodeBase.cs index 88186481..d0429d29 100644 --- a/BrightWire/ExecutionGraph/Node/NodeBase.cs +++ b/BrightWire/ExecutionGraph/Node/NodeBase.cs @@ -20,7 +20,14 @@ public abstract class NodeBase : ICanInitialiseNode, IDisposable, ICanSerialise string? _name; List _output = []; + /// + /// Callback method when the node has executed + /// public delegate void ForwardDelegate(NodeBase previous, NodeBase current, IGraphData input, IGraphData? output); + + /// + /// Called when the node is executed + /// public event ForwardDelegate? OnForward; /// diff --git a/BrightWire/Models/StringTable.cs b/BrightWire/Models/StringTable.cs index d3798fd8..9164568d 100644 --- a/BrightWire/Models/StringTable.cs +++ b/BrightWire/Models/StringTable.cs @@ -1,6 +1,4 @@ -using System; - -namespace BrightWire.Models +namespace BrightWire.Models { /// /// An array of indexed strings diff --git a/ExampleCode/Program.cs b/ExampleCode/Program.cs index baec6c08..4a6c3f3d 100644 --- a/ExampleCode/Program.cs +++ b/ExampleCode/Program.cs @@ -8,7 +8,6 @@ using BrightData.Cuda; using BrightData.LinearAlgebra; using BrightData.MKL; -using BrightData.Parquet; using BrightWire; using ExampleCode.DataSet; using ExampleCode.DataTableTrainers; @@ -100,7 +99,7 @@ static async Task IrisClassification(BrightDataContext context, bool useMkl) Start(context, useMkl); using var iris = await context.Iris(); await iris.TrainNaiveBayes(); - iris.TrainDecisionTree(); + await iris.TrainDecisionTree(); await iris.TrainRandomForest(500, 7); await iris.TrainKNearestNeighbours(10); //iris.TrainMultinomialLogisticRegression(500, 0.3f, 0.1f);