From c88a2d2abbddeb201c1f4a47d57e5c73c4511605 Mon Sep 17 00:00:00 2001 From: Clemens Akens Date: Sun, 22 Oct 2023 23:49:42 +0200 Subject: [PATCH] Minor refactoring --- src/chat.zig | 8 +------- src/checkpoint.zig | 22 +++++++++++----------- src/generator.zig | 8 +------- src/transformer.zig | 15 +++++---------- 4 files changed, 18 insertions(+), 35 deletions(-) diff --git a/src/chat.zig b/src/chat.zig index a37f2ce..cec0fed 100644 --- a/src/chat.zig +++ b/src/chat.zig @@ -14,13 +14,7 @@ system_prompt: []const u8, user_prompt: []const u8, pub fn createLeaky(allocator: std.mem.Allocator, args: ChatArgs) !Self { - const transformer = try Transformer.createLeaky( - allocator, - args.model_path, - args.sequence_length, - args.thread_count, - ); - + const transformer = try Transformer.createLeaky(allocator, args); const vocab_size = transformer.checkpoint.vocab_size; return .{ diff --git a/src/checkpoint.zig b/src/checkpoint.zig index a09ef7d..dc9a543 100644 --- a/src/checkpoint.zig +++ b/src/checkpoint.zig @@ -25,10 +25,10 @@ ffn_up_weights: []const Matrix, output_norm_weight: Vector, output_weight: Matrix, -pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8, thread_count: usize) !Self { +pub fn readLeaky(allocator: std.mem.Allocator, args: anytype) !Self { const path = try std.fs.path.join( allocator, - &[_][]const u8{ model_path, "checkpoint_v1.bin" }, + &[_][]const u8{ args.model_path, "checkpoint_v1.bin" }, ); defer allocator.free(path); @@ -85,7 +85,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8, thread_co n_layers, embedding_size, embedding_size, - thread_count, + args.thread_count, ); const attention_head_size: usize = embedding_size / n_attention_heads; @@ -96,7 +96,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8, thread_co n_layers, n_attention_query_groups * attention_head_size, embedding_size, - thread_count, + args.thread_count, ); const attention_value_weights = try Matrix.readMultipleLeaky( @@ -105,7 +105,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8, thread_co n_layers, n_attention_query_groups * attention_head_size, embedding_size, - thread_count, + args.thread_count, ); const attention_output_weights = try Matrix.readMultipleLeaky( @@ -114,7 +114,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8, thread_co n_layers, embedding_size, embedding_size, - thread_count, + args.thread_count, ); const ffn_gate_weights = try Matrix.readMultipleLeaky( @@ -123,7 +123,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8, thread_co n_layers, ffn_hidden_size, embedding_size, - thread_count, + args.thread_count, ); const ffn_down_weights = try Matrix.readMultipleLeaky( @@ -132,7 +132,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8, thread_co n_layers, embedding_size, ffn_hidden_size, - thread_count, + args.thread_count, ); const ffn_up_weights = try Matrix.readMultipleLeaky( @@ -141,13 +141,13 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8, thread_co n_layers, ffn_hidden_size, embedding_size, - thread_count, + args.thread_count, ); const output_weight = if (shared_output_weight) - Matrix{ .rows = embedding_weights, .thread_count = thread_count } + Matrix{ .rows = embedding_weights, .thread_count = args.thread_count } else - try Matrix.readLeaky(allocator, file, vocab_size, embedding_size, thread_count); + try Matrix.readLeaky(allocator, file, vocab_size, embedding_size, args.thread_count); return .{ .embedding_size = embedding_size, diff --git a/src/generator.zig b/src/generator.zig index 9cfd451..f95aef9 100644 --- a/src/generator.zig +++ b/src/generator.zig @@ -14,13 +14,7 @@ prompt_tokens: []usize, verbose: bool, pub fn createLeaky(allocator: std.mem.Allocator, args: GeneratorArgs) !Self { - const transformer = try Transformer.createLeaky( - allocator, - args.model_path, - args.sequence_length, - args.thread_count, - ); - + const transformer = try Transformer.createLeaky(allocator, args); const vocab_size = transformer.checkpoint.vocab_size; const tokenizer = try Tokenizer.readLeaky(allocator, args.model_path, vocab_size); diff --git a/src/transformer.zig b/src/transformer.zig index dd3e167..3bcdadb 100644 --- a/src/transformer.zig +++ b/src/transformer.zig @@ -13,18 +13,13 @@ ffn: FFN, hidden: Vector, output: Vector, -pub fn createLeaky( - allocator: std.mem.Allocator, - model_path: []const u8, - custom_sequence_length: usize, - thread_count: usize, -) !Self { - const checkpoint = try Checkpoint.readLeaky(allocator, model_path, thread_count); - - const sequence_length = if (custom_sequence_length == 0) +pub fn createLeaky(allocator: std.mem.Allocator, args: anytype) !Self { + const checkpoint = try Checkpoint.readLeaky(allocator, args); + + const sequence_length = if (args.sequence_length == 0) checkpoint.max_sequence_length else - @min(custom_sequence_length, checkpoint.max_sequence_length); + @min(checkpoint.max_sequence_length, args.sequence_length); return .{ .checkpoint = checkpoint,