Skip to content

Commit

Permalink
Minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 22, 2023
1 parent 37b8847 commit c88a2d2
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 35 deletions.
8 changes: 1 addition & 7 deletions src/chat.zig
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,7 @@ system_prompt: []const u8,
user_prompt: []const u8,

pub fn createLeaky(allocator: std.mem.Allocator, args: ChatArgs) !Self {
const transformer = try Transformer.createLeaky(
allocator,
args.model_path,
args.sequence_length,
args.thread_count,
);

const transformer = try Transformer.createLeaky(allocator, args);
const vocab_size = transformer.checkpoint.vocab_size;

return .{
Expand Down
22 changes: 11 additions & 11 deletions src/checkpoint.zig
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ ffn_up_weights: []const Matrix,
output_norm_weight: Vector,
output_weight: Matrix,

pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8, thread_count: usize) !Self {
pub fn readLeaky(allocator: std.mem.Allocator, args: anytype) !Self {
const path = try std.fs.path.join(
allocator,
&[_][]const u8{ model_path, "checkpoint_v1.bin" },
&[_][]const u8{ args.model_path, "checkpoint_v1.bin" },
);

defer allocator.free(path);
Expand Down Expand Up @@ -85,7 +85,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8, thread_co
n_layers,
embedding_size,
embedding_size,
thread_count,
args.thread_count,
);

const attention_head_size: usize = embedding_size / n_attention_heads;
Expand All @@ -96,7 +96,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8, thread_co
n_layers,
n_attention_query_groups * attention_head_size,
embedding_size,
thread_count,
args.thread_count,
);

const attention_value_weights = try Matrix.readMultipleLeaky(
Expand All @@ -105,7 +105,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8, thread_co
n_layers,
n_attention_query_groups * attention_head_size,
embedding_size,
thread_count,
args.thread_count,
);

const attention_output_weights = try Matrix.readMultipleLeaky(
Expand All @@ -114,7 +114,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8, thread_co
n_layers,
embedding_size,
embedding_size,
thread_count,
args.thread_count,
);

const ffn_gate_weights = try Matrix.readMultipleLeaky(
Expand All @@ -123,7 +123,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8, thread_co
n_layers,
ffn_hidden_size,
embedding_size,
thread_count,
args.thread_count,
);

const ffn_down_weights = try Matrix.readMultipleLeaky(
Expand All @@ -132,7 +132,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8, thread_co
n_layers,
embedding_size,
ffn_hidden_size,
thread_count,
args.thread_count,
);

const ffn_up_weights = try Matrix.readMultipleLeaky(
Expand All @@ -141,13 +141,13 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8, thread_co
n_layers,
ffn_hidden_size,
embedding_size,
thread_count,
args.thread_count,
);

const output_weight = if (shared_output_weight)
Matrix{ .rows = embedding_weights, .thread_count = thread_count }
Matrix{ .rows = embedding_weights, .thread_count = args.thread_count }
else
try Matrix.readLeaky(allocator, file, vocab_size, embedding_size, thread_count);
try Matrix.readLeaky(allocator, file, vocab_size, embedding_size, args.thread_count);

return .{
.embedding_size = embedding_size,
Expand Down
8 changes: 1 addition & 7 deletions src/generator.zig
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,7 @@ prompt_tokens: []usize,
verbose: bool,

pub fn createLeaky(allocator: std.mem.Allocator, args: GeneratorArgs) !Self {
const transformer = try Transformer.createLeaky(
allocator,
args.model_path,
args.sequence_length,
args.thread_count,
);

const transformer = try Transformer.createLeaky(allocator, args);
const vocab_size = transformer.checkpoint.vocab_size;
const tokenizer = try Tokenizer.readLeaky(allocator, args.model_path, vocab_size);

Expand Down
15 changes: 5 additions & 10 deletions src/transformer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,13 @@ ffn: FFN,
hidden: Vector,
output: Vector,

pub fn createLeaky(
allocator: std.mem.Allocator,
model_path: []const u8,
custom_sequence_length: usize,
thread_count: usize,
) !Self {
const checkpoint = try Checkpoint.readLeaky(allocator, model_path, thread_count);

const sequence_length = if (custom_sequence_length == 0)
pub fn createLeaky(allocator: std.mem.Allocator, args: anytype) !Self {
const checkpoint = try Checkpoint.readLeaky(allocator, args);

const sequence_length = if (args.sequence_length == 0)
checkpoint.max_sequence_length
else
@min(custom_sequence_length, checkpoint.max_sequence_length);
@min(checkpoint.max_sequence_length, args.sequence_length);

return .{
.checkpoint = checkpoint,
Expand Down

0 comments on commit c88a2d2

Please sign in to comment.