From ff10202b647b4e1f5c42c5fe76340f8ea551d82c Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Thu, 17 Aug 2023 12:09:32 +0200 Subject: [PATCH] WIP: Add dictionary builder Functional and ok, but has failure modes. --- dict/builder.go | 491 +++++++++++++++++++++++++++++++++++++ dict/cmd/builddict/main.go | 93 +++++++ s2/dict.go | 22 ++ zstd/dict.go | 213 ++++++++++++++++ 4 files changed, 819 insertions(+) create mode 100644 dict/builder.go create mode 100644 dict/cmd/builddict/main.go diff --git a/dict/builder.go b/dict/builder.go new file mode 100644 index 0000000000..d0430ad073 --- /dev/null +++ b/dict/builder.go @@ -0,0 +1,491 @@ +// Copyright 2023+ Klaus Post. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dict + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math/rand" + "sort" + "time" + + "github.com/klauspost/compress/s2" + "github.com/klauspost/compress/zstd" +) + +type match struct { + hash uint32 + n uint32 + offset int64 +} + +type matchValue struct { + value []byte + followBy map[uint32]uint32 + preceededBy map[uint32]uint32 +} + +type Options struct { + // MaxDictSize is the max size of the backreference dictionary. + MaxDictSize int + + // HashBytes is the minimum length to index. + // Must be >=4 and <=8 + HashBytes int + + // Debug output + Output io.Writer + + // ZstdDictID is the Zstd dictionary ID to use. + // Leave at zero to generate a random ID. + ZstdDictID uint32 + + outFormat int +} + +const ( + formatRaw = iota + formatZstd + formatS2 +) + +func BuildZstdDict(input [][]byte, o Options) ([]byte, error) { + o.outFormat = formatZstd + if o.ZstdDictID == 0 { + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + o.ZstdDictID = 32768 + uint32(rng.Int31n((1<<31)-32768)) + } + return buildDict(input, o) +} + +func BuildS2Dict(input [][]byte, o Options) ([]byte, error) { + o.outFormat = formatS2 + return buildDict(input, o) +} + +func BuildRawDict(input [][]byte, o Options) ([]byte, error) { + o.outFormat = formatRaw + return buildDict(input, o) +} + +func buildDict(input [][]byte, o Options) ([]byte, error) { + matches := make(map[uint32]uint32) + offsets := make(map[uint32]int64) + var total uint64 + + wantLen := o.MaxDictSize + hashBytes := o.HashBytes + if len(input) == 0 { + return nil, fmt.Errorf("no input provided") + } + if hashBytes < 4 || hashBytes > 8 { + return nil, fmt.Errorf("HashBytes must be >= 4 and <= 8") + } + println := func(args ...interface{}) { + if o.Output != nil { + fmt.Fprintln(o.Output, args...) + } + } + printf := func(s string, args ...interface{}) { + if o.Output != nil { + fmt.Fprintf(o.Output, s, args...) + } + } + found := make(map[uint32]struct{}) + for i, b := range input { + for k := range found { + delete(found, k) + } + for i := range b { + rem := b[i:] + if len(rem) < 8 { + break + } + h := hashLen(binary.LittleEndian.Uint64(rem), 32, uint8(hashBytes)) + if _, ok := found[h]; ok { + // Only count first occurrence + continue + } + matches[h]++ + offsets[h] += int64(i) + total++ + found[h] = struct{}{} + } + printf("\r input %d indexed...", i) + } + threshold := uint32(total / uint64(len(matches))) + println("\nTotal", total, "match", len(matches), "avg", threshold) + sorted := make([]match, 0, len(matches)/2) + for k, v := range matches { + if v <= threshold { + continue + } + sorted = append(sorted, match{hash: k, n: v, offset: offsets[k]}) + } + sort.Slice(sorted, func(i, j int) bool { + if sorted[i].n == sorted[j].n { + return sorted[i].offset < sorted[j].offset + } + return sorted[i].n > sorted[j].n + }) + println("Sorted len:", len(sorted)) + if len(sorted) > wantLen { + sorted = sorted[:wantLen] + } + lowestOcc := sorted[len(sorted)-1].n + println("Cropped len:", len(sorted), "Lowest occurrence:", lowestOcc) + + wantMatches := make(map[uint32]uint32, len(sorted)) + for _, v := range sorted { + wantMatches[v.hash] = v.n + } + + output := make(map[uint32]matchValue, len(sorted)) + var remainCnt [256]int + var remainTotal int + var firstOffsets []int + for i, b := range input { + for i := range b { + rem := b[i:] + if len(rem) < 8 { + break + } + var prev []byte + if i > hashBytes { + prev = b[i-hashBytes:] + } + + h := hashLen(binary.LittleEndian.Uint64(rem), 32, uint8(hashBytes)) + if _, ok := wantMatches[h]; !ok { + remainCnt[rem[0]]++ + remainTotal++ + continue + } + mv := output[h] + if len(mv.value) == 0 { + var tmp = make([]byte, hashBytes) + copy(tmp[:], rem) + mv.value = tmp[:] + } + if mv.followBy == nil { + mv.followBy = make(map[uint32]uint32, 4) + mv.preceededBy = make(map[uint32]uint32, 4) + } + if len(rem) > hashBytes+8 { + // Check if we should add next as well. + hNext := hashLen(binary.LittleEndian.Uint64(rem[hashBytes:]), 32, uint8(hashBytes)) + if _, ok := wantMatches[hNext]; ok { + mv.followBy[hNext]++ + } + } + if len(prev) >= 8 { + // Check if we should prev next as well. + hPrev := hashLen(binary.LittleEndian.Uint64(prev), 32, uint8(hashBytes)) + if _, ok := wantMatches[hPrev]; ok { + mv.preceededBy[hPrev]++ + } + } + output[h] = mv + } + printf("\rinput %d re-indexed...", i) + } + println("") + dst := make([][]byte, 0, wantLen/hashBytes) + added := 0 + const printUntil = 500 + for i, e := range sorted { + if added > o.MaxDictSize { + break + } + m, ok := output[e.hash] + if !ok { + // Already added + continue + } + var tmp = make([]byte, 0, hashBytes*2) + { + sortedPrev := make([]match, 0, len(m.followBy)) + for k, v := range m.preceededBy { + if _, ok := output[k]; !ok { + continue + } + sortedPrev = append(sortedPrev, match{ + hash: k, + n: v, + }) + } + if len(sortedPrev) > 0 { + sort.Slice(sortedPrev, func(i, j int) bool { + return sortedPrev[i].n > sortedPrev[j].n + }) + bestPrev := output[sortedPrev[0].hash] + tmp = append(tmp, bestPrev.value...) + } + } + tmp = append(tmp, m.value...) + delete(output, e.hash) + wantLen := e.n / uint32(hashBytes) / 4 + if wantLen <= lowestOcc { + wantLen = lowestOcc + } + for { + var nh uint32 // Next hash + stopAfter := false + if true { + sortedFollow := make([]match, 0, len(m.followBy)) + for k, v := range m.followBy { + if _, ok := output[k]; !ok { + continue + } + sortedFollow = append(sortedFollow, match{ + hash: k, + n: v, + }) + } + if len(sortedFollow) == 0 { + break + } + sort.Slice(sortedFollow, func(i, j int) bool { + return sortedFollow[i].n > sortedFollow[j].n + }) + nh = sortedFollow[0].hash + stopAfter = sortedFollow[0].n < wantLen + } + m, ok = output[nh] + if !ok { + break + } + if len(tmp) > 0 { + // Delete all hashes that are in the current string to avoid stuttering. + var toDel [16 + 8]byte + copy(toDel[:], tmp[len(tmp)-hashBytes:]) + copy(toDel[hashBytes:], m.value) + for i := range toDel[:hashBytes*2] { + delete(output, hashLen(binary.LittleEndian.Uint64(toDel[i:]), 32, uint8(hashBytes))) + } + } + tmp = append(tmp, m.value...) + //delete(output, nh) + if stopAfter { + // Last entry was no significant. + break + } + } + if i < printUntil { + printf("ENTRY %d: %q (%d occurrences, cutoff %d)\n", i, string(tmp), e.n, wantLen) + } + // Delete substrings already added. + if len(tmp) > hashBytes { + for j := range tmp[:len(tmp)-hashBytes+1] { + var t8 [8]byte + copy(t8[:], tmp[j:]) + if i < 100 { + if false { + printf("DELETE %q\n", string(t8[:hashBytes])) + } + } + delete(output, hashLen(binary.LittleEndian.Uint64(t8[:]), 32, uint8(hashBytes))) + } + } + dst = append(dst, tmp) + added += len(tmp) + // Find offsets + // TODO: This can be better if done as a global search. + if len(firstOffsets) < 3 { + if len(tmp) > 16 { + tmp = tmp[:16] + } + offCnt := make(map[int]int, len(input)) + // Find first offsets + for _, b := range input { + off := bytes.Index(b, tmp) + if off == -1 { + continue + } + offCnt[off]++ + } + for _, off := range firstOffsets { + // Very unlikely, but we deleted it just in case + delete(offCnt, off-added) + } + maxCnt := 0 + maxOffset := 0 + for k, v := range offCnt { + if v == maxCnt && k > maxOffset { + // Prefer the longer offset on ties , since it is more expensive to encode + maxCnt = v + maxOffset = k + continue + } + + if v > maxCnt { + maxCnt = v + maxOffset = k + } + } + if maxCnt > 1 { + firstOffsets = append(firstOffsets, maxOffset+added) + println(" - Offset:", len(firstOffsets), "at", maxOffset+added, "count:", maxCnt, "total added:", added, "src index", maxOffset) + } + } + } + out := bytes.NewBuffer(nil) + written := 0 + for i, toWrite := range dst { + if len(toWrite)+written > wantLen { + toWrite = toWrite[:wantLen-written] + } + dst[i] = toWrite + written += len(toWrite) + if written >= wantLen { + dst = dst[:i+1] + break + } + } + // Write in reverse order. + for i := range dst { + toWrite := dst[len(dst)-i-1] + out.Write(toWrite) + } + if o.outFormat == formatRaw { + return out.Bytes(), nil + } + + if o.outFormat == formatS2 { + dOff := 0 + dBytes := out.Bytes() + if len(dBytes) > s2.MaxDictSize { + dBytes = dBytes[:s2.MaxDictSize] + } + for _, off := range firstOffsets { + myOff := len(dBytes) - off + if myOff < 0 || myOff > s2.MaxDictSrcOffset { + continue + } + dOff = myOff + } + + dict := s2.MakeDictManual(dBytes, uint16(dOff)) + if dict == nil { + return nil, fmt.Errorf("unable to create s2 dictionary") + } + return dict.Bytes(), nil + } + /* + avgSize := 256 + println("\nHuffman: literal total:", remainTotal, "normalized counts on remainder size:", avgSize) + huffBuff := make([]byte, 0, avgSize) + // Target size + div := remainTotal / avgSize + if div < 1 { + div = 1 + } + for i, n := range remainCnt[:] { + if n > 0 { + n = n / div + if n == 0 { + n = 1 + } + huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...) + fmt.Printf("[%d: %d], ", i, n) + } + } + println("") + scratch := &huff0.Scratch{} + _, _, err := huff0.Compress1X(huffBuff, scratch) + if err != nil { + // TODO: Handle RLE + return nil, err + } + println("Huffman table:", len(scratch.OutTable), "bytes") + */ + offsetsZstd := [3]int{1, 4, 8} + for i, off := range firstOffsets { + if i >= 3 || off == 0 || off >= out.Len() { + break + } + offsetsZstd[i] = off + } + println("\nCompressing. Offsets:", offsetsZstd) + return zstd.BuildDict(zstd.BuildDictOptions{ + ID: o.ZstdDictID, + Contents: input, + History: out.Bytes(), + Offsets: offsetsZstd, + }) +} + +const ( + prime3bytes = 506832829 + prime4bytes = 2654435761 + prime5bytes = 889523592379 + prime6bytes = 227718039650203 + prime7bytes = 58295818150454627 + prime8bytes = 0xcf1bbcdcb7a56463 +) + +// hashLen returns a hash of the lowest l bytes of u for a size size of h bytes. +// l must be >=4 and <=8. Any other value will return hash for 4 bytes. +// h should always be <32. +// Preferably h and l should be a constant. +// LENGTH 4 is passed straight through +func hashLen(u uint64, hashLog, mls uint8) uint32 { + switch mls { + case 5: + return hash5(u, hashLog) + case 6: + return hash6(u, hashLog) + case 7: + return hash7(u, hashLog) + case 8: + return hash8(u, hashLog) + default: + return uint32(u) + } +} + +// hash3 returns the hash of the lower 3 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <32. +func hash3(u uint32, h uint8) uint32 { + return ((u << (32 - 24)) * prime3bytes) >> ((32 - h) & 31) +} + +// hash4 returns the hash of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <32. +func hash4(u uint32, h uint8) uint32 { + return (u * prime4bytes) >> ((32 - h) & 31) +} + +// hash4x64 returns the hash of the lowest 4 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <32. +func hash4x64(u uint64, h uint8) uint32 { + return (uint32(u) * prime4bytes) >> ((32 - h) & 31) +} + +// hash5 returns the hash of the lowest 5 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash5(u uint64, h uint8) uint32 { + return uint32(((u << (64 - 40)) * prime5bytes) >> ((64 - h) & 63)) +} + +// hash6 returns the hash of the lowest 6 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash6(u uint64, h uint8) uint32 { + return uint32(((u << (64 - 48)) * prime6bytes) >> ((64 - h) & 63)) +} + +// hash7 returns the hash of the lowest 7 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash7(u uint64, h uint8) uint32 { + return uint32(((u << (64 - 56)) * prime7bytes) >> ((64 - h) & 63)) +} + +// hash8 returns the hash of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash8(u uint64, h uint8) uint32 { + return uint32((u * prime8bytes) >> ((64 - h) & 63)) +} diff --git a/dict/cmd/builddict/main.go b/dict/cmd/builddict/main.go new file mode 100644 index 0000000000..f57325d7ac --- /dev/null +++ b/dict/cmd/builddict/main.go @@ -0,0 +1,93 @@ +// Copyright 2023+ Klaus Post. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "flag" + "fmt" + "io" + "log" + "os" + "path/filepath" + + "github.com/klauspost/compress/dict" +) + +var ( + wantLenFlag = flag.Int("len", 112<<10, "Specify custom output size") + wantHashBytes = flag.Int("hash", 8, "Hash bytes match length. Minimum match length.") + wantMaxBytes = flag.Int("max", 32<<10, "Max input length to index per input file") + wantOutput = flag.String("o", "dictionary.bin", "Output name") + wantFormat = flag.String("format", "zstd", `Output type. "zstd" "s2" or "raw"`) + wantZstdID = flag.Uint("zstdid", 0, "Zstd dictionary ID. 0 will be random") + quiet = flag.Bool("q", false, "Do not print progress") +) + +func main() { + flag.Parse() + o := dict.Options{ + MaxDictSize: *wantLenFlag, + HashBytes: *wantHashBytes, + Output: os.Stdout, + ZstdDictID: uint32(*wantZstdID), + } + if *wantOutput == "" || *quiet { + o.Output = nil + } + var input [][]byte + base := flag.Arg(0) + if base == "" { + log.Fatal("no path with files specified") + } + + // Index ALL hashes in all files. + filepath.Walk(base, func(path string, info os.FileInfo, err error) error { + if info.IsDir() { + return nil + } + + f, err := os.Open(path) + if err != nil { + log.Print(err) + return nil + } + defer f.Close() + b, err := io.ReadAll(io.LimitReader(f, int64(*wantMaxBytes))) + if len(b) < 8 { + return nil + } + input = append(input, b) + if !*quiet { + fmt.Print("\r"+info.Name(), " read...") + } + return nil + }) + var out []byte + var err error + switch *wantFormat { + case "zstd": + out, err = dict.BuildZstdDict(input, o) + case "s2": + out, err = dict.BuildS2Dict(input, o) + case "raw": + out, err = dict.BuildRawDict(input, o) + default: + err = fmt.Errorf("unknown format %q", *wantFormat) + } + if err != nil { + log.Fatal(err) + } + if *wantOutput != "" { + err = os.WriteFile(*wantOutput, out, 0666) + if err != nil { + log.Fatal(err) + } + } else { + _, err = os.Stdout.Write(out) + if err != nil { + log.Fatal(err) + } + } +} diff --git a/s2/dict.go b/s2/dict.go index 24f7ce80bc..93e858ba65 100644 --- a/s2/dict.go +++ b/s2/dict.go @@ -106,6 +106,28 @@ func MakeDict(data []byte, searchStart []byte) *Dict { return &d } +// MakeDict will create a dictionary. +// 'data' must be at least MinDictSize. +// If data is longer than MaxDictSize only the last MaxDictSize bytes will be used. +// A manual first repeat value must be provided. It cannot be 0. +func MakeDictManual(data []byte, firstIdx uint16) *Dict { + if len(data) == 0 || int(firstIdx) > len(data)-8 || len(data) > MaxDictSize { + return nil + } + var d Dict + dict := data + d.dict = dict + if cap(d.dict) < len(d.dict)+16 { + d.dict = append(make([]byte, 0, len(d.dict)+16), d.dict...) + } + if len(dict) < MinDictSize { + return nil + } + + d.repeat = int(firstIdx) + return &d +} + // Encode returns the encoded form of src. The returned slice may be a sub- // slice of dst if dst was large enough to hold the entire encoded block. // Otherwise, a newly allocated slice will be returned. diff --git a/zstd/dict.go b/zstd/dict.go index ca0951452e..5acd6457c4 100644 --- a/zstd/dict.go +++ b/zstd/dict.go @@ -1,10 +1,12 @@ package zstd import ( + "bytes" "encoding/binary" "errors" "fmt" "io" + "math" "github.com/klauspost/compress/huff0" ) @@ -159,3 +161,214 @@ func InspectDictionary(b []byte) (interface { d, err := loadDict(b) return d, err } + +type BuildDictOptions struct { + // Dictionary ID. + ID uint32 + + // Content to use to create dictionary tables. + Contents [][]byte + + // History to use for all blocks. + History []byte + + // Offsets to use. + Offsets [3]int +} + +func BuildDict(o BuildDictOptions) ([]byte, error) { + initPredefined() + hist := o.History + contents := o.Contents + const debug = false + if len(hist) > dictMaxLength { + return nil, fmt.Errorf("dictionary of size %d > %d", len(hist), dictMaxLength) + } + if len(hist) < 8 { + return nil, fmt.Errorf("dictionary of size %d < %d", len(hist), 8) + } + if len(contents) == 0 { + return nil, errors.New("no content provided") + } + d := dict{ + id: o.ID, + litEnc: nil, + llDec: sequenceDec{}, + ofDec: sequenceDec{}, + mlDec: sequenceDec{}, + offsets: o.Offsets, + content: hist, + } + block := blockEnc{lowMem: false} + block.init() + enc := &bestFastEncoder{fastBase: fastBase{maxMatchOff: int32(maxMatchLen), bufferReset: math.MaxInt32 - int32(maxMatchLen*2), lowMem: false}} + var ( + remain [256]int + ll [256]int + ml [256]int + of [256]int + ) + addValues := func(dst *[256]int, src []byte) { + for _, v := range src { + dst[v]++ + } + } + addHist := func(dst *[256]int, src *[256]uint32) { + for i, v := range src { + dst[i] += int(v) + } + } + seqs := 0 + nUsed := 0 + litTotal := 0 + for _, b := range contents { + block.reset(nil) + if len(b) < 8 { + continue + } + nUsed++ + enc.Reset(&d, true) + enc.Encode(&block, b) + addValues(&remain, block.literals) + litTotal += len(block.literals) + seqs += len(block.sequences) + block.genCodes() + addHist(&ll, block.coders.llEnc.Histogram()) + addHist(&ml, block.coders.mlEnc.Histogram()) + addHist(&of, block.coders.ofEnc.Histogram()) + } + if nUsed == 0 || seqs == 0 { + return nil, fmt.Errorf("%d blocks, %d sequences found", nUsed, seqs) + } + if debug { + fmt.Println("Sequences:", seqs, "Blocks:", nUsed, "Literals:", litTotal) + } + if seqs/nUsed < 512 { + // Use 512 as minimum. + nUsed = seqs / 512 + } + copyHist := func(dst *fseEncoder, src *[256]int) ([]byte, error) { + hist := dst.Histogram() + var maxSym uint8 + var maxCount int + var fakeLength int + for i, v := range src { + if v > 0 { + v = v / nUsed + if v == 0 { + v = 1 + } + } + if v > maxCount { + maxCount = v + } + if v != 0 { + maxSym = uint8(i) + } + fakeLength += v + hist[i] = uint32(v) + } + dst.HistogramFinished(maxSym, maxCount) + dst.reUsed = false + dst.useRLE = false + err := dst.normalizeCount(fakeLength) + if err != nil { + return nil, err + } + if debug { + fmt.Println("RAW:", dst.count[:maxSym+1], "NORM:", dst.norm[:maxSym+1], "LEN:", fakeLength) + } + return dst.writeCount(nil) + } + if debug { + fmt.Print("Literal lengths: ") + } + llTable, err := copyHist(block.coders.llEnc, &ll) + if err != nil { + return nil, err + } + if debug { + fmt.Print("Match lengths: ") + } + mlTable, err := copyHist(block.coders.mlEnc, &ml) + if err != nil { + return nil, err + } + if debug { + fmt.Print("Offsets: ") + } + ofTable, err := copyHist(block.coders.ofEnc, &of) + if err != nil { + return nil, err + } + + // Liteal table + avgSize := litTotal + if avgSize > huff0.BlockSizeMax/2 { + avgSize = huff0.BlockSizeMax / 2 + } + huffBuff := make([]byte, 0, avgSize) + // Target size + div := litTotal / avgSize + if div < 1 { + div = 1 + } + if debug { + fmt.Println("Huffman weights:") + } + for i, n := range remain[:] { + if n > 0 { + n = n / div + // Allow all entries to be represented. + if n == 0 { + n = 1 + } + huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...) + if debug { + fmt.Printf("[%d: %d], ", i, n) + } + } + } + if remain[255]/div == 0 { + huffBuff = append(huffBuff, 255) + } + scratch := &huff0.Scratch{TableLog: 11} + _, _, err = huff0.Compress1X(huffBuff, scratch) + if err != nil { + // TODO: Handle RLE + return nil, err + } + + var out bytes.Buffer + out.Write([]byte(dictMagic)) + out.Write(binary.LittleEndian.AppendUint32(nil, o.ID)) + out.Write(scratch.OutTable) + if debug { + fmt.Println("huff table:", len(scratch.OutTable), "bytes") + fmt.Println("of table:", len(ofTable), "bytes") + fmt.Println("ml table:", len(mlTable), "bytes") + fmt.Println("ll table:", len(llTable), "bytes") + } + out.Write(ofTable) + out.Write(mlTable) + out.Write(llTable) + out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[0]))) + out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[1]))) + out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[2]))) + out.Write(hist) + if debug { + _, err := loadDict(out.Bytes()) + if err != nil { + panic(err) + } + i, err := InspectDictionary(out.Bytes()) + if err != nil { + panic(err) + } + fmt.Println("ID:", i.ID()) + fmt.Println("Content size:", i.ContentSize()) + fmt.Println("Encoder:", i.LitEncoder() != nil) + fmt.Println("Offsets:", i.Offsets()) + } + return out.Bytes(), nil +}