Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get compress performance to match paper algorithm 4 #3

Merged
merged 20 commits into from
Aug 15, 2024
Merged
7 changes: 0 additions & 7 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,2 @@
/target
.idea/


# Added by cargo
#
# already existing elements were commented out

#/target
13 changes: 11 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 14 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
[package]
name = "fsst-rs"
version = "0.0.1"
description = "Pure-Rust implementation of Fast Static Symbol Tables algorithm for string compression"
authors = ["SpiralDB Developers <hello@spiraldb.com>"]
license = "Apache-2.0"
repository = "https://github.com/spiraldb/fsst"
edition = "2021"

[lints.rust]
Expand All @@ -22,7 +26,16 @@ use_debug = { level = "deny" }
criterion = "0.5"
lz4 = "1"

[[example]]
name = "round_trip"
bench = false
test = false

[[bench]]
name = "compress"
harness = false
bench = true

[[test]]
name = "correctness"
test = true
bench = false
36 changes: 7 additions & 29 deletions benches/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
//!
//! Also contains LZ4 baseline.
#![allow(missing_docs)]
use core::str;
use std::io::{Cursor, Read, Write};

use criterion::{black_box, criterion_group, criterion_main, Criterion};
use lz4::liblz4::BlockChecksum;
use lz4::{BlockSize, ContentChecksum};

use fsst_rs::{train, Code};
use fsst_rs::{train, ESCAPE_CODE};

