Skip to content

Commit

Permalink
always train in bulk
Browse files Browse the repository at this point in the history
  • Loading branch information
a10y committed Aug 23, 2024
1 parent f1702e7 commit 9f4448b
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 207 deletions.
105 changes: 28 additions & 77 deletions benches/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,21 @@ fn bench_dbtext(c: &mut Criterion) {

let mut text = String::new();
let lines: Vec<&[u8]> = {
let mut file = File::open("benches/data/wikipedia").unwrap();
let mut file = File::open(path).unwrap();
file.read_to_string(&mut text).unwrap();

text.lines().map(|line| line.as_bytes()).collect()
};

group.bench_function("train-and-compress", |b| {
b.iter(|| {
let compressor = Compressor::train_bulk(&lines);
let compressor = Compressor::train(&lines);
let _ =
std::hint::black_box(compressor.compress_bulk(std::hint::black_box(&lines)));
});
});

let compressor = Compressor::train_bulk(&lines);
let compressor = Compressor::train(&lines);
group.throughput(Throughput::Bytes(
lines.iter().map(|l| l.len() as u64).sum::<u64>(),
));
Expand All @@ -89,6 +89,24 @@ fn bench_dbtext(c: &mut Criterion) {
});

group.finish();

// Report the compression factor for this dataset.
let uncompressed_size = lines.iter().map(|l| l.len()).sum::<usize>();
let compressor = Compressor::train(&lines);

// Show the symbols
for code in 256..compressor.symbol_table().len() {
let symbol = compressor.symbol_table()[code];
let code = code - 256;
println!("symbol[{code}] = {symbol:?}");
}

let compressed = compressor.compress_bulk(&lines);
let compressed_size = compressed.iter().map(|l| l.len()).sum::<usize>();
let ratio = 100.0 * (compressed_size as f64) / (uncompressed_size as f64);
println!(
"compressed {name} {uncompressed_size} => {compressed_size}B ({ratio}% of original)"
)
}

run_dataset_bench(
Expand All @@ -111,81 +129,14 @@ fn bench_dbtext(c: &mut Criterion) {
"benches/data/urls",
c,
);
}

fn bench_tpch_comments(c: &mut Criterion) {
let mut group = c.benchmark_group("tpch");

group.bench_function("train-only", |b| {
b.iter(|| {
let mut file = File::open("/Users/aduffy/code/cwi-fsst/build/comments").unwrap();
let mut text = String::new();
file.read_to_string(&mut text).unwrap();

let lines: Vec<&str> = text.lines().collect();
let lines_sliced: Vec<&[u8]> = lines.iter().map(|s| s.as_bytes()).collect();

let _ =
std::hint::black_box(Compressor::train_bulk(std::hint::black_box(&lines_sliced)));
// let _ = std::hint::black_box(compressor.compress_bulk(&lines_sliced));
});
});

let mut file = File::open("/Users/aduffy/code/cwi-fsst/build/comments").unwrap();
let mut text = String::new();
file.read_to_string(&mut text).unwrap();

let lines: Vec<&str> = text.lines().collect();
let lines_sliced: Vec<&[u8]> = lines.iter().map(|s| s.as_bytes()).collect();

let compressor = Compressor::train_bulk(&lines_sliced);

group.throughput(Throughput::Bytes(
lines.iter().map(|l| l.len() as u64).sum::<u64>(),
));
group.bench_function("compress-only", |b| {
b.iter(|| {
let _ = std::hint::black_box(compressor.compress_bulk(&lines_sliced));
});
});

group.bench_function("train-and-compress", |b| {
b.iter(|| {
let mut file = File::open("/Users/aduffy/code/cwi-fsst/build/comments").unwrap();
let mut text = String::new();
file.read_to_string(&mut text).unwrap();

let lines: Vec<&str> = text.lines().collect();
let lines_sliced: Vec<&[u8]> = lines.iter().map(|s| s.as_bytes()).collect();

let compressor = Compressor::train_bulk(&lines_sliced);
let _ = std::hint::black_box(compressor.compress_bulk(&lines_sliced));
});
});

group.finish();

let mut file = File::open("/Users/aduffy/code/cwi-fsst/build/comments").unwrap();
let mut text = String::new();
file.read_to_string(&mut text).unwrap();

let lines: Vec<&str> = text.lines().collect();
let lines_sliced: Vec<&[u8]> = lines.iter().map(|s| s.as_bytes()).collect();
let mut lines_total = Vec::new();
for slice in &lines_sliced {
lines_total.extend_from_slice(slice);
}

let compressor = Compressor::train_bulk(&lines_sliced);
let compressed = compressor.compress(&lines_total);

println!(
"compressed {} => {} ({}%)",
lines_total.len(),
compressed.len(),
100.0 * (compressed.len() as f64) / (lines_total.len() as f64),
)
run_dataset_bench(
"dbtext/urls",
"https://raw.githubusercontent.com/cwida/fsst/4e188a/paper/dbtext/urls",
"benches/data/urls",
c,
);
}

