Skip to content

Commit

Permalink
small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack Dermody committed May 18, 2024
1 parent 54bce96 commit be30e72
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 104 deletions.
7 changes: 3 additions & 4 deletions BrightData/BrightData.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2655,7 +2655,6 @@
<summary>
Converts the typed buffer to a buffer of objects
</summary>
<typeparam name="T"></typeparam>
<param name="buffer"></param>
<returns></returns>
</member>
Expand Down Expand Up @@ -8447,12 +8446,12 @@
</member>
<member name="P:BrightData.ITensor.Shape">
<summary>
Tensor shape - for a vector the array will have a single element, for a matrix it will be [columns, rows], a 3D tensor will be [columns, rows, depth] etc
Tensor shape - for a vector the array will have a single element, for a matrix it will be [columns, rows], a 3D tensor will be [columns, rows, depth] etc.
</summary>
</member>
<member name="T:BrightData.ITensor`1">
<summary>
Typed tensor interface - vector, matrix, 3D tensor etc
Typed tensor interface - vector, matrix, 3D tensor etc.
</summary>
</member>
<member name="M:BrightData.ITensor`1.Reshape">
Expand Down Expand Up @@ -8580,7 +8579,7 @@
</member>
<member name="T:BrightData.ITensorType`3">
<summary>
Typed tensor interface - vector, matrix, 3D tensor etc
Typed tensor interface - vector, matrix, 3D tensor etc.
</summary>
<typeparam name="T"></typeparam>
<typeparam name="TT"></typeparam>
Expand Down
80 changes: 0 additions & 80 deletions BrightData/Buffer/ReadOnly/Helper/BufferConcatenator.cs

This file was deleted.

1 change: 0 additions & 1 deletion BrightData/ExtensionMethods.Buffers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,6 @@ public static IReadOnlyBuffer ConvertUnmanagedTo(this IReadOnlyBuffer buffer, Ty
/// <summary>
/// Converts the typed buffer to a buffer of objects
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="buffer"></param>
/// <returns></returns>
public static IReadOnlyBuffer<object> ToObjectBuffer(this IReadOnlyBuffer buffer)
Expand Down
6 changes: 3 additions & 3 deletions BrightData/Interfaces.LinearAlgebra.cs
Original file line number Diff line number Diff line change
Expand Up @@ -799,13 +799,13 @@ public interface ITensor : IDisposable, IHaveBrightDataContext
uint TotalSize { get; }

/// <summary>
/// Tensor shape - for a vector the array will have a single element, for a matrix it will be [columns, rows], a 3D tensor will be [columns, rows, depth] etc
/// Tensor shape - for a vector the array will have a single element, for a matrix it will be [columns, rows], a 3D tensor will be [columns, rows, depth] etc.
/// </summary>
uint[] Shape { get; }
}

/// <summary>
/// Typed tensor interface - vector, matrix, 3D tensor etc
/// Typed tensor interface - vector, matrix, 3D tensor etc.
/// </summary>
public interface ITensor<T> : ITensor, IReadOnlyTensor<T>, IHaveLinearAlgebraProvider<T>, IHaveTensorSegment<T>
where T: unmanaged, IBinaryFloatingPointIeee754<T>, IMinMaxValue<T>
Expand Down Expand Up @@ -935,7 +935,7 @@ public interface ITensor<T> : ITensor, IReadOnlyTensor<T>, IHaveLinearAlgebraPro
}

