Skip to content

Commit

Permalink
Support Mistral Nemo, Llama 3.1, better bf16 support
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Jul 29, 2024
1 parent 83c61a9 commit e7dc2a5
Show file tree
Hide file tree
Showing 22 changed files with 261 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ public static float bFloat16ToFloat32(short raw) {
}

public static short float32ToBFloat16(float n) {
// if (true)
// return (short) ((Float.floatToRawIntBits(n) >> 16) & 0xffff);
int nbits = Float.floatToRawIntBits(n);
// 32 bits has 1 sign bit, 8 exponent bits, 23 mantissa bits
int s = (nbits >>> 16) & 0x8000;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ public class VectorMath {
public static void pfor(int start, int end, IntConsumer action) {
PhysicalCoreExecutor.instance
.get()
.execute(() -> IntStream.range(start, end).parallel().forEach(action));
.execute(() -> IntStream.range(start, end)
.parallel()
.forEach(action));
}

public static void pchunk(int offset, int length, BiIntConsumer action) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.github.tjake.jlama.tensor.Q8ByteBufferTensor;
import com.github.tjake.jlama.tensor.TensorShape;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.github.tjake.jlama.util.DebugSupport;
import com.github.tjake.jlama.util.Pair;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
Expand All @@ -44,11 +45,11 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static com.github.tjake.jlama.util.DebugSupport.debug;

public abstract class AbstractModel implements Generator {
private static final Logger logger = LoggerFactory.getLogger(AbstractModel.class);

static final boolean DEBUG = false;

public enum InferenceType {
INPUT_TO_EMBEDDING(true, false, false),
OUTPUT_TO_TOKEN(false, true, false),
Expand Down Expand Up @@ -113,6 +114,11 @@ protected AbstractModel(
workingMemoryQType = DType.F32;
}

// FIXME: This is a hack to support Avoid Q8BF16 evals
if (modelDType == DType.BF16 && workingMemoryQType != DType.BF16 && modelQType.isEmpty()) {
workingMemoryQType = DType.BF16;
}

if (workingMemoryQType != workingMemoryDType) {
boolean supportsQType;
AbstractTensor tmp = makeTensor(Q8ByteBufferTensor.BLOCK_SIZE);
Expand Down Expand Up @@ -203,6 +209,9 @@ public AbstractTensor forward(
Optional<Consumer<List<AbstractTensor>>> tensorReducer) {
AbstractTensor embedding = embedInput.inputTokenToEmbedding(token_id, pos);

debug("EMBEDDING TOKEN", token_id);
debug("TOKEN POSITION", pos);

for (int i = c.layerStart(); i < c.layerEnd(); i++) {
AbstractTensor kvlayer = kvbuf.slice(true, i);
AbstractTensor ref = embedding; // reference so we can free
Expand All @@ -217,7 +226,6 @@ protected AbstractTensor batchForwardSlow(int[] token_ids, int startPos, Abstrac
AbstractTensor last = null;
for (int i = 0; i < token_ids.length; i++) {
if (last != null) last.close();

last = forward(token_ids[i], startPos + i, kvbuf);
}

Expand Down Expand Up @@ -329,7 +337,7 @@ public Response generate(
long start = System.currentTimeMillis();
long promptStart = start;
// Batch Process Prompt
AbstractTensor last = DEBUG
AbstractTensor last = DebugSupport.isDebug()
? batchForwardSlow(promptTokens, startPos, kvmem)
: batchForward(promptTokens, startPos, kvmem);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@

import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.BFloat16BufferTensor;
import com.github.tjake.jlama.tensor.FloatBufferTensor;
import com.github.tjake.jlama.tensor.operations.TensorOperations;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.google.common.base.Preconditions;
import java.util.*;
import java.util.function.Consumer;

import static com.github.tjake.jlama.util.DebugSupport.debug;

public class CausalSelfAttention {
private final AbstractModel m;
private final Config c;
Expand All @@ -40,6 +46,8 @@ public class CausalSelfAttention {
private final AbstractTensor outputProjectionWeights;

private final float attentionScale;
private final int attentionLength;
private final boolean attentionQVSizeMismatch;

private final AbstractTensor[] qkvResults;
private final AbstractTensor[] qkvWeights;
Expand Down Expand Up @@ -105,7 +113,9 @@ public CausalSelfAttention(

this.outputProjectionBias = outputProjectionBias;
this.outputProjectionWeights = outputProjectionWeights;
this.attentionLength = c.numberOfHeads * c.headSize;

this.attentionQVSizeMismatch = c.embeddingLength != attentionLength;
this.attentionScale = (float) (1.0 / StrictMath.sqrt(c.headSize));

this.qkvResults = new AbstractTensor[3];
Expand All @@ -120,13 +130,13 @@ public AbstractTensor forward(
Preconditions.checkArgument(input.dims() == 2 && input.shape().last() == c.embeddingLength);
int batchSize = input.shape().first();

try (AbstractTensor queryBatch = m.makeFullTensor(batchSize, c.embeddingLength);
try (AbstractTensor queryBatch = m.makeFullTensor(batchSize, attentionLength);
AbstractTensor tmpKeyBatch = m.makeFullTensor(batchSize, c.kvLength);
AbstractTensor tmpValBatch = m.makeFullTensor(batchSize, c.kvLength);
AbstractTensor valueBatch = m.makeFullTensor(batchSize, c.embeddingLength)) {
AbstractTensor valueBatch = m.makeFullTensor(batchSize, attentionLength)) {

if (c.isGQA) {
VectorMath.pchunk(0, c.embeddingLength, (chunkStart, chunkLength) -> {
VectorMath.pchunk(0, attentionLength, (chunkStart, chunkLength) -> {
TensorOperationsProvider.get()
.dotProductChunk(
queryBatch,
Expand Down Expand Up @@ -186,6 +196,10 @@ public AbstractTensor forward(
valueAttnBias.ifPresent(bias -> TensorOperationsProvider.get()
.accumulate(tmpValBatch, bias, c.kvSegmentStart(), c.kvSegmentLength()));

debug("query", queryBatch, 0);
debug("key", tmpKeyBatch, 0);
debug("value", tmpValBatch, 0);

// This is our memory of the key and value vectors for each position
for (int position = startPosition, bi = 0; position < startPosition + batchSize; position++, bi++) {
int finalPostion = position;
Expand Down Expand Up @@ -222,12 +236,12 @@ public AbstractTensor forward(
tmpKey,
tmpKey.getOffset(0, c.kvSegmentStart()),
key.getOffset(0, c.kvSegmentStart()),
c.kvSegmentLength());
c.kvLength);
val.copyFrom(
tmpVal,
tmpVal.getOffset(0, c.kvSegmentStart()),
val.getOffset(0, c.kvSegmentStart()),
c.kvSegmentLength());
c.kvLength);
}

// apply RoPE if present (accounting for huggingface permutation)
Expand All @@ -241,6 +255,11 @@ public AbstractTensor forward(
for (int h = c.headStart(); h < c.headEnd(); h++) {
// get the q vectors for this head
int offset = h * c.headSize;

// skip if we are out of bounds
if (offset >= query.shape().last())
break;

int goffset = c.maybeMapToGroupHead(h) * c.headSize;
// rotate q by the freq theta and freq r
for (int i = offset, g = goffset; i < (offset + headPiece); i++, g++) {
Expand All @@ -257,6 +276,8 @@ public AbstractTensor forward(
for (int h = c.groupHeadStart(); h < c.groupHeadEnd(); h++) {
// get the k vectors for this head
int offset = h * c.headSize;
if (offset >= key.shape().last())
break;
// rotate k by the freq theta and freq r
for (int i = offset; i < (offset + headPiece); i++) {
float k00 = key.get(0, i);
Expand Down Expand Up @@ -289,14 +310,20 @@ public AbstractTensor forward(
}
}
}
debug("query+rope", query, finalPostion);
debug("key+rope", key, finalPostion);
});


// Attention
VectorMath.pfor(c.headStart(), c.headEnd(), h -> {
try (AbstractTensor attn = m.makeFullTensor(1, kvp.shape().first())) {
int xoffset = c.maybeMapToGroupHead(h) * c.headSize;
int yoffset = h * c.headSize;

if (yoffset >= query.shape().last())
return;

// compute attention scores by multiplying query and key for every position
TensorOperationsProvider.get()
.batchDotProduct(attn, query, kvp, yoffset, xoffset, c.headSize, 0, finalPostion + 1);
Expand All @@ -312,6 +339,8 @@ public AbstractTensor forward(
});
}

debug("after_attention", valueBatch, 0);

// matmul the projection and sum into input
// input += c_proj_weight @ ybuf + c_proj_bias
AbstractTensor result = m.makeFullTensor(batchSize, c.embeddingLength);
Expand All @@ -323,7 +352,7 @@ public AbstractTensor forward(
vq,
outputProjectionWeights,
c.embeddingSegmentStart(),
c.embeddingSegmentLength(),
attentionQVSizeMismatch ? attentionLength : c.embeddingSegmentLength(),
chunkStart,
chunkSize);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.github.tjake.jlama.model.functions.FeedForward;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.github.tjake.jlama.util.DebugSupport;
import com.github.tjake.jlama.util.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -27,11 +28,14 @@
import java.util.function.BiFunction;
import java.util.function.Consumer;

import static com.github.tjake.jlama.util.DebugSupport.debug;

public class TransformerBlock {

private static final Logger logger = LoggerFactory.getLogger(TransformerBlock.class);

private final AbstractModel model;
final int layerIndex;
final Optional<LayerNorm> preAttentionNorm;
final CausalSelfAttention attention;
final LayerNorm postAttentionNorm;
Expand All @@ -40,11 +44,13 @@ public class TransformerBlock {

public TransformerBlock(
AbstractModel model,
int layerIndex,
LayerNorm preAttentionNorm,
CausalSelfAttention attention,
LayerNorm postAttentionNorm,
FeedForward ffBlock) {
this.model = model;
this.layerIndex = layerIndex;
this.preAttentionNorm = Optional.of(preAttentionNorm);
this.attention = attention;

Expand All @@ -56,11 +62,13 @@ public TransformerBlock(

public TransformerBlock(
AbstractModel model,
int layerIndex,
CausalSelfAttention attention,
LayerNorm postAttentionNorm,
FeedForward ffBlock,
LayerNorm postFFNorm) {
this.model = model;
this.layerIndex = layerIndex;
this.preAttentionNorm = Optional.empty();
this.attention = attention;

Expand All @@ -81,38 +89,44 @@ public AbstractTensor forward(
Optional<BiFunction<Float, Float, Pair<Float, Float>>> normReducer,
Optional<Consumer<List<AbstractTensor>>> tensorReducer) {

if (AbstractModel.DEBUG)
logger.debug("embedding: {}" + embedding);
debug("input_emb", embedding, layerIndex);

AbstractTensor lnemb =
preAttentionNorm.map(ln -> ln.forward(embedding, normReducer)).orElse(embedding);

if (AbstractModel.DEBUG)
logger.debug("lnemb: {}" + lnemb);

debug("ln_emb", lnemb, layerIndex);

AbstractTensor postAttention;
try (AbstractTensor qlnemb = model.maybeQuantize(lnemb)) {
postAttention = attention.forward(qlnemb, position, kvBuffer, tensorReducer);
}

if (AbstractModel.DEBUG)
logger.debug("postAttention: {}" + postAttention);
debug("post_attn", postAttention, layerIndex);

// residual connection
TensorOperationsProvider.get()
.accumulate(
postAttention, embedding, model.c.embeddingSegmentStart(), model.c.embeddingSegmentLength());

debug("post_attn_res", postAttention, layerIndex);

AbstractTensor lnemb2 = postAttentionNorm.forward(postAttention, normReducer);

debug("ln_emb2", lnemb2, layerIndex);

AbstractTensor postFF;
try (AbstractTensor qlnemb2 = model.maybeQuantize(lnemb2)) {
postFF = ffBlock.forward(qlnemb2, tensorReducer);
debug("post_ff", postFF, layerIndex);
}

// residual connection
TensorOperationsProvider.get()
.accumulate(postFF, postAttention, model.c.embeddingSegmentStart(), model.c.embeddingSegmentLength());

debug("post_ff_res", postFF, layerIndex);

// Release any tmp buffers
if (lnemb != embedding) lnemb.close();

Expand All @@ -122,6 +136,7 @@ public AbstractTensor forward(
return postFFNorm
.map(ln -> {
AbstractTensor lnout = ln.forward(postFF, normReducer);
debug("ln_out", lnout, layerIndex);
postFF.close();
return lnout;
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() {
LayerNorm postMlpNorm = new LayerNorm(
this, weights.load(b + "output.LayerNorm.bias"), weights.load(b + "output.LayerNorm.weight"));

transformerBlocks[i] = new TransformerBlock(this, attention, postAttentionNorm, mlpBlock, postMlpNorm);
transformerBlocks[i] = new TransformerBlock(this, i, attention, postAttentionNorm, mlpBlock, postMlpNorm);
}

return transformerBlocks;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() {
weights.load(prefix + "up_proj.weight", c.offset()).quantize(qType)); // w3

transformerBlocks[i] = new TransformerBlock(
this,
this, i,
new RMSNorm(
this,
weights.load(base + "input_layernorm.weight", c.offset())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() {
LayerNorm layerNorm1 = new LayerNorm(this, weights.load(b + "ln_1.bias"), weights.load(b + "ln_1.weight"));
LayerNorm layerNorm2 = new LayerNorm(this, weights.load(b + "ln_2.bias"), weights.load(b + "ln_2.weight"));

transformerBlocks[i] = new TransformerBlock(this, layerNorm1, attention, layerNorm2, mlpBlock);
transformerBlocks[i] = new TransformerBlock(this, i, layerNorm1, attention, layerNorm2, mlpBlock);
}

return transformerBlocks;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.safetensors.Config;

import java.util.List;
import java.util.Map;

public class LlamaConfig extends Config {
Expand All @@ -34,7 +36,7 @@ public LlamaConfig(
@JsonProperty("rms_norm_eps") float layerNormEps,
@JsonProperty("vocab_size") int vocabularySize,
@JsonProperty("bos_token_id") int bosToken,
@JsonProperty("eos_token_id") int eosToken,
@JsonProperty("eos_token_id") Object eosToken,
@JsonProperty("hidden_act") ActivationFunction.Type activationFunction,
@JsonProperty("rope_theta") Double ropeFreqsTheta,
@JsonProperty("rope_scaling") Map<String, String> ropeScaling) {
Expand All @@ -48,9 +50,9 @@ public LlamaConfig(
layerNormEps,
vocabularySize,
bosToken,
eosToken,
eosToken instanceof List ? ((List<Integer>)eosToken).get(((List<Integer>)eosToken).size() - 1) : (Integer) eosToken, //for llama3.1
activationFunction,
ropeFreqsTheta == null ? 10000.0 : ropeFreqsTheta,
ropeScaling == null ? 1.0 : Double.parseDouble(ropeScaling.get("factor")));
ropeScaling == null || !("linear".equals(ropeScaling.get("rope_type"))) ? 1.0 : Double.parseDouble(ropeScaling.get("factor")));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() {
weights.load(prefix + "up_proj.weight", c.offset()).quantize(qType)); // w3

transformerBlocks[i] = new TransformerBlock(
this,
this, i,
new RMSNorm(
this,
weights.load(base + "input_layernorm.weight", c.offset())
Expand Down
Loading

0 comments on commit e7dc2a5

Please sign in to comment.