criterion_group!(compress_bench, bench_tpch_comments, bench_dbtext);
criterion_group!(compress_bench, bench_dbtext);
criterion_main!(compress_bench);
7 changes: 3 additions & 4 deletions examples/file_compressor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,11 @@ fn main() {

let mut output = File::create(output_path).unwrap();

let compressor = Compressor::train_bulk(&lines);
let compressor = Compressor::train(&lines);
let mut compressed_size = 0;
for text in lines {
let compressed = compressor.compress(&text);
compressed_size += compressed.len();
output.write(&compressed).unwrap();
let compressed = compressor.compress(text);
compressed_size += output.write(&compressed).unwrap();
}

println!(
Expand Down
2 changes: 1 addition & 1 deletion examples/round_trip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use fsst::Compressor;
fn main() {
// Train on a sample.
let sample = "the quick brown fox jumped over the lazy dog";
let trained = Compressor::train(sample.as_bytes());
let trained = Compressor::train(&vec![sample.as_bytes()]);
let compressed = trained.compress(sample.as_bytes());
println!("compressed: {} => {}", sample.len(), compressed.len());
// decompress now
Expand Down
2 changes: 1 addition & 1 deletion fuzz/fuzz_targets/fuzz_compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use libfuzzer_sys::fuzz_target;

fuzz_target!(|data: &[u8]| {
let compressor =
fsst::Compressor::train("the quick brown fox jumped over the lazy dog".as_bytes());
fsst::Compressor::train(&vec![b"the quick brown fox jumped over the lazy dog"]);
let compress = compressor.compress(data);
let decompress = compressor.decompressor().decompress(&compress);
assert_eq!(&decompress, data);
Expand Down
2 changes: 1 addition & 1 deletion fuzz/fuzz_targets/fuzz_train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
use libfuzzer_sys::fuzz_target;

fuzz_target!(|data: &[u8]| {
let _ = fsst::Compressor::train(data);
let _ = fsst::Compressor::train(&vec![data]);
});
115 changes: 47 additions & 68 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ impl Counter {
///
/// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf
#[cfg(not(miri))]
const MAX_GENERATIONS: usize = 5;
const GENERATIONS: [usize; 5] = [8usize, 38, 68, 98, 128];
#[cfg(miri)]
const MAX_GENERATIONS: usize = 2;
const GENERATIONS: [usize; 3] = [8usize, 38, 128];

const FSST_SAMPLETARGET: usize = 1 << 14;
const FSST_SAMPLEMAX: usize = 1 << 15;
Expand All @@ -220,6 +220,7 @@ const FSST_SAMPLELINE: usize = 512;
/// returned slices are pointers into the `sample_buf`.
///
/// SAFETY: sample_buf must be >= FSST_SAMPLEMAX bytes long. Providing something less may cause unexpected failures.
#[allow(clippy::ptr_arg)]
fn make_sample<'a, 'b: 'a>(sample_buf: &'a mut Vec<u8>, str_in: &Vec<&'b [u8]>) -> Vec<&'a [u8]> {
debug_assert!(
sample_buf.capacity() >= FSST_SAMPLEMAX,
Expand All @@ -239,13 +240,13 @@ fn make_sample<'a, 'b: 'a>(sample_buf: &'a mut Vec<u8>, str_in: &Vec<&'b [u8]>)

while sample_buf_offset < sample_lim {
sample_rnd = fsst_hash(sample_rnd);
let mut line_nr = sample_rnd % str_in.len();
let mut line_nr = (sample_rnd as usize) % str_in.len();

// Find the first non-empty chunk starting at line_nr, wrapping around if
// necessary.
//
// TODO: this will loop infinitely if there are no non-empty lines in the sample
while str_in[line_nr].len() == 0 {
while str_in[line_nr].is_empty() {
if line_nr == str_in.len() {
line_nr = 0;
}
Expand All @@ -254,10 +255,9 @@ fn make_sample<'a, 'b: 'a>(sample_buf: &'a mut Vec<u8>, str_in: &Vec<&'b [u8]>)
let line = str_in[line_nr];
let chunks = 1 + ((line.len() - 1) / FSST_SAMPLELINE);
sample_rnd = fsst_hash(sample_rnd);
let chunk = FSST_SAMPLELINE * (sample_rnd % chunks);
let chunk = FSST_SAMPLELINE * ((sample_rnd as usize) % chunks);

let len = FSST_SAMPLELINE.min(line.len() - chunk);
// println!("extending sample with chunk str_in[{line_nr}][{chunk}...len={len}]");

sample_buf.extend_from_slice(&str_in[line_nr][chunk..chunk + len]);

Expand All @@ -273,7 +273,11 @@ fn make_sample<'a, 'b: 'a>(sample_buf: &'a mut Vec<u8>, str_in: &Vec<&'b [u8]>)
sample
}

fn fsst_hash(value: usize) -> usize {
/// Hash function used in various components of the library.
///
/// This is equivalent to the FSST_HASH macro from the C++ implementation.
#[inline]
pub(crate) fn fsst_hash(value: u64) -> u64 {
(value * 2971215073) ^ (value >> 15)
}

Expand Down Expand Up @@ -307,56 +311,24 @@ impl Compressor {
/// code).
///
/// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf
pub fn train(corpus: impl AsRef<[u8]>) -> Self {
pub fn train(values: &Vec<&[u8]>) -> Self {
let mut counters = Counter::new();
let mut compressor = Compressor::default();
// TODO(aduffy): handle truncating/sampling if corpus > requires sample size.
let sample = corpus.as_ref();
if sample.is_empty() {
return compressor;
}

// Make the sample for each iteration.
//
// The sample is just a vector of slices, so we don't actually have to move anything around.

let mut counter = Counter::new();
for _generation in 0..(MAX_GENERATIONS - 1) {
compressor.compress_count(sample, &mut counter);
compressor.optimize(&counter, 128);
counter.clear();
if values.is_empty() {
return compressor;
}

compressor.compress_count(sample, &mut counter);
compressor.optimize(&counter, 128);

compressor
}

/// Train on a collection of samples.
pub fn train_bulk(values: &Vec<&[u8]>) -> Self {
let mut sample_memory = Vec::with_capacity(FSST_SAMPLEMAX);
let sample = make_sample(&mut sample_memory, values);

let mut counters = Counter::new();
let mut compressor = Compressor::default();

for sample_frac in [8usize, 38, 68, 98, 128] {
// let mut skips = 0;
for i in 0..sample.len() {
if sample_frac < 128 {
if fsst_hash(i) & 127 > sample_frac {
// skips += 1;
continue;
}
for sample_frac in GENERATIONS {
for (i, line) in sample.iter().enumerate() {
if sample_frac < 128 && ((fsst_hash(i as u64) & 127) as usize) > sample_frac {
continue;
}

compressor.compress_count(sample[i], &mut counters);
compressor.compress_count(line, &mut counters);
}
// println!(
// "sampleFrac={sample_frac} -- skipped {} of {}",
// skips,
// sample.len()
// );

compressor.optimize(&counters, sample_frac);
counters.clear();
Expand Down Expand Up @@ -403,6 +375,11 @@ impl Compressor {
counter.record_count1(code_u16);
if prev_code != MAX_CODE {
counter.record_count2(prev_code, code_u16);
// Also record the first byte of the next code
let first_byte_code =
self.symbols[code_u16 as usize].first_byte() as u16;
counter.record_count1(first_byte_code);
counter.record_count2(prev_code, first_byte_code);
}
prev_code = code_u16;
}
Expand Down Expand Up @@ -494,9 +471,9 @@ impl Compressor {

// From the c++ impl:
// "improves both compression speed (less candidates), but also quality!!"
if count < 5 * sample_frac / 128 {
continue;
}
// if count < (5 * sample_frac / 128) {
// continue;
// }

let mut gain = count * symbol1_len;
// NOTE: use heuristic from C++ implementation to boost the gain of single-byte symbols.
Expand All @@ -512,7 +489,7 @@ impl Compressor {
});
}

// Skip on last round, or when symbol cannot be extended.
// Skip merges on last round, or when symbol cannot be extended.
if sample_frac >= 128 || symbol1_len == 8 {
continue;
}
Expand Down Expand Up @@ -552,19 +529,19 @@ impl Compressor {
//
// Note that because of the lossy hash table, we won't accidentally
// save the same ASCII character twice into the table.
// if include_ascii {
// for character in
// " abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ[](){}:?/<>".bytes()
// {
// if n_symbols == 255 {
// break;
// }

// if self.insert(Symbol::from_u8(character)) {
// n_symbols += 1
// }
// }
// }
if sample_frac < 128 {
for character in
" abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ[](){}:?/<>".bytes()
{
if n_symbols == 255 {
break;
}

if self.insert(Symbol::from_u8(character)) {
n_symbols += 1
}
}
}
}
}

Expand Down Expand Up @@ -613,11 +590,13 @@ mod test {
#[test]
fn test_builder() {
// Train a Compressor on the toy string
let text = "hello world";
let table = Compressor::train(text.as_bytes());
let text = b"hello hello hello hello";

// count of 5 is the cutoff for including a symbol in the table.
let table = Compressor::train(&vec![text, text, text, text, text]);

// Use the table to compress a string, see the values
let compressed = table.compress(text.as_bytes());
let compressed = table.compress(text);

// Ensure that the compressed string has no escape bytes
assert!(compressed.iter().all(|b| *b != ESCAPE_CODE));
Expand Down
Loading

0 comments on commit 9f4448b

Please sign in to comment.