Skip to content

Commit

Permalink
Cleanup code
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Jul 21, 2024
1 parent dde37fc commit 3b0ecfd
Show file tree
Hide file tree
Showing 23 changed files with 578 additions and 348 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Implements:
* Flash Attention
* Mixture of Experts
* Huggingface [SafeTensors](https://github.com/huggingface/safetensors) model and tokenizer format
* Support for F32, F16, BF16 models
* Support for F32, F16, BF16 types
* Support for Q8, Q4 model quantization
* Fast GEMM operations
* Distributed Inference!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public static float bFloat16ToFloat32(short raw) {
}

public static short float32ToBFloat16(float n) {
//if (true)
// if (true)
// return (short) ((Float.floatToRawIntBits(n) >> 16) & 0xffff);
int nbits = Float.floatToRawIntBits(n);
// 32 bits has 1 sign bit, 8 exponent bits, 23 mantissa bits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.github.tjake.jlama.util.BiIntConsumer;
import com.github.tjake.jlama.util.PhysicalCoreExecutor;
import com.google.common.base.Preconditions;
import java.util.function.IntConsumer;
import java.util.stream.IntStream;

import com.google.common.base.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -33,9 +32,7 @@ public class VectorMath {
public static void pfor(int start, int end, IntConsumer action) {
PhysicalCoreExecutor.instance
.get()
.execute(() -> IntStream.range(start, end)
.parallel()
.forEach(action));
.execute(() -> IntStream.range(start, end).parallel().forEach(action));
}

public static void pchunk(int offset, int length, BiIntConsumer action) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import com.github.tjake.jlama.util.Pair;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -255,8 +254,6 @@ public int sample(AbstractTensor output, float temperature, float uniformSample,
double maxv = Double.NEGATIVE_INFINITY;
for (int i = 0; i < c.vocabularySize; i++) {
float v = logits.get(0, i);
//v = (float) (30.0f * Math.tanh(v / 30.0f));
//logits.set(v, 0, i);
if (v > maxv) {
maxi = i;
maxv = v;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.operations.TensorOperations;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.google.common.base.Preconditions;
import java.util.*;
Expand Down Expand Up @@ -122,9 +121,9 @@ public AbstractTensor forward(
int batchSize = input.shape().first();

try (AbstractTensor queryBatch = m.makeFullTensor(batchSize, c.embeddingLength);
AbstractTensor tmpKeyBatch = m.makeFullTensor(batchSize, c.kvLength);
AbstractTensor tmpValBatch = m.makeFullTensor(batchSize, c.kvLength);
AbstractTensor valueBatch = m.makeFullTensor(batchSize, c.embeddingLength)) {
AbstractTensor tmpKeyBatch = m.makeFullTensor(batchSize, c.kvLength);
AbstractTensor tmpValBatch = m.makeFullTensor(batchSize, c.kvLength);
AbstractTensor valueBatch = m.makeFullTensor(batchSize, c.embeddingLength)) {

if (c.isGQA) {
VectorMath.pchunk(0, c.embeddingLength, (chunkStart, chunkLength) -> {
Expand Down Expand Up @@ -203,10 +202,10 @@ public AbstractTensor forward(
AbstractTensor value = valueBatch.slice(bi);

if (key.dType() != tmpKey.dType()) {
try (AbstractTensor tmpKey2 = TensorOperationsProvider.get()
.quantize(tmpKey, key.dType(), 0, c.kvLength);
AbstractTensor tmpVal2 = TensorOperationsProvider.get()
.quantize(tmpVal, val.dType(), 0, c.kvLength)) {
try (AbstractTensor tmpKey2 =
TensorOperationsProvider.get().quantize(tmpKey, key.dType(), 0, c.kvLength);
AbstractTensor tmpVal2 =
TensorOperationsProvider.get().quantize(tmpVal, val.dType(), 0, c.kvLength)) {
key.copyFrom(
tmpKey2,
tmpKey2.getOffset(0, c.kvSegmentStart()),
Expand Down Expand Up @@ -292,7 +291,7 @@ public AbstractTensor forward(
}
});

//Attention
// Attention
VectorMath.pfor(c.headStart(), c.headEnd(), h -> {
try (AbstractTensor attn = m.makeFullTensor(1, kvp.shape().first())) {
int xoffset = c.maybeMapToGroupHead(h) * c.headSize;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public AbstractTensor forward(
}

if (reducer.isPresent()) {
Pair<Float, Float> p = reducer.get().apply((float)ss, 0f);
Pair<Float, Float> p = reducer.get().apply((float) ss, 0f);
ss = p.left;
}

Expand All @@ -60,7 +60,7 @@ public AbstractTensor forward(
ss = (1.0 / StrictMath.sqrt(ss));
// normalize and scale
for (int j = offset; j < limit; j++) {
output.set((weightAdjustment + weights.get(0, j)) * ((float)ss * input.get(b, j)), b, j);
output.set((weightAdjustment + weights.get(0, j)) * ((float) ss * input.get(b, j)), b, j);
}
}
return output;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.operations.PanamaTensorOperations;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import java.util.Optional;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -63,7 +62,8 @@ public GemmaModel(
super(inferenceType, config, weights, tokenizer, workingDType, workingQType, modelQType);
// https://github.com/huggingface/transformers/blob/1082361a1978d30db5c3932d1ee08914d74d9697/src/transformers/models/gemma/modeling_gemma.py#L898
// This is the scaling factor for the embedding layer but google's implementation is a is rounded to 16 bits
this.embeddingScalingFactor = FloatConversions.bFloat16ToFloat32(FloatConversions.float32ToBFloat16((float) Math.pow(c.embeddingLength, 0.5)));
this.embeddingScalingFactor = FloatConversions.bFloat16ToFloat32(
FloatConversions.float32ToBFloat16((float) Math.pow(c.embeddingLength, 0.5)));
}

private AbstractTensor wte;
Expand Down Expand Up @@ -126,7 +126,8 @@ protected EmbedInput loadInputWeights() {
AbstractTensor embedding = makeTensor(c.embeddingLength);
AbstractTensor at = wte.slice(true, inputToken);
if (wte.dType() != embedding.dType())
at = TensorOperationsProvider.get().quantize(at, embedding.dType(), c.embeddingSegmentStart(), c.embeddingSegmentLength());
at = TensorOperationsProvider.get()
.quantize(at, embedding.dType(), c.embeddingSegmentStart(), c.embeddingSegmentLength());

embedding.copyFrom(
at,
Expand All @@ -135,7 +136,8 @@ protected EmbedInput loadInputWeights() {
c.embeddingSegmentLength());

// This is important for Gemma, but not for Llama
TensorOperationsProvider.get().scale(embeddingScalingFactor, embedding, c.embeddingSegmentStart(), c.embeddingSegmentLength());
TensorOperationsProvider.get()
.scale(embeddingScalingFactor, embedding, c.embeddingSegmentStart(), c.embeddingSegmentLength());

return embedding;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ protected Optional<Character> maybeDecodeTokenAsCharacter(long id) {

@Override
protected String preProcess(String sentence) {
sentence = sentence.replace(" ", SPIECE_UNDERLINE);
sentence = sentence.replace(" ", SPIECE_UNDERLINE);

return sentence;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
import com.hubspot.jinjava.Jinjava;
import com.hubspot.jinjava.JinjavaConfig;
import com.hubspot.jinjava.lib.fn.ELFunctionDefinition;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* This class also renders the prompt templates of the huggingface model format (using jinja templates)
Expand All @@ -39,7 +38,9 @@ public class PromptSupport {
.build());

static {
jinjava.getGlobalContext().registerFunction(new ELFunctionDefinition("", "raise_exception", PromptSupport.class, "raiseException", String.class));
jinjava.getGlobalContext()
.registerFunction(new ELFunctionDefinition(
"", "raise_exception", PromptSupport.class, "raiseException", String.class));
}

private final TokenizerModel m;
Expand Down Expand Up @@ -150,8 +151,7 @@ public String build() {
"eos_token",
m.eosToken(),
"bos_token",
m.bosToken()
));
m.bosToken()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ public AbstractTensor quantize(DType dType) {

public AbstractTensor quantize(DType dType, boolean force) {

if (!force && (this.shape().first() == 1 || this.dType == dType || this.dType.size() < dType.size())) return this;
if (!force && (this.shape().first() == 1 || this.dType == dType || this.dType.size() < dType.size()))
return this;

if (shape.isSparse()) {
logger.info("Quantizing sparse tensor is not supported");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ private Pair<RandomAccessFile, AbstractTensor> makeKvBuffer(UUID session) {
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();

t = new FloatBufferTensor(fb, s, true);
t = new FloatBufferTensor(fb, s, true);
} else if (model.getWorkingDType() == DType.BF16) {
ShortBuffer sb = raf.getChannel()
.map(FileChannel.MapMode.READ_WRITE, 0, bytes)
Expand All @@ -103,7 +103,6 @@ private Pair<RandomAccessFile, AbstractTensor> makeKvBuffer(UUID session) {
throw new UnsupportedOperationException("Only F32/BF16 is supported for now");
}


return Pair.create(raf, t);

} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ public void scale(float factor, AbstractTensor x, int offset, int length) {
int limit = offset + length;

for (int b = 0; b < x.shape().first(); b++)
for (int i = offset; i < limit; ++i)
x.set(x.get(b, i) * factor, b, i);
for (int i = offset; i < limit; ++i) x.set(x.get(b, i) * factor, b, i);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ public final class PanamaTensorOperations implements TensorOperations {

static final IntVector BF16_BYTE_SHIFT = IntVector.broadcast(IntVector.SPECIES_PREFERRED, 16);


static final IntVector BF16_BYTE_SHIFT_512 = IntVector.broadcast(IntVector.SPECIES_512, 16);
static final FloatVector F32_ROUND_UP_512 = FloatVector.broadcast(FloatVector.SPECIES_512, 0.5f);

Expand Down Expand Up @@ -100,7 +99,7 @@ public void batchDotProduct(
int rowChunkSize) {
Preconditions.checkArgument(a.dims() == 2 && b.dims() == 2 && result.dims() == 2);
Preconditions.checkArgument(a.shape().dim(0) == result.shape().dim(0), "BAD M");
//Preconditions.checkArgument(b.shape().dim(0) == result.shape().dim(1), "BAD N");
// Preconditions.checkArgument(b.shape().dim(0) == result.shape().dim(1), "BAD N");
// Preconditions.checkArgument(a.shape().dim(1) == b.shape().dim(1), "BAD K");

int M = a.shape().dim(0);
Expand Down Expand Up @@ -132,9 +131,11 @@ public void batchDotProduct(
};
case BF16 -> switch (b.dType()) {
case BF16 -> new GemmerBF16(K, a, b, result, aColumnOffset, bColumnOffset);
default -> throw new UnsupportedOperationException(b.dType().name());
default -> throw new UnsupportedOperationException(
b.dType().name());
};
default -> throw new UnsupportedOperationException(a.dType().name() + " " + b.dType().name());
default -> throw new UnsupportedOperationException(
a.dType().name() + " " + b.dType().name());
};

gemm.matmul(0, M, bRowOffset, bRowOffset + N);
Expand Down Expand Up @@ -1466,10 +1467,10 @@ protected int pickKernel(int m0, int m, int n0, int n) {
nc = 4;
kernel(m0, m, 1, n0, n, 4, matmul1x4);
} else {*/
mc = 1;
nc = 1;
kernel(m0, m, 1, n0, n, 1, matmul1x1);
//}
mc = 1;
nc = 1;
kernel(m0, m, 1, n0, n, 1, matmul1x1);
// }

return (mc << 4) | nc;
}
Expand All @@ -1492,7 +1493,6 @@ protected BiIntConsumer initMatmul1x1() {
.lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT)
.reinterpretAsFloats();


ShortVector sb = b.getVector(ShortVector.SPECIES_PREFERRED, j, boffset);
FloatVector vb0 = sb.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0)
.lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT)
Expand Down Expand Up @@ -1717,7 +1717,8 @@ protected BiIntConsumer initMatmul1x1() {
int slen = ShortVector.SPECIES_PREFERRED.length();
for (; aoffset < alim && boffset < blim; aoffset += slen, boffset += slen) {
FloatVector va0 = a.getVector(FloatVector.SPECIES_PREFERRED, i, aoffset);
FloatVector va1 = a.getVector(FloatVector.SPECIES_PREFERRED, i, aoffset + FloatVector.SPECIES_PREFERRED.length());
FloatVector va1 = a.getVector(
FloatVector.SPECIES_PREFERRED, i, aoffset + FloatVector.SPECIES_PREFERRED.length());

ShortVector sb = b.getVector(ShortVector.SPECIES_PREFERRED, j, boffset);
FloatVector vb0 = sb.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0)
Expand Down Expand Up @@ -1835,7 +1836,8 @@ public BFloat16BufferTensor quantizeBF16(FloatBufferTensor ft, final int offset,
.lanewise(VectorOperators.ASHR, BF16_BYTE_SHIFT)
.convertShape(VectorOperators.I2S, ShortVector.SPECIES_PREFERRED, -1);

VectorMask<Short> mask = VectorMask.fromLong(ShortVector.SPECIES_PREFERRED, (1L << FloatVector.SPECIES_PREFERRED.length()) - 1);
VectorMask<Short> mask = VectorMask.fromLong(
ShortVector.SPECIES_PREFERRED, (1L << FloatVector.SPECIES_PREFERRED.length()) - 1);
mask = mask.not(); // Invert the mask to select the second half

var r = r0.blend(r1, mask);
Expand Down Expand Up @@ -2276,7 +2278,6 @@ public Q8ByteBufferTensor quantizeBF16_Q8_arm(BFloat16BufferTensor ft, int offse
return qft;
}


@Override
public void maccumulate(AbstractTensor aBatch, AbstractTensor bBatch, int offset, int limit) {
Preconditions.checkArgument(aBatch.dType() == bBatch.dType());
Expand Down Expand Up @@ -2616,12 +2617,19 @@ void saxpyF32(float alpha, FloatBufferTensor x, FloatBufferTensor y, int xoffset
}

@Override
public void saxpy(AbstractTensor alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit, int batchSize) {
public void saxpy(
AbstractTensor alpha,
AbstractTensor x,
AbstractTensor y,
int xoffset,
int yoffset,
int limit,
int batchSize) {
Preconditions.checkArgument(limit % 2 == 0);

switch (x.dType()) {
case F32:
saxpyF32(alpha, (FloatBufferTensor) x, (FloatBufferTensor)y, xoffset, yoffset, limit, batchSize);
saxpyF32(alpha, (FloatBufferTensor) x, (FloatBufferTensor) y, xoffset, yoffset, limit, batchSize);
break;
case BF16:
switch (y.dType()) {
Expand Down Expand Up @@ -2649,7 +2657,6 @@ public void saxpyF32(
int limit,
int batchSize) {


int upperBound = FloatVector.SPECIES_PREFERRED.loopBound(limit);

// Use Nearest multiple of 4
Expand Down Expand Up @@ -2724,8 +2731,7 @@ public void saxpyBF16F32(
}
}

void saxpyBF16(
float alpha, BFloat16BufferTensor a, BFloat16BufferTensor b, int aoffset, int boffset, int limit) {
void saxpyBF16(float alpha, BFloat16BufferTensor a, BFloat16BufferTensor b, int aoffset, int boffset, int limit) {
int upperBound = ShortVector.SPECIES_PREFERRED.loopBound(limit);
Preconditions.checkArgument(upperBound == limit);

Expand Down Expand Up @@ -2774,9 +2780,7 @@ void saxpyBF16(
}
}


void saxpyBF16F32(
float alpha, BFloat16BufferTensor a, FloatBufferTensor b, int aoffset, int boffset, int limit) {
void saxpyBF16F32(float alpha, BFloat16BufferTensor a, FloatBufferTensor b, int aoffset, int boffset, int limit) {
int upperBound = ShortVector.SPECIES_PREFERRED.loopBound(limit);
Preconditions.checkArgument(upperBound == limit);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ default void saxpy(
int yoffset,
int limit,
int batchSize) {
Preconditions.checkArgument(alpha.shape().last() == x.shape().first() && y.shape().first() == 1);
Preconditions.checkArgument(
alpha.shape().last() == x.shape().first() && y.shape().first() == 1);

for (int i = 0; i < batchSize; i++) {
saxpy(alpha.get(0, i), x.slice(i), y, xoffset, yoffset, limit);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
*/
public class PhysicalCoreExecutor {
private static volatile int physicalCoreCount =
Math.max(1, Runtime.getRuntime().availableProcessors()/2);
Math.max(1, Runtime.getRuntime().availableProcessors() / 2);
private static final AtomicBoolean started = new AtomicBoolean(false);

/**
Expand Down
Loading

0 comments on commit 3b0ecfd

Please sign in to comment.