Skip to content

Commit

Permalink
Prepare support for more checkpoint file formats
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 18, 2023
1 parent 5c60ca7 commit 5c2cd57
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/checkpoint.zig
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,23 @@ weights: struct {
},

pub fn init(allocator: std.mem.Allocator, model_path: []const u8) !Self {
return try readLegacy(allocator, model_path);
}

pub fn readLegacy(allocator: std.mem.Allocator, model_path: []const u8) !Self {
const path = try std.fs.path.join(
const legacy_path = try std.fs.path.join(
allocator,
&[_][]const u8{ model_path, "checkpoint_legacy.bin" },
);

defer allocator.free(path);
defer allocator.free(legacy_path);

const legacy_file = std.fs.cwd().openFile(legacy_path, .{}) catch null;

const file = try std.fs.cwd().openFile(path, .{});
defer if (legacy_file) |file| file.close();

defer file.close();
if (legacy_file) |file| return try readLegacy(allocator, file);

return error.CheckpointFileNotFound;
}

pub fn readLegacy(allocator: std.mem.Allocator, file: std.fs.File) !Self {
const embedding_size: usize = @intCast(try file.reader().readIntLittle(i32));
const ffn_hidden_size: usize = @intCast(try file.reader().readIntLittle(i32));
const n_layers: usize = @intCast(try file.reader().readIntLittle(i32));
Expand Down

0 comments on commit 5c2cd57

Please sign in to comment.