const CORPUS: &str = include_str!("dracula.txt");
const TEST: &str = "I found my smattering of German very useful here";
Expand All @@ -26,17 +27,17 @@ fn bench_fsst(c: &mut Criterion) {
let plaintext = TEST.as_bytes();

let compressed = table.compress(plaintext);
let escape_count = compressed
.iter()
.filter(|b| **b == Code::ESCAPE_CODE)
.count();
let escape_count = compressed.iter().filter(|b| **b == ESCAPE_CODE).count();
let ratio = (plaintext.len() as f64) / (compressed.len() as f64);
println!(
"Escapes = {escape_count}/{}, compression_ratio = {ratio}",
compressed.len()
);

assert_eq!(table.decompress(&compressed), TEST.as_bytes());
let decompressed = table.decompress(&compressed);
let decompressed = str::from_utf8(&decompressed).unwrap();
println!("DECODED: {}", decompressed);
assert_eq!(decompressed, TEST);

group.bench_function("compress-single", |b| {
b.iter(|| black_box(table.compress(black_box(plaintext))));
Expand All @@ -50,29 +51,6 @@ fn bench_fsst(c: &mut Criterion) {
fn bench_lz4(c: &mut Criterion) {
let mut group = c.benchmark_group("lz4");

// {
// let compressed = Vec::with_capacity(10_000);
// let mut encoder = lz4::EncoderBuilder::new()
// .block_size(BlockSize::Max64KB)
// .build(compressed)
// .unwrap();
//
// encoder.write_all(TEST.as_bytes()).unwrap();
// let (compressed, result) = encoder.finish();
// result.unwrap();
//
// let ratio = (TEST.as_bytes().len() as f64) / (compressed.len() as f64);
// println!("LZ4 compress_ratio = {ratio}");
//
// // ensure decodes cleanly
// let cursor = Cursor::new(compressed);
// let mut decoder = lz4::Decoder::new(cursor).unwrap();
// let mut output = String::new();
//
// decoder.read_to_string(&mut output).unwrap();
// assert_eq!(output.as_str(), TEST);
// }

group.bench_function("compress-single", |b| {
let mut compressed = Vec::with_capacity(100_000_000);
let mut encoder = lz4::EncoderBuilder::new()
Expand Down
70 changes: 70 additions & 0 deletions examples/file_compressor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#![allow(missing_docs, clippy::use_debug)]

//! This is a command line program that expects two input files as arguments.
//!
//! The first is the file to train a symbol table on.
//!
//! The second is the file to compress. The compressor will run and compress
//! in chunks of 16MB, logging the compression ratio for each chunk.
//!
//! Example:
//!
//! ```
//! cargo run --release --example file_compressor -- file1.csv file2.csv
//! ```
use std::{
fs::File,
io::Read,
os::unix::fs::{FileExt, MetadataExt},
path::Path,
};

fn main() {
let args: Vec<_> = std::env::args().skip(1).collect();
assert!(args.len() >= 2, "args TRAINING and FILE must be provided");

let train_path = Path::new(&args[0]);
let input_path = Path::new(&args[1]);

let mut train_bytes = Vec::new();
{
let mut f = File::open(train_path).unwrap();
f.read_to_end(&mut train_bytes).unwrap();
}

println!("building the compressor from {train_path:?}...");
let compressor = fsst_rs::train(&train_bytes);

println!("compressing blocks of {input_path:?} with compressor...");

let f = File::open(input_path).unwrap();
let size_bytes = f.metadata().unwrap().size() as usize;

const CHUNK_SIZE: usize = 16 * 1024 * 1024;

let mut chunk_idx = 1;
let mut pos = 0;
let mut chunk = vec![0u8; CHUNK_SIZE];
while pos + CHUNK_SIZE < size_bytes {
f.read_exact_at(&mut chunk, pos as u64).unwrap();
// Compress the chunk, don't write it anywhere.
let compact = compressor.compress(&chunk);
let compression_ratio = (CHUNK_SIZE as f64) / (compact.len() as f64);
println!("compressed chunk {chunk_idx} with ratio {compression_ratio}");

pos += CHUNK_SIZE;
chunk_idx += 1;
}

// Read last chunk with a new custom-sized buffer.
if pos < size_bytes {
let amount = size_bytes - pos;
chunk = vec![0u8; size_bytes - pos];
f.read_exact_at(&mut chunk, pos as u64).unwrap();
// Compress the chunk, don't write it anywhere.
let compact = compressor.compress(&chunk[0..amount]);
let compression_ratio = (amount as f64) / (compact.len() as f64);
println!("compressed chunk {chunk_idx} with ratio {compression_ratio}");
}
println!("done");
}
19 changes: 19 additions & 0 deletions examples/round_trip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//! Simple example where we show round-tripping a string through the static symbol table.

use core::str;

fn main() {
// Train on a sample.
let sample = "the quick brown fox jumped over the lazy dog";
let trained = fsst_rs::train(sample.as_bytes());
let compressed = trained.compress(sample.as_bytes());
println!("compressed: {} => {}", sample.len(), compressed.len());
// decompress now
let decode = trained.decompress(&compressed);
let output = str::from_utf8(&decode).unwrap();
println!(
"decoded to the original: len={} text='{}'",
decode.len(),
output
);
}
3 changes: 1 addition & 2 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[toolchain]
channel = "nightly-2024-06-19"
channel = "nightly-2024-08-14"
components = ["rust-src", "rustfmt", "clippy"]
profile = "minimal"

67 changes: 36 additions & 31 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
use std::cmp::Ordering;
use std::collections::BinaryHeap;

use crate::{Code, Symbol, SymbolTable};
use crate::find_longest::FindLongestSymbol;
use crate::{Symbol, SymbolTable, MAX_CODE};

#[derive(Debug, Clone)]
struct Counter {
Expand All @@ -21,29 +22,29 @@ struct Counter {
impl Counter {
fn new() -> Self {
Self {
counts1: vec![0; Code::CODE_MAX as usize],
counts2: vec![vec![0; Code::CODE_MAX as usize]; Code::CODE_MAX as usize],
counts1: vec![0; MAX_CODE as usize],
counts2: vec![vec![0; MAX_CODE as usize]; MAX_CODE as usize],
}
}

#[inline]
fn record_count1(&mut self, code1: Code) {
self.counts1[code1.0 as usize] += 1;
fn record_count1(&mut self, code1: u16) {
self.counts1[code1 as usize] += 1;
}

#[inline]
fn record_count2(&mut self, code1: Code, code2: Code) {
self.counts2[code1.0 as usize][code2.0 as usize] += 1;
fn record_count2(&mut self, code1: u16, code2: u16) {
self.counts2[code1 as usize][code2 as usize] += 1;
}

#[inline]
fn count1(&self, code: Code) -> usize {
self.counts1[code.0 as usize]
fn count1(&self, code: u16) -> usize {
self.counts1[code as usize]
}

#[inline]
fn count2(&self, code1: Code, code2: Code) -> usize {
self.counts2[code1.0 as usize][code2.0 as usize]
fn count2(&self, code1: u16, code2: u16) -> usize {
self.counts2[code1 as usize][code2 as usize]
}
}

Expand All @@ -65,6 +66,9 @@ pub fn train(corpus: impl AsRef<[u8]>) -> SymbolTable {
let mut table = SymbolTable::default();
// TODO(aduffy): handle truncating/sampling if corpus > requires sample size.
let sample = corpus.as_ref();
if sample.is_empty() {
return table;
}
for _generation in 0..MAX_GENERATIONS {
let counter = table.compress_count(sample);
table = table.optimize(counter);
Expand All @@ -81,13 +85,13 @@ impl SymbolTable {
let len = sample.len();
let mut prev_code = self.find_longest_symbol(sample);
counter.record_count1(prev_code);
let mut pos = self.symbols[prev_code.0 as usize].len();
let mut pos = self.symbols[prev_code as usize].len();

while pos < len {
let code = self.find_longest_symbol(&sample[pos..len]);
counter.record_count1(code);
counter.record_count2(prev_code, code);
pos += self.symbols[code.0 as usize].len();
pos += self.symbols[code as usize].len();
prev_code = code;
}

Expand All @@ -100,17 +104,15 @@ impl SymbolTable {
let mut res = SymbolTable::default();
let mut pqueue = BinaryHeap::new();
for code1 in 0..511 {
let code1 = Code::from_u16(code1);
let symbol1 = self.symbols[code1.0 as usize];
let symbol1 = self.symbols[code1 as usize];
let gain = counters.count1(code1) * symbol1.len();
pqueue.push(Candidate {
symbol: symbol1,
gain,
});

for code2 in 0..511 {
let code2 = Code::from_u16(code2);
let symbol2 = &self.symbols[code2.0 as usize];
let symbol2 = &self.symbols[code2 as usize];
// If either symbol is zero-length, or if merging would yield a symbol of
// length greater than 8, skip.
if symbol1.len() + symbol2.len() >= 8 || symbol1.is_empty() || symbol2.is_empty() {
Expand All @@ -133,10 +135,13 @@ impl SymbolTable {
}

// Pop the 255 best symbols.
pqueue
.iter()
.take(255)
.for_each(|candidate| res.insert(candidate.symbol));
let mut n_symbols = 0;
while !pqueue.is_empty() && n_symbols < 255 {
let candidate = pqueue.pop().unwrap();
if res.insert(candidate.symbol) {
n_symbols += 1;
}
}

res
}
Expand Down Expand Up @@ -181,7 +186,7 @@ impl Ord for Candidate {

#[cfg(test)]
mod test {
use crate::{train, Code};
use crate::{train, ESCAPE_CODE};

#[test]
fn test_builder() {
Expand All @@ -193,26 +198,26 @@ mod test {
let compressed = table.compress(text.as_bytes());

// Ensure that the compressed string has no escape bytes
assert!(compressed.iter().all(|b| *b != Code::ESCAPE_CODE));
assert!(compressed.iter().all(|b| *b != ESCAPE_CODE));

// Ensure that we can compress a string with no values seen at training time.
// Ensure that we can compress a string with no values seen at training time, with escape bytes
let compressed = table.compress("xyz123".as_bytes());
assert_eq!(
compressed,
vec![
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'x',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'y',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'z',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'1',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'2',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'3',
]
)
);
}
}
5 changes: 5 additions & 0 deletions src/find_longest/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod naive;

pub trait FindLongestSymbol {
fn find_longest_symbol(&self, text: &[u8]) -> u16;
}
Loading
Loading