From 3950c0d3fcf8fc4a7501ae89ece21b55e07216a3 Mon Sep 17 00:00:00 2001 From: Clemens Akens Date: Sun, 27 Aug 2023 13:43:58 +0200 Subject: [PATCH] Improve memory freeing in error case --- README.md | 1 - src/attention.zig | 46 ++++++++++++++++++++++++++++++++++++-------- src/cli.zig | 3 +++ src/feed_forward.zig | 22 +++++++++++++++++---- src/tokenizer.zig | 9 +++++++++ src/transformer.zig | 24 +++++++++++++++++++---- 6 files changed, 88 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 8eecb72..9456fbc 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,6 @@ through the following linked [tests](./test.sh). ## TODOs - Add support for chat (https://github.com/karpathy/llama2.c/pull/343) -- Use `errdefer` in all init methods (see generator) ## Usage diff --git a/src/attention.zig b/src/attention.zig index 49129c9..a19eec1 100644 --- a/src/attention.zig +++ b/src/attention.zig @@ -21,18 +21,48 @@ pub fn init(allocator: std.mem.Allocator, checkpoint: Checkpoint, seq_len: usize const kv_dim = checkpoint.kv_dim; const kv_cache_dim = checkpoint.n_layers * seq_len * kv_dim; + const input_buffer = try allocator.alloc(f32, dim); + + errdefer allocator.free(input_buffer); + + const output_buffer = try allocator.alloc(f32, dim); + + errdefer allocator.free(output_buffer); + + const scores_buffer = try allocator.alloc(f32, checkpoint.n_heads * seq_len); + + errdefer allocator.free(scores_buffer); + + const queries_buffer = try allocator.alloc(f32, dim); + + errdefer allocator.free(queries_buffer); + + const keys_buffer = try allocator.alloc(f32, kv_dim); + + errdefer allocator.free(keys_buffer); + + const values_buffer = try allocator.alloc(f32, kv_dim); + + errdefer allocator.free(values_buffer); + + const key_cache = try allocator.alloc(f32, kv_cache_dim); + + errdefer allocator.free(key_cache); + + const value_cache = try allocator.alloc(f32, kv_cache_dim); + return Self{ .allocator = allocator, .checkpoint = checkpoint, .seq_len = seq_len, - .input_buffer = try allocator.alloc(f32, dim), - .output_buffer = try allocator.alloc(f32, dim), - .scores_buffer = try allocator.alloc(f32, checkpoint.n_heads * seq_len), - .queries_buffer = try allocator.alloc(f32, dim), - .keys_buffer = try allocator.alloc(f32, kv_dim), - .values_buffer = try allocator.alloc(f32, kv_dim), - .key_cache = try allocator.alloc(f32, kv_cache_dim), - .value_cache = try allocator.alloc(f32, kv_cache_dim), + .input_buffer = input_buffer, + .output_buffer = output_buffer, + .scores_buffer = scores_buffer, + .queries_buffer = queries_buffer, + .keys_buffer = keys_buffer, + .values_buffer = values_buffer, + .key_cache = key_cache, + .value_cache = value_cache, }; } diff --git a/src/cli.zig b/src/cli.zig index 443da3c..2713f4f 100644 --- a/src/cli.zig +++ b/src/cli.zig @@ -25,8 +25,11 @@ pub fn init(allocator: std.mem.Allocator) !Self { var tokenizer_path: ?[]const u8 = null; var mmap: bool = true; var timer: bool = true; + var arg_iterator = try std.process.argsWithAllocator(allocator); + errdefer arg_iterator.deinit(); + _ = arg_iterator.next().?; const checkpoint_path = arg_iterator.next() orelse try exit(); diff --git a/src/feed_forward.zig b/src/feed_forward.zig index 80a6350..40b9829 100644 --- a/src/feed_forward.zig +++ b/src/feed_forward.zig @@ -15,13 +15,27 @@ pub fn init(allocator: std.mem.Allocator, checkpoint: Checkpoint) !Self { const dim = checkpoint.dim; const hidden_dim = checkpoint.hidden_dim; + const input_buffer = try allocator.alloc(f32, dim); + + errdefer allocator.free(input_buffer); + + const hidden_buffer = try allocator.alloc(f32, hidden_dim); + + errdefer allocator.free(hidden_buffer); + + const residual_buffer = try allocator.alloc(f32, hidden_dim); + + errdefer allocator.free(residual_buffer); + + const output_buffer = try allocator.alloc(f32, dim); + return Self{ .allocator = allocator, .checkpoint = checkpoint, - .input_buffer = try allocator.alloc(f32, dim), - .hidden_buffer = try allocator.alloc(f32, hidden_dim), - .residual_buffer = try allocator.alloc(f32, hidden_dim), - .output_buffer = try allocator.alloc(f32, dim), + .input_buffer = input_buffer, + .hidden_buffer = hidden_buffer, + .residual_buffer = residual_buffer, + .output_buffer = output_buffer, }; } diff --git a/src/tokenizer.zig b/src/tokenizer.zig index 17a7f9c..b9f4865 100644 --- a/src/tokenizer.zig +++ b/src/tokenizer.zig @@ -10,8 +10,17 @@ sorted_vocab: []const VocabEntry, pub fn init(allocator: std.mem.Allocator, path: []const u8, vocab_size: usize) !Self { var vocab = try allocator.alloc([]u8, vocab_size); + + errdefer for (vocab) |word| { + allocator.free(word); + }; + + errdefer allocator.free(vocab); + var word_scores = try allocator.alloc(f32, vocab_size); + errdefer allocator.free(word_scores); + const file = try std.fs.cwd().openFile(path, .{}); defer file.close(); diff --git a/src/transformer.zig b/src/transformer.zig index 5757f83..70c691d 100644 --- a/src/transformer.zig +++ b/src/transformer.zig @@ -17,13 +17,29 @@ logits: []f32, pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { const checkpoint = try Checkpoint.init(if (cli.mmap) null else allocator, cli.checkpoint_path); + errdefer checkpoint.deinit(); + + const attention = try Attention.init(allocator, checkpoint, cli.n_steps); + + errdefer attention.deinit(); + + const feed_forward = try FeedForward.init(allocator, checkpoint); + + errdefer feed_forward.deinit(); + + const hidden_state = try allocator.alloc(f32, checkpoint.dim); + + errdefer allocator.free(hidden_state); + + const logits = try allocator.alloc(f32, checkpoint.vocab_size); + return Self{ .allocator = allocator, .checkpoint = checkpoint, - .attention = try Attention.init(allocator, checkpoint, cli.n_steps), - .feed_forward = try FeedForward.init(allocator, checkpoint), - .hidden_state = try allocator.alloc(f32, checkpoint.dim), - .logits = try allocator.alloc(f32, checkpoint.vocab_size), + .attention = attention, + .feed_forward = feed_forward, + .hidden_state = hidden_state, + .logits = logits, }; }