Skip to content

Commit

Permalink
Introduce sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Aug 26, 2023
1 parent c9a7980 commit d5136e3
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/lib/sample_multinomial.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ const std = @import("std");

const tolerance: comptime_float = std.math.sqrt(std.math.floatEps(f32));

pub fn sampleMultinomial(probability_threshold: f32, probability_distribution: []f32) usize {
pub fn sampleMultinomial(probability_threshold: f32, probability_distribution: []const f32) usize {
std.debug.assert(probability_distribution.len > 0);

var cumulative_probability: f32 = 0;
Expand Down
34 changes: 8 additions & 26 deletions src/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ const std = @import("std");
const lib = @import("lib.zig");
const Checkpoint = @import("checkpoint.zig");
const Cli = @import("cli.zig");
const Sampler = @import("sampler.zig");
const Tokenizer = @import("tokenizer.zig");
const Transformer = @import("transformer.zig");

Expand All @@ -24,6 +25,10 @@ fn generate(allocator: std.mem.Allocator, cli: *const Cli, writer: anytype) !voi

defer checkpoint.deinit();

var sampler = try Sampler.init(allocator, cli, checkpoint.vocab_size);

defer sampler.deinit();

const tokenizer = try Tokenizer.init(allocator, cli.tokenizer_path, checkpoint.vocab_size);

defer tokenizer.deinit();
Expand All @@ -32,26 +37,20 @@ fn generate(allocator: std.mem.Allocator, cli: *const Cli, writer: anytype) !voi

defer allocator.free(prompt_tokens);

std.debug.assert(prompt_tokens.len > 0);

const transformer = try Transformer.init(allocator, &checkpoint, cli.n_steps);

defer transformer.deinit();

std.debug.assert(prompt_tokens.len > 0);

var prompt_tokens_offset: usize = 0;
var current_token: usize = prompt_tokens[prompt_tokens_offset];

prompt_tokens_offset += 1;

var probability_index_pairs_buffer: []lib.ProbabilityIndexPair =
try allocator.alloc(lib.ProbabilityIndexPair, checkpoint.vocab_size);

defer allocator.free(probability_index_pairs_buffer);

var start_time: i64 = 0;
var total_time: i64 = 0;
var next_token: usize = 1;
var rng_state = cli.random_seed;
var n_steps: usize = 0;

for (0..cli.n_steps) |pos| {
Expand All @@ -68,25 +67,8 @@ fn generate(allocator: std.mem.Allocator, cli: *const Cli, writer: anytype) !voi
if (prompt_tokens_offset < prompt_tokens.len) {
next_token = prompt_tokens[prompt_tokens_offset];
prompt_tokens_offset += 1;
} else if (cli.temperature == 0) {
next_token = lib.argmax(transformer.logits);
} else {
for (transformer.logits) |*logit| {
logit.* /= cli.temperature;
}

lib.softmax(transformer.logits);

if (cli.top_p <= 0 or cli.top_p >= 1) {
next_token = lib.sampleMultinomial(lib.random(&rng_state), transformer.logits);
} else {
next_token = lib.sampleNucleus(
lib.random(&rng_state),
transformer.logits,
cli.top_p,
probability_index_pairs_buffer,
);
}
next_token = sampler.sample(transformer.logits);
}

n_steps += 1;
Expand Down
49 changes: 49 additions & 0 deletions src/sampler.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
const Self = @This();

const std = @import("std");
const lib = @import("lib.zig");
const Checkpoint = @import("checkpoint.zig");
const Cli = @import("cli.zig");

allocator: std.mem.Allocator,
probability_index_pairs_buffer: []lib.ProbabilityIndexPair,
temperature: f32,
top_p: f32,
rng_state: u64,

pub fn init(allocator: std.mem.Allocator, cli: *const Cli, vocab_size: usize) !Self {
return Self{
.allocator = allocator,
.probability_index_pairs_buffer = try allocator.alloc(lib.ProbabilityIndexPair, vocab_size),
.temperature = cli.temperature,
.top_p = cli.top_p,
.rng_state = cli.random_seed,
};
}

pub fn deinit(self: *const Self) void {
defer self.allocator.free(self.probability_index_pairs_buffer);
}

pub fn sample(self: *Self, probability_distribution: []f32) usize {
if (self.temperature == 0) {
return lib.argmax(probability_distribution);
}

for (probability_distribution) |*probability| {
probability.* /= self.temperature;
}

lib.softmax(probability_distribution);

if (self.top_p <= 0 or self.top_p >= 1) {
return lib.sampleMultinomial(lib.random(&self.rng_state), probability_distribution);
}

return lib.sampleNucleus(
lib.random(&self.rng_state),
probability_distribution,
self.top_p,
self.probability_index_pairs_buffer,
);
}

0 comments on commit d5136e3

Please sign in to comment.