/// <summary>
/// Typed tensor interface - vector, matrix, 3D tensor etc
/// Typed tensor interface - vector, matrix, 3D tensor etc.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <typeparam name="TT"></typeparam>
Expand Down
129 changes: 118 additions & 11 deletions BrightData/LinearAlgebra/MutableMatrix.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Linq;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
using BrightData.LinearAlgebra.ReadOnly;
using BrightData.LinearAlgebra.Segments;
Expand All @@ -24,8 +25,8 @@ namespace BrightData.LinearAlgebra
/// <param name="columns">Number of columns</param>
/// <param name="lap">Linear algebra provider</param>
public class MutableMatrix<T, LAP>(INumericSegment<T> data, uint rows, uint columns, LAP lap) : MutableTensorBase<T, IReadOnlyMatrix<T>, IMatrix<T>, LAP>(data, lap), IMatrix<T>
where T: unmanaged, IBinaryFloatingPointIeee754<T>, IMinMaxValue<T>
where LAP: LinearAlgebraProvider<T>
where T : unmanaged, IBinaryFloatingPointIeee754<T>, IMinMaxValue<T>
where LAP : LinearAlgebraProvider<T>
{
/// <inheritdoc />
public uint RowCount { get; private set; } = rows;
Expand Down Expand Up @@ -79,31 +80,31 @@ protected set
/// <inheritdoc />
public INumericSegment<T> GetRow(uint index)
{
if(index > RowCount)
if (index > RowCount)
throw new ArgumentOutOfRangeException(nameof(index), $"Number of rows is {RowCount} but index {index} was requested");
return new MutableTensorSegmentWrapper<T>(Segment, index, RowCount, ColumnCount);
}

/// <inheritdoc />
public virtual INumericSegment<T> GetColumn(uint index)
{
if(index > ColumnCount)
if (index > ColumnCount)
throw new ArgumentOutOfRangeException(nameof(index), $"Number of columns is {ColumnCount} but index {index} was requested");
return new MutableTensorSegmentWrapper<T>(Segment, index * RowCount, 1, RowCount);
}

/// <inheritdoc />
public virtual IReadOnlyNumericSegment<T> GetReadOnlyRow(uint index)
{
if(index > RowCount)
if (index > RowCount)
throw new ArgumentOutOfRangeException(nameof(index), $"Number of rows is {RowCount} but index {index} was requested");
return new ReadOnlyTensorSegmentWrapper<T>(Segment, index, RowCount, ColumnCount);
}

/// <inheritdoc />
public virtual IReadOnlyNumericSegment<T> GetReadOnlyColumn(uint index)
{
if(index > ColumnCount)
if (index > ColumnCount)
throw new ArgumentOutOfRangeException(nameof(index), $"Number of columns is {ColumnCount} but index {index} was requested");
return new ReadOnlyTensorSegmentWrapper<T>(Segment, index * RowCount, 1, RowCount);
}
Expand Down Expand Up @@ -317,13 +318,14 @@ static unsafe IMatrix<T> MultiplyWithThisTransposed(LinearAlgebraProvider<T> lap
fixed (T* matrixPtr = matrixSpan)
fixed (T* otherPtr = otherSpan)
fixed (T* retPtr = retSpan) {
MatrixMultiplyChunked(matrixPtr, otherPtr, lda, rowCount, columnCount, retPtr);
//MatrixMultiplyChunked(matrixPtr, otherPtr, lda, rowCount, columnCount, retPtr);
MatrixMultiplyTiled2(matrixPtr, otherPtr, lda, rowCount, columnCount, retPtr);
}
}
finally {
if(wasMatrixTempUsed)
if (wasMatrixTempUsed)
matrixTemp.Dispose();
if(wasOtherTempUsed)
if (wasOtherTempUsed)
otherTemp.Dispose();
}

Expand All @@ -348,9 +350,9 @@ static unsafe void MatrixMultiplyChunked(T* a, T* b, int size, uint rows, uint c

return;

[MethodImpl(MethodImplOptions.AggressiveInlining)]void Multiply(long startIndex)
[MethodImpl(MethodImplOptions.AggressiveInlining)] void Multiply(long startIndex)
{
for(long index = startIndex, len = Math.Min(startIndex + ChunkSize, totalSize); index < len; index++) {
for (long index = startIndex, len = Math.Min(startIndex + ChunkSize, totalSize); index < len; index++) {
var i = (uint)(index % rows);
var j = (uint)(index / rows);

Expand All @@ -371,6 +373,111 @@ static unsafe void MatrixMultiplyChunked(T* a, T* b, int size, uint rows, uint c
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
static unsafe void MatrixMultiplyTiled(T* a, T* b, int size, uint rows, uint cols, T* ret)
{
const int TileSize = 32; // Size of the tile, should be adjusted based on hardware cache sizes.
var vectorSize = Vector<T>.Count;
var numVectors = size / vectorSize;
var ceiling = numVectors * vectorSize;
var totalSize = rows * cols;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
void MultiplyTile(uint rowStart, uint colStart)
{
for (uint i = rowStart; i < rowStart + TileSize && i < rows; i++) {
for (uint j = colStart; j < colStart + TileSize && j < cols; j++) {
var xPtr = &a[i * size];
var xSpan = new ReadOnlySpan<T>(xPtr, size);
var xVectors = MemoryMarshal.Cast<T, Vector<T>>(xSpan);

var yPtr = &b[j * size];
var ySpan = new ReadOnlySpan<T>(yPtr, size);
var yVectors = MemoryMarshal.Cast<T, Vector<T>>(ySpan);

var vSum = Vector<T>.Zero;
for (var z = 0; z < numVectors; z++)
vSum += xVectors[z] * yVectors[z];

var sum = Vector.Dot(vSum, Vector<T>.One);
for (var z = ceiling; z < size; z++)
sum += xPtr[z] * yPtr[z];
ret[j * rows + i] = sum;
}
}
}

if (totalSize >= Consts.MinimumSizeForParallel) {
Parallel.For(0, (int)Math.Ceiling((double)rows / TileSize), rowTile => {
for (uint colTile = 0; colTile < cols; colTile += TileSize) {
MultiplyTile((uint)rowTile * TileSize, colTile);
}
});
}
else {
for (uint rowTile = 0; rowTile < rows; rowTile += TileSize) {
for (uint colTile = 0; colTile < cols; colTile += TileSize) {
MultiplyTile(rowTile, colTile);
}
}
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
static unsafe void MatrixMultiplyTiled2(T* a, T* b, int size, uint rows, uint cols, T* ret)
{
const int L1BlockSize = 32;
const int L2BlockSize = 64;
var vectorSize = Vector<T>.Count;
var numVectors = size / vectorSize;
var ceiling = numVectors * vectorSize;
var totalSize = rows * cols;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
void MultiplyBlock(uint rowStart, uint colStart, uint rowEnd, uint colEnd)
{
for (uint i = rowStart; i < rowEnd && i < rows; i += L1BlockSize) {
for (uint j = colStart; j < colEnd && j < cols; j += L1BlockSize) {
for (uint ii = i; ii < i + L1BlockSize && ii < rowEnd && ii < rows; ii++) {
for (uint jj = j; jj < j + L1BlockSize && jj < colEnd && jj < cols; jj++) {
var xPtr = &a[ii * size];
var xSpan = new ReadOnlySpan<T>(xPtr, size);
var xVectors = MemoryMarshal.Cast<T, Vector<T>>(xSpan);

var yPtr = &b[jj * size];
var ySpan = new ReadOnlySpan<T>(yPtr, size);
var yVectors = MemoryMarshal.Cast<T, Vector<T>>(ySpan);

var vSum = Vector<T>.Zero;
for (var z = 0; z < numVectors; z++)
vSum += xVectors[z] * yVectors[z];

var sum = Vector.Dot(vSum, Vector<T>.One);
for (var z = ceiling; z < size; z++)
sum += xPtr[z] * yPtr[z];
ret[jj * rows + ii] = sum;
}
}
}
}
}

if (totalSize >= Consts.MinimumSizeForParallel) {
Parallel.For(0, (int)Math.Ceiling((double)rows / L2BlockSize), rowTile => {
for (uint colTile = 0; colTile < cols; colTile += L2BlockSize) {
MultiplyBlock((uint)rowTile * L2BlockSize, colTile, (uint)((rowTile + 1) * L2BlockSize), colTile + L2BlockSize);
}
});
}
else {
for (uint rowTile = 0; rowTile < rows; rowTile += L2BlockSize) {
for (uint colTile = 0; colTile < cols; colTile += L2BlockSize) {
MultiplyBlock(rowTile, colTile, rowTile + L2BlockSize, colTile + L2BlockSize);
}
}
}
}

/// <inheritdoc />
public override string ToString()
{
Expand Down
10 changes: 10 additions & 0 deletions BrightWire/BrightWire.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions BrightWire/ExecutionGraph/Node/NodeBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@ public abstract class NodeBase : ICanInitialiseNode, IDisposable, ICanSerialise
string? _name;
List<WireToNode> _output = [];

/// <summary>
/// Callback method when the node has executed
/// </summary>
public delegate void ForwardDelegate(NodeBase previous, NodeBase current, IGraphData input, IGraphData? output);

/// <summary>
/// Called when the node is executed
/// </summary>
public event ForwardDelegate? OnForward;

/// <summary>
Expand Down
4 changes: 1 addition & 3 deletions BrightWire/Models/StringTable.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using System;

namespace BrightWire.Models
namespace BrightWire.Models
{
/// <summary>
/// An array of indexed strings
Expand Down
3 changes: 1 addition & 2 deletions ExampleCode/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
using BrightData.Cuda;
using BrightData.LinearAlgebra;
using BrightData.MKL;
using BrightData.Parquet;
using BrightWire;
using ExampleCode.DataSet;
using ExampleCode.DataTableTrainers;
Expand Down Expand Up @@ -100,7 +99,7 @@ static async Task IrisClassification(BrightDataContext context, bool useMkl)
Start(context, useMkl);
using var iris = await context.Iris();
await iris.TrainNaiveBayes();
iris.TrainDecisionTree();
await iris.TrainDecisionTree();
await iris.TrainRandomForest(500, 7);
await iris.TrainKNearestNeighbours(10);
//iris.TrainMultinomialLogisticRegression(500, 0.3f, 0.1f);
Expand Down

0 comments on commit be30e72

Please sign in to comment.