Skip to content

Commit

Permalink
Add support for chat
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Aug 29, 2023
1 parent 720adb8 commit f1ce455
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 11 deletions.
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@
This is a Zig port of [llama2.c](https://github.com/karpathy/llama2.c).

The current code is based on:
https://github.com/karpathy/llama2.c/blob/c7a26264a233c32f396b1c67be4ac019d2d8a659/run.c
https://github.com/karpathy/llama2.c/blob/7325bab657406c427e7c1ca6575bace9a5982744/run.c

I have significantly diverged from the original in terms of architecture and implementation.
However, my goal is to continue porting the improvements and new features of Andrej's C version into
this codebase. At present, my Zig port produces the same output as the C version. I ensure this
through the following linked [tests](./test.sh).

## TODOs

- Add support for chat (https://github.com/karpathy/llama2.c/pull/343)

## Usage

```sh
Expand Down
157 changes: 157 additions & 0 deletions src/chat.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
const Self = @This();

const std = @import("std");
const lib = @import("lib.zig");
const Cli = @import("cli.zig");
const Sampler = @import("sampler.zig");
const Tokenizer = @import("tokenizer.zig");
const Transformer = @import("transformer.zig");

allocator: std.mem.Allocator,
transformer: Transformer,
tokenizer: Tokenizer,
sampler: Sampler,
n_steps: usize,
user_prompt: []const u8,
system_prompt: []const u8,

pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self {
const transformer = try Transformer.init(allocator, cli);

errdefer transformer.deinit();

const vocab_size = transformer.checkpoint.vocab_size;
const tokenizer = try Tokenizer.init(allocator, cli.tokenizer_path, vocab_size);

errdefer tokenizer.deinit();

const sampler = try Sampler.init(allocator, cli, vocab_size);

errdefer sampler.deinit();

return Self{
.allocator = allocator,
.transformer = transformer,
.tokenizer = tokenizer,
.sampler = sampler,
.n_steps = cli.n_steps,
.user_prompt = cli.prompt,
.system_prompt = cli.system_prompt,
};
}

pub fn deinit(self: *const Self) void {
self.transformer.deinit();
self.tokenizer.deinit();
self.sampler.deinit();
}

const user_prompt_template_start = "[INST] ";
const user_prompt_template_close = " [/INST]";
const system_prompt_template_start = "<<SYS>>\n";
const system_prompt_template_close = "\n<</SYS>>\n\n";

const bos_token = 1; // beginning of sequence
const eos_token = 2; // end of sequence

pub fn start(self: *Self, allocator: std.mem.Allocator) !void {
var stdin = std.io.getStdIn().reader();
var stdout = std.io.getStdOut().writer();

var token: usize = bos_token;
var next_token: usize = 0;
var user_turn: bool = true;
var user_prompt_tokens_index: usize = 0;

var user_prompt_tokens: ?[]const usize = null;

defer if (user_prompt_tokens) |prompt_tokens| {
allocator.free(prompt_tokens);
};

for (0..self.n_steps) |pos| {
try self.transformer.forward(token, pos);

if (token == bos_token and user_turn) {
var user_prompt = std.ArrayList(u8).init(allocator);

defer user_prompt.deinit();

try user_prompt.appendSlice(user_prompt_template_start);

if (pos == 0) {
if (self.system_prompt.len == 0) {
var system_prompt = std.ArrayList(u8).init(allocator);

defer system_prompt.deinit();

try stdout.print("Enter system prompt (optional): ", .{});
try stdin.streamUntilDelimiter(system_prompt.writer(), '\n', null);

if (system_prompt.items.len > 0) {
try user_prompt.appendSlice(system_prompt_template_start);
try user_prompt.appendSlice(try system_prompt.toOwnedSlice());
try user_prompt.appendSlice(system_prompt_template_close);
}
} else {
try user_prompt.appendSlice(system_prompt_template_start);
try user_prompt.appendSlice(self.system_prompt);
try user_prompt.appendSlice(system_prompt_template_close);
}
}

if (pos == 0 and self.user_prompt.len > 0) {
try user_prompt.appendSlice(self.user_prompt);
} else {
try stdout.print("User: ", .{});
try stdin.streamUntilDelimiter(user_prompt.writer(), '\n', null);
}

try user_prompt.appendSlice(user_prompt_template_close);

if (user_prompt_tokens) |prompt_tokens| {
allocator.free(prompt_tokens);

user_prompt_tokens = null;
}

user_turn = false;
user_prompt_tokens_index = 0;
user_prompt_tokens = try self.tokenizer.encode(allocator, user_prompt.items);

try stdout.print("Assistant:", .{});
}

if (user_prompt_tokens) |prompt_tokens| {
if (user_prompt_tokens_index < prompt_tokens.len) {
next_token = prompt_tokens[user_prompt_tokens_index];
}
}

user_prompt_tokens_index += 1;

if (next_token == 0) {
next_token = self.sampler.sample(self.transformer.logits);
}

if (next_token == eos_token) {
user_turn = true;

try stdout.print("\n", .{});
} else if (user_prompt_tokens) |prompt_tokens| {
if (next_token > 2 and user_prompt_tokens_index > prompt_tokens.len) {
const word = self.tokenizer.decode(
next_token,
user_prompt_tokens_index == prompt_tokens.len + 1,
);

try lib.print(word, stdout);
}
}

token = next_token;
next_token = 0;
}

try stdout.print("\n", .{});
}
31 changes: 30 additions & 1 deletion src/cli.zig
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,22 @@ random_seed: u64,
n_steps: usize,
prompt: []const u8,
tokenizer_path: []const u8,
chat: bool,
system_prompt: []const u8,
mmap: bool,
timer: bool,
arg_iterator: std.process.ArgIterator,

const Option = enum { temperature, top_p, random_seed, n_steps, prompt, tokenizer_path };
const Option = enum {
temperature,
top_p,
random_seed,
n_steps,
prompt,
tokenizer_path,
mode,
system_prompt,
};

pub fn init(allocator: std.mem.Allocator) !Self {
var current_option: ?Option = null;
Expand All @@ -23,6 +34,8 @@ pub fn init(allocator: std.mem.Allocator) !Self {
var n_steps: ?usize = null;
var prompt: ?[]const u8 = null;
var tokenizer_path: ?[]const u8 = null;
var mode: ?[]const u8 = null;
var system_prompt: ?[]const u8 = null;
var mmap: bool = true;
var timer: bool = true;

Expand All @@ -48,6 +61,14 @@ pub fn init(allocator: std.mem.Allocator) !Self {
prompt = arg;
} else if (option == .tokenizer_path and tokenizer_path == null) {
tokenizer_path = arg;
} else if (option == .mode and mode == null) {
if (std.mem.eql(u8, arg, "generate") or std.mem.eql(u8, arg, "chat")) {
mode = arg;
} else {
try exit();
}
} else if (option == .system_prompt and system_prompt == null) {
system_prompt = arg;
} else {
try exit();
}
Expand All @@ -65,6 +86,10 @@ pub fn init(allocator: std.mem.Allocator) !Self {
current_option = .prompt;
} else if (std.mem.eql(u8, arg, "-z")) {
current_option = .tokenizer_path;
} else if (std.mem.eql(u8, arg, "-m")) {
current_option = .mode;
} else if (std.mem.eql(u8, arg, "-y")) {
current_option = .system_prompt;
} else if (std.mem.eql(u8, arg, "--no-mmap") and mmap) {
mmap = false;
} else if (std.mem.eql(u8, arg, "--no-timer") and timer) {
Expand All @@ -86,6 +111,8 @@ pub fn init(allocator: std.mem.Allocator) !Self {
.n_steps = @max(n_steps orelse 256, 1),
.prompt = prompt orelse "",
.tokenizer_path = tokenizer_path orelse "tokenizer.bin",
.chat = if (mode) |arg| std.mem.eql(u8, arg, "chat") else false,
.system_prompt = system_prompt orelse "",
.mmap = mmap,
.timer = timer,
.arg_iterator = arg_iterator,
Expand All @@ -108,6 +135,8 @@ fn exit() !noreturn {
try stderr.print(" -n <int> n_steps = 256\n", .{});
try stderr.print(" -i <string> prompt = \"\"\n", .{});
try stderr.print(" -z <string> tokenizer_path = \"tokenizer.bin\"\n", .{});
try stderr.print(" -m <string> mode = \"generate\"; (alt. \"chat\")\n", .{});
try stderr.print(" -y <string> system_prompt = \"\"\n", .{});
try stderr.print(" --no-mmap\n", .{});
try stderr.print(" --no-timer\n\n", .{});

Expand Down
4 changes: 3 additions & 1 deletion src/generator.zig
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ const eos_token = 2; // end of sequence

pub fn generate(self: *Self, writer: anytype) !void {
var token: usize = bos_token;
var next_token: usize = eos_token;
var next_token: usize = 0;
var prompt_tokens_index: usize = 0;
var n_timed_steps: usize = 0;
var start_time: i64 = 0;
Expand Down Expand Up @@ -117,6 +117,8 @@ test "generate tiny story" {
.n_steps = 10,
.prompt = "There was",
.tokenizer_path = "tok512.bin",
.chat = false,
.system_prompt = "",
.mmap = false,
.timer = false,
.arg_iterator = arg_iterator,
Expand Down
15 changes: 12 additions & 3 deletions src/main.zig
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const std = @import("std");
const Chat = @import("chat.zig");
const Cli = @import("cli.zig");
const Generator = @import("generator.zig");

Expand All @@ -7,11 +8,19 @@ pub fn main() !void {

defer cli.deinit();

var generator = try Generator.init(std.heap.page_allocator, &cli);
if (cli.chat) {
var chat = try Chat.init(std.heap.page_allocator, &cli);

defer generator.deinit();
defer chat.deinit();

try generator.generate(std.io.getStdOut().writer());
try chat.start(std.heap.page_allocator);
} else {
var generator = try Generator.init(std.heap.page_allocator, &cli);

defer generator.deinit();

try generator.generate(std.io.getStdOut().writer());
}
}

test {
Expand Down
2 changes: 1 addition & 1 deletion src/tokenizer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ pub fn decode(self: *const Self, token: usize, bos: bool) []const u8 {
const word = self.vocab[token];

// https://github.com/karpathy/llama2.c/blob/7ac65cb2c2b169050747be92011b7bebdd1b4544/run.c#L425
return if (bos and word[0] == ' ') word[1..] else word;
return if (bos and std.ascii.isWhitespace(word[0])) word[1..] else word;
}

fn encodeCodepoints(self: *const Self, allocator: std.mem.Allocator, text: []const u8) ![]usize {
Expand Down

0 comments on commit f1ce455

Please sign in to comment.