Skip to content

Commit

Permalink
Rename --n_steps to --sequence_length
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 19, 2023
1 parent 389afbe commit 6406661
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 49 deletions.
22 changes: 11 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ zig build -Doptimize=ReleaseFast run-generator -- models/llama2_7b_hf --prompt "
Usage: llama2-generator <model_path> [options]
Options:
--temperature <float> = 1.0
--top_p <float> = 0.9
--random_seed <int> = <milli_timestamp>
--n_steps <int> = <max_sequence_length>
--prompt <string> = ""
--temperature <float> = 1.0
--top_p <float> = 0.9
--random_seed <int> = <milli_timestamp>
--sequence_length <int> = <max_sequence_length>
--prompt <string> = ""
--verbose
--help
```
Expand All @@ -53,12 +53,12 @@ Options:
Usage: llama2-chat <model_path> [options]
Options:
--temperature <float> = 1.0
--top_p <float> = 0.9
--random_seed <int> = <milli_timestamp>
--n_steps <int> = <max_sequence_length>
--system_prompt <string> = ""
--user_prompt <string> = ""
--temperature <float> = 1.0
--top_p <float> = 0.9
--random_seed <int> = <milli_timestamp>
--sequence_length <int> = <max_sequence_length>
--system_prompt <string> = ""
--user_prompt <string> = ""
--help
```

Expand Down
2 changes: 1 addition & 1 deletion src/chat.zig
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ system_prompt: []const u8,
user_prompt: []const u8,

