Skip to content

Commit

Permalink
Improve multithreading
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 22, 2023
1 parent 88123ac commit f635606
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions src/matrix.zig
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub fn readMultipleLeaky(
return matrices;
}

const max_thread_count = 8;
const max_thread_count = 24;

pub fn multiplyVector(self: Self, input: Vector, output: Vector) !void {
if (self.thread_count == 0) {
Expand All @@ -45,27 +45,30 @@ pub fn multiplyVector(self: Self, input: Vector, output: Vector) !void {
return;
}

const n_threads = @min(try std.Thread.getCpuCount(), max_thread_count, self.thread_count);

if (output.values.len % n_threads != 0) {
return error.UnsupportedThreadCount;
}

const partial_length = output.values.len / n_threads;
const n_threads = @min(max_thread_count, self.thread_count);
const thread_chunk_size = output.values.len / n_threads;

var threads: [max_thread_count]std.Thread = undefined;

for (threads[0..n_threads], 0..) |*thread, index| {
thread.* = try std.Thread.spawn(.{}, computeMatrixVectorMultiplication, .{
self.rows[index * partial_length .. (index + 1) * partial_length],
self.rows[index * thread_chunk_size ..][0..thread_chunk_size],
input,
output.values[index * partial_length .. (index + 1) * partial_length],
output.values[index * thread_chunk_size ..][0..thread_chunk_size],
});
}

for (threads[0..n_threads]) |thread| {
thread.join();
}

if (output.values.len % n_threads > 0) {
try computeMatrixVectorMultiplication(
self.rows[n_threads * thread_chunk_size ..],
input,
output.values[n_threads * thread_chunk_size ..],
);
}
}

fn computeMatrixVectorMultiplication(
Expand Down

0 comments on commit f635606

Please sign in to comment.