Skip to content

Commit

Permalink
Improve memory freeing in error case
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Aug 27, 2023
1 parent e7542cd commit 3950c0d
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 17 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ through the following linked [tests](./test.sh).
## TODOs

- Add support for chat (https://github.com/karpathy/llama2.c/pull/343)
- Use `errdefer` in all init methods (see generator)

## Usage

Expand Down
46 changes: 38 additions & 8 deletions src/attention.zig
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,48 @@ pub fn init(allocator: std.mem.Allocator, checkpoint: Checkpoint, seq_len: usize
const kv_dim = checkpoint.kv_dim;
const kv_cache_dim = checkpoint.n_layers * seq_len * kv_dim;

const input_buffer = try allocator.alloc(f32, dim);

errdefer allocator.free(input_buffer);

const output_buffer = try allocator.alloc(f32, dim);

errdefer allocator.free(output_buffer);

const scores_buffer = try allocator.alloc(f32, checkpoint.n_heads * seq_len);

errdefer allocator.free(scores_buffer);

const queries_buffer = try allocator.alloc(f32, dim);

errdefer allocator.free(queries_buffer);

const keys_buffer = try allocator.alloc(f32, kv_dim);

errdefer allocator.free(keys_buffer);

const values_buffer = try allocator.alloc(f32, kv_dim);

errdefer allocator.free(values_buffer);

const key_cache = try allocator.alloc(f32, kv_cache_dim);

errdefer allocator.free(key_cache);

const value_cache = try allocator.alloc(f32, kv_cache_dim);

return Self{
.allocator = allocator,
.checkpoint = checkpoint,
.seq_len = seq_len,
.input_buffer = try allocator.alloc(f32, dim),
.output_buffer = try allocator.alloc(f32, dim),
.scores_buffer = try allocator.alloc(f32, checkpoint.n_heads * seq_len),
.queries_buffer = try allocator.alloc(f32, dim),
.keys_buffer = try allocator.alloc(f32, kv_dim),
.values_buffer = try allocator.alloc(f32, kv_dim),
.key_cache = try allocator.alloc(f32, kv_cache_dim),
.value_cache = try allocator.alloc(f32, kv_cache_dim),
.input_buffer = input_buffer,
.output_buffer = output_buffer,
.scores_buffer = scores_buffer,
.queries_buffer = queries_buffer,
.keys_buffer = keys_buffer,
.values_buffer = values_buffer,
.key_cache = key_cache,
.value_cache = value_cache,
};
}

Expand Down
3 changes: 3 additions & 0 deletions src/cli.zig
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ pub fn init(allocator: std.mem.Allocator) !Self {
var tokenizer_path: ?[]const u8 = null;
var mmap: bool = true;
var timer: bool = true;

var arg_iterator = try std.process.argsWithAllocator(allocator);

errdefer arg_iterator.deinit();

_ = arg_iterator.next().?;

const checkpoint_path = arg_iterator.next() orelse try exit();
Expand Down
22 changes: 18 additions & 4 deletions src/feed_forward.zig
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,27 @@ pub fn init(allocator: std.mem.Allocator, checkpoint: Checkpoint) !Self {
const dim = checkpoint.dim;
const hidden_dim = checkpoint.hidden_dim;

const input_buffer = try allocator.alloc(f32, dim);

errdefer allocator.free(input_buffer);

const hidden_buffer = try allocator.alloc(f32, hidden_dim);

errdefer allocator.free(hidden_buffer);

const residual_buffer = try allocator.alloc(f32, hidden_dim);

errdefer allocator.free(residual_buffer);

const output_buffer = try allocator.alloc(f32, dim);

return Self{
.allocator = allocator,
.checkpoint = checkpoint,
.input_buffer = try allocator.alloc(f32, dim),
.hidden_buffer = try allocator.alloc(f32, hidden_dim),
.residual_buffer = try allocator.alloc(f32, hidden_dim),
.output_buffer = try allocator.alloc(f32, dim),
.input_buffer = input_buffer,
.hidden_buffer = hidden_buffer,
.residual_buffer = residual_buffer,
.output_buffer = output_buffer,
};
}

Expand Down
9 changes: 9 additions & 0 deletions src/tokenizer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,17 @@ sorted_vocab: []const VocabEntry,

pub fn init(allocator: std.mem.Allocator, path: []const u8, vocab_size: usize) !Self {
var vocab = try allocator.alloc([]u8, vocab_size);

errdefer for (vocab) |word| {
allocator.free(word);
};

errdefer allocator.free(vocab);

var word_scores = try allocator.alloc(f32, vocab_size);

errdefer allocator.free(word_scores);

const file = try std.fs.cwd().openFile(path, .{});

defer file.close();
Expand Down
24 changes: 20 additions & 4 deletions src/transformer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,29 @@ logits: []f32,
pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {
const checkpoint = try Checkpoint.init(if (cli.mmap) null else allocator, cli.checkpoint_path);

errdefer checkpoint.deinit();

const attention = try Attention.init(allocator, checkpoint, cli.n_steps);

errdefer attention.deinit();

const feed_forward = try FeedForward.init(allocator, checkpoint);

errdefer feed_forward.deinit();

const hidden_state = try allocator.alloc(f32, checkpoint.dim);

errdefer allocator.free(hidden_state);

const logits = try allocator.alloc(f32, checkpoint.vocab_size);

return Self{
.allocator = allocator,
.checkpoint = checkpoint,
.attention = try Attention.init(allocator, checkpoint, cli.n_steps),
.feed_forward = try FeedForward.init(allocator, checkpoint),
.hidden_state = try allocator.alloc(f32, checkpoint.dim),
.logits = try allocator.alloc(f32, checkpoint.vocab_size),
.attention = attention,
.feed_forward = feed_forward,
.hidden_state = hidden_state,
.logits = logits,
};
}

Expand Down

0 comments on commit 3950c0d

Please sign in to comment.