pub fn init(allocator: std.mem.Allocator, args: ChatArgs) !Self {
const transformer = try Transformer.init(allocator, args.model_path, args.n_steps);
const transformer = try Transformer.init(allocator, args.model_path, args.sequence_length);

errdefer transformer.deinit();

Expand Down
35 changes: 21 additions & 14 deletions src/chat_args.zig
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@ model_path: []const u8,
temperature: f32,
top_p: f32,
random_seed: u64,
n_steps: usize,
sequence_length: usize,
system_prompt: []const u8,
user_prompt: []const u8,

const Option = enum { temperature, top_p, random_seed, n_steps, system_prompt, user_prompt };
const Option = enum {
temperature,
top_p,
random_seed,
sequence_length,
system_prompt,
user_prompt,
};

pub fn init(allocator: std.mem.Allocator) !Self {
var arg_iterator = try std.process.argsWithAllocator(allocator);
Expand All @@ -26,7 +33,7 @@ pub fn init(allocator: std.mem.Allocator) !Self {
var temperature: ?f32 = null;
var top_p: ?f32 = null;
var random_seed: ?u64 = null;
var n_steps: ?usize = null;
var sequence_length: ?usize = null;
var system_prompt: ?[]const u8 = null;
var user_prompt: ?[]const u8 = null;

Expand All @@ -38,8 +45,8 @@ pub fn init(allocator: std.mem.Allocator) !Self {
top_p = try std.fmt.parseFloat(f32, arg);
} else if (option == .random_seed and random_seed == null) {
random_seed = try std.fmt.parseInt(u64, arg, 10);
} else if (option == .n_steps and n_steps == null) {
n_steps = try std.fmt.parseInt(usize, arg, 10);
} else if (option == .sequence_length and sequence_length == null) {
sequence_length = try std.fmt.parseInt(usize, arg, 10);
} else if (option == .system_prompt and system_prompt == null) {
system_prompt = arg;
} else if (option == .user_prompt and user_prompt == null) {
Expand All @@ -55,8 +62,8 @@ pub fn init(allocator: std.mem.Allocator) !Self {
current_option = .top_p;
} else if (std.mem.eql(u8, arg, "--random_seed")) {
current_option = .random_seed;
} else if (std.mem.eql(u8, arg, "--n_steps")) {
current_option = .n_steps;
} else if (std.mem.eql(u8, arg, "--sequence_length")) {
current_option = .sequence_length;
} else if (std.mem.eql(u8, arg, "--system_prompt")) {
current_option = .system_prompt;
} else if (std.mem.eql(u8, arg, "--user_prompt")) {
Expand All @@ -76,7 +83,7 @@ pub fn init(allocator: std.mem.Allocator) !Self {
.temperature = @max(@min(temperature orelse 1, 1), 0),
.top_p = @max(@min(top_p orelse 0.9, 1), 0),
.random_seed = random_seed orelse @intCast(std.time.milliTimestamp()),
.n_steps = n_steps orelse 0,
.sequence_length = sequence_length orelse 0,
.system_prompt = system_prompt orelse "",
.user_prompt = user_prompt orelse "",
};
Expand All @@ -95,12 +102,12 @@ fn help(exit_status: u8) !noreturn {
try console.print("Usage: llama2-chat <model_path> [options]\n\n", .{});

try console.print("Options:\n", .{});
try console.print(" --temperature <float> = 1.0\n", .{});
try console.print(" --top_p <float> = 0.9\n", .{});
try console.print(" --random_seed <int> = <milli_timestamp>\n", .{});
try console.print(" --n_steps <int> = <max_sequence_length>\n", .{});
try console.print(" --system_prompt <string> = \"\"\n", .{});
try console.print(" --user_prompt <string> = \"\"\n", .{});
try console.print(" --temperature <float> = 1.0\n", .{});
try console.print(" --top_p <float> = 0.9\n", .{});
try console.print(" --random_seed <int> = <milli_timestamp>\n", .{});
try console.print(" --sequence_length <int> = <max_sequence_length>\n", .{});
try console.print(" --system_prompt <string> = \"\"\n", .{});
try console.print(" --user_prompt <string> = \"\"\n", .{});
try console.print(" --help\n", .{});

std.process.exit(exit_status);
Expand Down
2 changes: 0 additions & 2 deletions src/converter_args.zig
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ const std = @import("std");
arg_iterator: std.process.ArgIterator,
model_path: []const u8,

const Option = enum { temperature, top_p, random_seed, n_steps, prompt };

pub fn init(allocator: std.mem.Allocator) !Self {
var arg_iterator = try std.process.argsWithAllocator(allocator);

Expand Down
4 changes: 2 additions & 2 deletions src/generator.zig
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ prompt_tokens: []usize,
verbose: bool,

pub fn init(allocator: std.mem.Allocator, args: GeneratorArgs) !Self {
const transformer = try Transformer.init(allocator, args.model_path, args.n_steps);
const transformer = try Transformer.init(allocator, args.model_path, args.sequence_length);

errdefer transformer.deinit();

Expand Down Expand Up @@ -113,7 +113,7 @@ test "generate tiny story" {
.temperature = 1,
.top_p = 0.9,
.random_seed = 42,
.n_steps = 10,
.sequence_length = 10,
.prompt = "There was",
.verbose = false,
};
Expand Down
26 changes: 13 additions & 13 deletions src/generator_args.zig
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ model_path: []const u8,
temperature: f32,
top_p: f32,
random_seed: u64,
n_steps: usize,
sequence_length: usize,
prompt: []const u8,
verbose: bool,

const Option = enum { temperature, top_p, random_seed, n_steps, prompt };
const Option = enum { temperature, top_p, random_seed, sequence_length, prompt };

pub fn init(allocator: std.mem.Allocator) !Self {
var arg_iterator = try std.process.argsWithAllocator(allocator);
Expand All @@ -26,7 +26,7 @@ pub fn init(allocator: std.mem.Allocator) !Self {
var temperature: ?f32 = null;
var top_p: ?f32 = null;
var random_seed: ?u64 = null;
var n_steps: ?usize = null;
var sequence_length: ?usize = null;
var prompt: ?[]const u8 = null;
var verbose: bool = false;

Expand All @@ -38,8 +38,8 @@ pub fn init(allocator: std.mem.Allocator) !Self {
top_p = try std.fmt.parseFloat(f32, arg);
} else if (option == .random_seed and random_seed == null) {
random_seed = try std.fmt.parseInt(u64, arg, 10);
} else if (option == .n_steps and n_steps == null) {
n_steps = try std.fmt.parseInt(usize, arg, 10);
} else if (option == .sequence_length and sequence_length == null) {
sequence_length = try std.fmt.parseInt(usize, arg, 10);
} else if (option == .prompt and prompt == null) {
prompt = arg;
} else {
Expand All @@ -53,8 +53,8 @@ pub fn init(allocator: std.mem.Allocator) !Self {
current_option = .top_p;
} else if (std.mem.eql(u8, arg, "--random_seed")) {
current_option = .random_seed;
} else if (std.mem.eql(u8, arg, "--n_steps")) {
current_option = .n_steps;
} else if (std.mem.eql(u8, arg, "--sequence_length")) {
current_option = .sequence_length;
} else if (std.mem.eql(u8, arg, "--prompt")) {
current_option = .prompt;
} else if (std.mem.eql(u8, arg, "--verbose") and !verbose) {
Expand All @@ -74,7 +74,7 @@ pub fn init(allocator: std.mem.Allocator) !Self {
.temperature = @max(@min(temperature orelse 1, 1), 0),
.top_p = @max(@min(top_p orelse 0.9, 1), 0),
.random_seed = random_seed orelse @intCast(std.time.milliTimestamp()),
.n_steps = n_steps orelse 0,
.sequence_length = sequence_length orelse 0,
.prompt = prompt orelse "",
.verbose = verbose,
};
Expand All @@ -93,11 +93,11 @@ fn help(exit_status: u8) !noreturn {
try console.print("Usage: llama2-generator <model_path> [options]\n\n", .{});

try console.print("Options:\n", .{});
try console.print(" --temperature <float> = 1.0\n", .{});
try console.print(" --top_p <float> = 0.9\n", .{});
try console.print(" --random_seed <int> = <milli_timestamp>\n", .{});
try console.print(" --n_steps <int> = <max_sequence_length>\n", .{});
try console.print(" --prompt <string> = \"\"\n", .{});
try console.print(" --temperature <float> = 1.0\n", .{});
try console.print(" --top_p <float> = 0.9\n", .{});
try console.print(" --random_seed <int> = <milli_timestamp>\n", .{});
try console.print(" --sequence_length <int> = <max_sequence_length>\n", .{});
try console.print(" --prompt <string> = \"\"\n", .{});
try console.print(" --verbose\n", .{});
try console.print(" --help\n", .{});

Expand Down
12 changes: 10 additions & 2 deletions src/transformer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,20 @@ ffn: FFN,
hidden_buffer: Tensor(1),
output_buffer: Tensor(1),

pub fn init(allocator: std.mem.Allocator, model_path: []const u8, n_steps: usize) !Self {
pub fn init(
allocator: std.mem.Allocator,
model_path: []const u8,
custom_sequence_length: usize,
) !Self {
const checkpoint = try Checkpoint.init(allocator, model_path);

errdefer checkpoint.deinit();

const sequence_length = if (n_steps == 0) checkpoint.max_sequence_length else n_steps;
const sequence_length = if (custom_sequence_length == 0)
checkpoint.max_sequence_length
else
custom_sequence_length;

const attention = try Attention.init(allocator, checkpoint, sequence_length);

errdefer attention.deinit();
Expand Down
8 changes: 4 additions & 4 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ zig build

model_path="models/tinystories_260k"

actual_output=$(./zig-out/bin/llama2-generator $model_path --temperature 0 --n_steps 200)
actual_output=$(./zig-out/bin/llama2-generator $model_path --temperature 0 --sequence_length 200)

# Generated with llama2.c (https://github.com/karpathy/llama2.c/tree/7ac65cb2c2b169050747be92011b7bebdd1b4544)
expected_output="Once upon a time, there was a little girl named Lily. She loved to play outside in the park. One day, she saw a big, red ball. She wanted to play with it, but it was too high.
Expand All @@ -22,7 +22,7 @@ if [ "$actual_output" != "$expected_output" ]; then
exit 1
fi

actual_output=$(./zig-out/bin/llama2-generator $model_path --top_p 1 --random_seed 42 --n_steps 200)
actual_output=$(./zig-out/bin/llama2-generator $model_path --top_p 1 --random_seed 42 --sequence_length 200)

# Generated with llama2.c (https://github.com/karpathy/llama2.c/tree/7ac65cb2c2b169050747be92011b7bebdd1b4544)
expected_output="Once upon a time, there was a big roof. The fox was ready to look for people inside. He saw a big rock near a big tree. The roof was very small and fun! He ate the roof too. He got a shiny stool, so he sicked the roof with his friend, the girl named Mia.
Expand All @@ -35,7 +35,7 @@ if [ "$actual_output" != "$expected_output" ]; then
exit 1
fi

actual_output=$(./zig-out/bin/llama2-generator $model_path --top_p 0.95 --random_seed 42 --n_steps 200)
actual_output=$(./zig-out/bin/llama2-generator $model_path --top_p 0.95 --random_seed 42 --sequence_length 200)

# Generated with llama2.c (https://github.com/karpathy/llama2.c/tree/7ac65cb2c2b169050747be92011b7bebdd1b4544)
expected_output="Once upon a time, there was a little boy named Timmy. Timmy loved going to the park with his mom. One day, Lily went outside to play outside in her pocket. He was scared and didn't know where to buy some colorful animals.
Expand All @@ -46,7 +46,7 @@ if [ "$actual_output" != "$expected_output" ]; then
exit 1
fi

actual_output=$(./zig-out/bin/llama2-generator $model_path --top_p 0.95 --random_seed 42 --n_steps 200 --prompt "There was a big")
actual_output=$(./zig-out/bin/llama2-generator $model_path --top_p 0.95 --random_seed 42 --sequence_length 200 --prompt "There was a big")

# Generated with llama2.c (https://github.com/karpathy/llama2.c/tree/7ac65cb2c2b169050747be92011b7bebdd1b4544)
expected_output="There was a big pretty grass. It was a long elephant. The cars wanted to tell him that as they spin before the amazing doll, just like it she was always okay.
Expand Down

0 comments on commit 6406661

Please sign in to comment.