Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve compression speed on small blocks #4165

Merged
merged 10 commits into from
Oct 11, 2024
19 changes: 19 additions & 0 deletions lib/compress/zstd_compress_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,25 @@ MEM_STATIC int ZSTD_cParam_withinBounds(ZSTD_cParameter cParam, int value)
return 1;
}

/* ZSTD_selectAddr:
* @return index >= lowLimit ? candidate : backup,
* tries to force branchless codegen. */
MEM_STATIC const BYTE*
ZSTD_selectAddr(U32 index, U32 lowLimit, const BYTE* candidate, const BYTE* backup)
{
#if defined(__GNUC__) && defined(__x86_64__)
__asm__ (
"cmp %1, %2\n"
"cmova %3, %0\n"
: "+r"(candidate)
: "r"(index), "r"(lowLimit), "r"(backup)
);
return candidate;
#else
return index >= lowLimit ? candidate : backup;
#endif
}

/* ZSTD_noCompressBlock() :
* Writes uncompressed block to dst buffer from given src.
* Returns the size of the block */
Expand Down
39 changes: 25 additions & 14 deletions lib/compress/zstd_double_fast.c
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,17 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic(
U32 idxl1; /* the long match index for ip1 */

const BYTE* matchl0; /* the long match for ip */
const BYTE* matchl0_safe; /* matchl0 or safe address */
const BYTE* matchs0; /* the short match for ip */
const BYTE* matchl1; /* the long match for ip1 */
const BYTE* matchs0_safe; /* matchs0 or safe address */

const BYTE* ip = istart; /* the current position */
const BYTE* ip1; /* the next position */
/* Array of ~random data, should have low probability of matching data
* we load from here instead of from tables, if matchl0/matchl1 are
* invalid indices. Used to avoid unpredictable branches. */
const BYTE dummy[] = {0x12,0x34,0x56,0x78,0x9a,0xbc,0xde,0xf0,0xe2,0xb4};

DEBUGLOG(5, "ZSTD_compressBlock_doubleFast_noDict_generic");

Expand Down Expand Up @@ -191,24 +197,29 @@ size_t ZSTD_compressBlock_doubleFast_noDict_generic(

hl1 = ZSTD_hashPtr(ip1, hBitsL, 8);

if (idxl0 > prefixLowestIndex) {
/* check prefix long match */
if (MEM_read64(matchl0) == MEM_read64(ip)) {
mLength = ZSTD_count(ip+8, matchl0+8, iend) + 8;
offset = (U32)(ip-matchl0);
while (((ip>anchor) & (matchl0>prefixLowest)) && (ip[-1] == matchl0[-1])) { ip--; matchl0--; mLength++; } /* catch up */
goto _match_found;
}
/* idxl0 > prefixLowestIndex is a (somewhat) unpredictable branch.
* However expression below complies into conditional move. Since
* match is unlikely and we only *branch* on idxl0 > prefixLowestIndex
* if there is a match, all branches become predictable. */
matchl0_safe = ZSTD_selectAddr(idxl0, prefixLowestIndex, matchl0, &dummy[0]);

/* check prefix long match */
if (MEM_read64(matchl0_safe) == MEM_read64(ip) && matchl0_safe == matchl0) {
mLength = ZSTD_count(ip+8, matchl0+8, iend) + 8;
offset = (U32)(ip-matchl0);
while (((ip>anchor) & (matchl0>prefixLowest)) && (ip[-1] == matchl0[-1])) { ip--; matchl0--; mLength++; } /* catch up */
goto _match_found;
}

idxl1 = hashLong[hl1];
matchl1 = base + idxl1;

if (idxs0 > prefixLowestIndex) {
/* check prefix short match */
if (MEM_read32(matchs0) == MEM_read32(ip)) {
goto _search_next_long;
}
/* Same optimization as matchl0 above */
matchs0_safe = ZSTD_selectAddr(idxs0, prefixLowestIndex, matchs0, &dummy[0]);

/* check prefix short match */
if(MEM_read32(matchs0_safe) == MEM_read32(ip) && matchs0_safe == matchs0) {
goto _search_next_long;
}

if (ip1 >= nextStep) {
Expand Down Expand Up @@ -651,7 +662,7 @@ size_t ZSTD_compressBlock_doubleFast_extDict_generic(
size_t mLength;
hashSmall[hSmall] = hashLong[hLong] = curr; /* update hash table */

if (((ZSTD_index_overlap_check(prefixStartIndex, repIndex))
if (((ZSTD_index_overlap_check(prefixStartIndex, repIndex))
& (offset_1 <= curr+1 - dictStartIndex)) /* note: we are searching at curr+1 */
&& (MEM_read32(repMatch) == MEM_read32(ip+1)) ) {
const BYTE* repMatchEnd = repIndex < prefixStartIndex ? dictEnd : iend;
Expand Down
134 changes: 76 additions & 58 deletions lib/compress/zstd_fast.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void ZSTD_fillHashTableForCDict(ZSTD_matchState_t* ms,
size_t const hashAndTag = ZSTD_hashPtr(ip + p, hBits, mls);
if (hashTable[hashAndTag >> ZSTD_SHORT_CACHE_TAG_BITS] == 0) { /* not yet filled */
ZSTD_writeTaggedIndex(hashTable, hashAndTag, curr + p);
} } } }
} } } }
}

static
Expand Down Expand Up @@ -97,6 +97,50 @@ void ZSTD_fillHashTable(ZSTD_matchState_t* ms,
}


typedef int (*ZSTD_match4Found) (const BYTE* currentPtr, const BYTE* matchAddress, U32 matchIdx, U32 idxLowLimit);

static int
ZSTD_match4Found_cmov(const BYTE* currentPtr, const BYTE* matchAddress, U32 matchIdx, U32 idxLowLimit)
{
/* Array of ~random data, should have low probability of matching data.
* Load from here if the index is invalid.
* Used to avoid unpredictable branches. */
static const BYTE dummy[] = {0x12,0x34,0x56,0x78};

/* currentIdx >= lowLimit is a (somewhat) unpredictable branch.
* However expression below compiles into conditional move.
*/
const BYTE* mvalAddr = ZSTD_selectAddr(matchIdx, idxLowLimit, matchAddress, dummy);
/* Note: this used to be written as : return test1 && test2;
* Unfortunately, once inlined, these tests become branches,
* in which case it becomes critical that they are executed in the right order (test1 then test2).
* So we have to write these tests in a specific manner to ensure their ordering.
*/
if (MEM_read32(currentPtr) != MEM_read32(mvalAddr)) return 0;
/* force ordering of these tests, which matters once the function is inlined, as they become branches */
#if defined(__GNUC__)
__asm__("");
#endif
return matchIdx >= idxLowLimit;
}

static int
ZSTD_match4Found_branch(const BYTE* currentPtr, const BYTE* matchAddress, U32 matchIdx, U32 idxLowLimit)
{
/* using a branch instead of a cmov,
* because it's faster in scenarios where matchIdx >= idxLowLimit is generally true,
* aka almost all candidates are within range */
U32 mval;
if (matchIdx >= idxLowLimit) {
mval = MEM_read32(matchAddress);
} else {
mval = MEM_read32(currentPtr) ^ 1; /* guaranteed to not match. */
}

return (MEM_read32(currentPtr) == mval);
}


/**
* If you squint hard enough (and ignore repcodes), the search operation at any
* given position is broken into 4 stages:
Expand Down Expand Up @@ -148,13 +192,12 @@ ZSTD_ALLOW_POINTER_OVERFLOW_ATTR
size_t ZSTD_compressBlock_fast_noDict_generic(
ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM],
void const* src, size_t srcSize,
U32 const mls, U32 const hasStep)
U32 const mls, int useCmov)
{
const ZSTD_compressionParameters* const cParams = &ms->cParams;
U32* const hashTable = ms->hashTable;
U32 const hlog = cParams->hashLog;
/* support stepSize of 0 */
size_t const stepSize = hasStep ? (cParams->targetLength + !(cParams->targetLength) + 1) : 2;
size_t const stepSize = cParams->targetLength + !(cParams->targetLength) + 1; /* min 2 */
const BYTE* const base = ms->window.base;
const BYTE* const istart = (const BYTE*)src;
const U32 endIndex = (U32)((size_t)(istart - base) + srcSize);
Expand All @@ -176,8 +219,7 @@ size_t ZSTD_compressBlock_fast_noDict_generic(

size_t hash0; /* hash for ip0 */
size_t hash1; /* hash for ip1 */
U32 idx; /* match idx for ip0 */
U32 mval; /* src value at match idx */
U32 matchIdx; /* match idx for ip0 */

U32 offcode;
const BYTE* match0;
Expand All @@ -190,6 +232,7 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
size_t step;
const BYTE* nextStep;
const size_t kStepIncr = (1 << (kSearchStrength - 1));
const ZSTD_match4Found matchFound = useCmov ? ZSTD_match4Found_cmov : ZSTD_match4Found_branch;

DEBUGLOG(5, "ZSTD_compressBlock_fast_generic");
ip0 += (ip0 == prefixStart);
Expand Down Expand Up @@ -218,7 +261,7 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
hash0 = ZSTD_hashPtr(ip0, hlog, mls);
hash1 = ZSTD_hashPtr(ip1, hlog, mls);

idx = hashTable[hash0];
matchIdx = hashTable[hash0];

do {
/* load repcode match for ip[2]*/
Expand All @@ -238,35 +281,25 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
offcode = REPCODE1_TO_OFFBASE;
mLength += 4;

/* First write next hash table entry; we've already calculated it.
* This write is known to be safe because the ip1 is before the
/* Write next hash table entry: it's already calculated.
* This write is known to be safe because ip1 is before the
* repcode (ip2). */
hashTable[hash1] = (U32)(ip1 - base);

goto _match;
}

/* load match for ip[0] */
if (idx >= prefixStartIndex) {
mval = MEM_read32(base + idx);
} else {
mval = MEM_read32(ip0) ^ 1; /* guaranteed to not match. */
}

/* check match at ip[0] */
if (MEM_read32(ip0) == mval) {
/* found a match! */

/* First write next hash table entry; we've already calculated it.
* This write is known to be safe because the ip1 == ip0 + 1, so
* we know we will resume searching after ip1 */
if (matchFound(ip0, base + matchIdx, matchIdx, prefixStartIndex)) {
/* Write next hash table entry (it's already calculated).
* This write is known to be safe because the ip1 == ip0 + 1,
* so searching will resume after ip1 */
hashTable[hash1] = (U32)(ip1 - base);

goto _offset;
}

/* lookup ip[1] */
idx = hashTable[hash1];
matchIdx = hashTable[hash1];

/* hash ip[2] */
hash0 = hash1;
Expand All @@ -281,36 +314,19 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
current0 = (U32)(ip0 - base);
hashTable[hash0] = current0;

/* load match for ip[0] */
if (idx >= prefixStartIndex) {
mval = MEM_read32(base + idx);
} else {
mval = MEM_read32(ip0) ^ 1; /* guaranteed to not match. */
}

/* check match at ip[0] */
if (MEM_read32(ip0) == mval) {
/* found a match! */

/* first write next hash table entry; we've already calculated it */
if (matchFound(ip0, base + matchIdx, matchIdx, prefixStartIndex)) {
/* Write next hash table entry, since it's already calculated */
if (step <= 4) {
/* We need to avoid writing an index into the hash table >= the
* position at which we will pick up our searching after we've
* taken this match.
*
* The minimum possible match has length 4, so the earliest ip0
* can be after we take this match will be the current ip0 + 4.
* ip1 is ip0 + step - 1. If ip1 is >= ip0 + 4, we can't safely
* write this position.
*/
/* Avoid writing an index if it's >= position where search will resume.
* The minimum possible match has length 4, so search can resume at ip0 + 4.
*/
hashTable[hash1] = (U32)(ip1 - base);
}

goto _offset;
}

/* lookup ip[1] */
idx = hashTable[hash1];
matchIdx = hashTable[hash1];

/* hash ip[2] */
hash0 = hash1;
Expand All @@ -332,7 +348,7 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
} while (ip3 < ilimit);

_cleanup:
/* Note that there are probably still a couple positions we could search.
/* Note that there are probably still a couple positions one could search.
* However, it seems to be a meaningful performance hit to try to search
* them. So let's not. */

Expand Down Expand Up @@ -361,7 +377,7 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
_offset: /* Requires: ip0, idx */

/* Compute the offset code. */
match0 = base + idx;
match0 = base + matchIdx;
rep_offset2 = rep_offset1;
rep_offset1 = (U32)(ip0-match0);
offcode = OFFSET_TO_OFFBASE(rep_offset1);
Expand Down Expand Up @@ -406,12 +422,12 @@ size_t ZSTD_compressBlock_fast_noDict_generic(
goto _start;
}

#define ZSTD_GEN_FAST_FN(dictMode, mls, step) \
static size_t ZSTD_compressBlock_fast_##dictMode##_##mls##_##step( \
#define ZSTD_GEN_FAST_FN(dictMode, mml, cmov) \
static size_t ZSTD_compressBlock_fast_##dictMode##_##mml##_##cmov( \
ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM], \
void const* src, size_t srcSize) \
{ \
return ZSTD_compressBlock_fast_##dictMode##_generic(ms, seqStore, rep, src, srcSize, mls, step); \
return ZSTD_compressBlock_fast_##dictMode##_generic(ms, seqStore, rep, src, srcSize, mml, cmov); \
}

ZSTD_GEN_FAST_FN(noDict, 4, 1)
Expand All @@ -428,10 +444,12 @@ size_t ZSTD_compressBlock_fast(
ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM],
void const* src, size_t srcSize)
{
U32 const mls = ms->cParams.minMatch;
U32 const mml = ms->cParams.minMatch;
/* use cmov when "candidate in range" branch is likely unpredictable */
int const useCmov = ms->cParams.windowLog < 19;
assert(ms->dictMatchState == NULL);
if (ms->cParams.targetLength > 1) {
switch(mls)
if (useCmov) {
switch(mml)
{
default: /* includes case 3 */
case 4 :
Expand All @@ -444,7 +462,8 @@ size_t ZSTD_compressBlock_fast(
return ZSTD_compressBlock_fast_noDict_7_1(ms, seqStore, rep, src, srcSize);
}
} else {
switch(mls)
/* use a branch instead */
switch(mml)
{
default: /* includes case 3 */
case 4 :
Expand All @@ -456,7 +475,6 @@ size_t ZSTD_compressBlock_fast(
case 7 :
return ZSTD_compressBlock_fast_noDict_7_0(ms, seqStore, rep, src, srcSize);
}

}
}

Expand Down Expand Up @@ -546,7 +564,7 @@ size_t ZSTD_compressBlock_fast_dictMatchState_generic(
size_t const dictHashAndTag1 = ZSTD_hashPtr(ip1, dictHBits, mls);
hashTable[hash0] = curr; /* update hash table */

if ((ZSTD_index_overlap_check(prefixStartIndex, repIndex))
if ((ZSTD_index_overlap_check(prefixStartIndex, repIndex))
&& (MEM_read32(repMatch) == MEM_read32(ip0 + 1))) {
const BYTE* const repMatchEnd = repIndex < prefixStartIndex ? dictEnd : iend;
mLength = ZSTD_count_2segments(ip0 + 1 + 4, repMatch + 4, iend, repMatchEnd, prefixStart) + 4;
Expand Down