diff --git a/src/checkpoint.zig b/src/checkpoint.zig index 8115344..9b4b8f8 100644 --- a/src/checkpoint.zig +++ b/src/checkpoint.zig @@ -30,21 +30,23 @@ weights: struct { }, pub fn init(allocator: std.mem.Allocator, model_path: []const u8) !Self { - return try readLegacy(allocator, model_path); -} - -pub fn readLegacy(allocator: std.mem.Allocator, model_path: []const u8) !Self { - const path = try std.fs.path.join( + const legacy_path = try std.fs.path.join( allocator, &[_][]const u8{ model_path, "checkpoint_legacy.bin" }, ); - defer allocator.free(path); + defer allocator.free(legacy_path); + + const legacy_file = std.fs.cwd().openFile(legacy_path, .{}) catch null; - const file = try std.fs.cwd().openFile(path, .{}); + defer if (legacy_file) |file| file.close(); - defer file.close(); + if (legacy_file) |file| return try readLegacy(allocator, file); + + return error.CheckpointFileNotFound; +} +pub fn readLegacy(allocator: std.mem.Allocator, file: std.fs.File) !Self { const embedding_size: usize = @intCast(try file.reader().readIntLittle(i32)); const ffn_hidden_size: usize = @intCast(try file.reader().readIntLittle(i32)); const n_layers: usize = @intCast(try file.reader().readIntLittle(i32));