Skip to content

Commit

Permalink
GQA
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Feb 4, 2024
1 parent 726a73d commit 345b569
Show file tree
Hide file tree
Showing 15 changed files with 180 additions and 76 deletions.
3 changes: 3 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ COPY mvnw .
RUN --mount=type=cache,target=/root/.m2 ./mvnw clean package

FROM openjdk:21-slim
RUN apt-get update
RUN apt-get install -y procps curl

COPY inlinerules.json inlinerules.json
COPY run-cli.sh run-cli.sh
COPY conf/logback.xml logback.xml
Expand Down
12 changes: 9 additions & 3 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,25 @@ services:
cpus: 1
command:
- cluster-coordinator
- --threads=4
- --worker-count=16
- --threads=2
- --worker-count=8
- /models/Llama-2-7b-chat-hf-jlama-Q4/
volumes:
- "./models:/models"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8080/ui/index.html"]
interval: 5s
jlama-worker:
image: jlama
restart: always
depends_on:
jlama-coordinator:
condition: service_healthy
environment:
- JLAMA_JVM_ARGS_EXTRA=-Xmx500M -Djava.net.preferIPv4Stack=true -Djlama.use_hostname_as_workerid=true
deploy:
mode: replicated
replicas: 16
replicas: 8
resources:
limits:
cpus: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ public void generate(UUID sessionId, String prompt, String cleanPrompt, float te
if (ntokens > c.contextLength)
ntokens = c.contextLength;

AbstractTensor kvmem = makeTensor(c.getNumberOfLayers(), ntokens, 2, c.embeddingLength); //k and v are last 2 dims
AbstractTensor kvmem = makeTensor(c.getNumberOfLayers(), ntokens, 2, c.kvLength); //k and v are last 2 dims
AbstractTensor logits = makeTensor(c.vocabularySize);

int[] promptTokens = new int[useEOS ? (1 + encoded.length + 1) : (1 + encoded.length)];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,13 @@

import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.operations.TensorOperations;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;

import com.github.tjake.jlama.util.Pair;
import com.google.common.base.Preconditions;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.*;
import java.util.function.Consumer;
import java.util.function.Function;

public class CausalSelfAttention {
private final AbstractModel m;
Expand All @@ -34,7 +26,6 @@ public class CausalSelfAttention {

private final AbstractTensor outputProjectionWeights;

private final Optional<float[][]> ropeFrequencies;

private final float attentionScale;

Expand All @@ -45,21 +36,21 @@ public class CausalSelfAttention {


public CausalSelfAttention(AbstractModel m, AbstractTensor queryAttnWeights, AbstractTensor keyAttnWeights, AbstractTensor valueAttnWeights,
AbstractTensor outputProjectionWeights, Optional<float[][]> ropeFrequencies)
AbstractTensor outputProjectionWeights)
{
this(m, Optional.empty(), Optional.empty(), Optional.empty(), queryAttnWeights, keyAttnWeights, valueAttnWeights, Optional.empty(), outputProjectionWeights, ropeFrequencies);
this(m, Optional.empty(), Optional.empty(), Optional.empty(), queryAttnWeights, keyAttnWeights, valueAttnWeights, Optional.empty(), outputProjectionWeights);
}

public CausalSelfAttention(AbstractModel m, AbstractTensor queryAttnBias, AbstractTensor keyAttnBias, AbstractTensor valueAttnBias,
AbstractTensor queryAttnWeights, AbstractTensor keyAttnWeights, AbstractTensor valueAttnWeights,
AbstractTensor outputProjectionBias, AbstractTensor outputProjectionWeights, Optional<float[][]> ropeFrequencies) {
this(m, Optional.of(queryAttnBias), Optional.of(keyAttnBias), Optional.of(valueAttnBias), queryAttnWeights, keyAttnWeights, valueAttnWeights, Optional.of(outputProjectionBias), outputProjectionWeights, ropeFrequencies);
AbstractTensor outputProjectionBias, AbstractTensor outputProjectionWeights) {
this(m, Optional.of(queryAttnBias), Optional.of(keyAttnBias), Optional.of(valueAttnBias), queryAttnWeights, keyAttnWeights, valueAttnWeights, Optional.of(outputProjectionBias), outputProjectionWeights);
}


public CausalSelfAttention(AbstractModel m, Optional<AbstractTensor> queryAttnBias, Optional<AbstractTensor> keyAttnBias, Optional<AbstractTensor> valueAttnBias,
AbstractTensor queryAttnWeights, AbstractTensor keyAttnWeights, AbstractTensor valueAttnWeights,
Optional<AbstractTensor> outputProjectionBias, AbstractTensor outputProjectionWeights, Optional<float[][]> ropeFrequencies)
Optional<AbstractTensor> outputProjectionBias, AbstractTensor outputProjectionWeights)
{
this.m = m;
this.c = m.c;
Expand All @@ -75,7 +66,6 @@ public CausalSelfAttention(AbstractModel m, Optional<AbstractTensor> queryAttnBi

this.attentionScale = (float) (1.0 / StrictMath.sqrt(c.headSize));

this.ropeFrequencies = ropeFrequencies;
this.flashAttnHeads = new float[c.contextLength][c.numberOfHeads];

this.qkvResults = new AbstractTensor[3];
Expand All @@ -88,8 +78,8 @@ public AbstractTensor forward(AbstractTensor input, int position, AbstractTensor
try (AbstractTensor flashAttn_m = m.makeTensor(c.numberOfHeads);
AbstractTensor flashAttn_l = m.makeTensor(c.numberOfHeads);
AbstractTensor query = m.makeFullTensor(c.embeddingLength);
AbstractTensor tmpKey = m.makeFullTensor(c.embeddingLength);
AbstractTensor tmpVal = m.makeFullTensor(c.embeddingLength);
AbstractTensor tmpKey = m.makeFullTensor(c.kvLength);
AbstractTensor tmpVal = m.makeFullTensor(c.kvLength);
AbstractTensor value = m.makeFullTensor(c.embeddingLength))
{
//This is our memory of the key and value vectors for each position
Expand All @@ -98,63 +88,105 @@ public AbstractTensor forward(AbstractTensor input, int position, AbstractTensor
AbstractTensor key = kvp.slice(0);
AbstractTensor val = kvp.slice(1);


qkvResults[0] = query;
qkvResults[1] = tmpKey;
qkvResults[2] = tmpVal;

// compute the query vector
VectorMath.pchunk(0, c.embeddingLength, (chunkStart, chunkLength) -> {
TensorOperationsProvider.get().dotProductBatchChunk(qkvResults, input, qkvWeights, c.embeddingSegmentStart(), c.embeddingSegmentLength(), chunkStart, chunkLength);
});

if (c.isGQA) {
VectorMath.pchunk(0, c.embeddingLength, (chunkStart, chunkLength) -> {
TensorOperationsProvider.get().dotProductChunk(query, input, queryAttnWeights, c.embeddingSegmentStart(), c.embeddingSegmentLength(), chunkStart, chunkLength);
});
VectorMath.pchunk(0, c.kvLength, (chunkStart, chunkLength) -> {
TensorOperationsProvider.get().dotProductChunk(tmpKey, input, keyAttnWeights, c.embeddingSegmentStart(), c.embeddingSegmentLength(), chunkStart, chunkLength);
TensorOperationsProvider.get().dotProductChunk(tmpVal, input, valueAttnWeights, c.embeddingSegmentStart(), c.embeddingSegmentLength(), chunkStart, chunkLength);
});
} else {
qkvResults[0] = query;
qkvResults[1] = tmpKey;
qkvResults[2] = tmpVal;

// compute the query vector
VectorMath.pchunk(0, c.embeddingLength, (chunkStart, chunkLength) -> {
TensorOperationsProvider.get().dotProductBatchChunk(qkvResults, input, qkvWeights, c.embeddingSegmentStart(), c.embeddingSegmentLength(), chunkStart, chunkLength);
});
}
// For distributed sum of tensor
tensorReducer.ifPresent(func -> func.accept(List.of(query, tmpKey, tmpVal)));

queryAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(query, bias, c.embeddingSegmentStart(), c.embeddingSegmentLength()));
keyAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(tmpKey, bias, c.embeddingSegmentStart(), c.embeddingSegmentLength()));
valueAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(tmpVal, bias, c.embeddingSegmentStart(), c.embeddingSegmentLength()));
keyAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(tmpKey, bias, c.kvSegmentStart(), c.kvSegmentLength()));
valueAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(tmpVal, bias, c.kvSegmentStart(), c.kvSegmentLength()));

key.copyFrom(tmpKey, tmpKey.getOffset(c.embeddingSegmentStart()), key.getOffset(c.embeddingSegmentStart()), c.embeddingSegmentLength());
val.copyFrom(tmpVal, tmpVal.getOffset(c.embeddingSegmentStart()), val.getOffset(c.embeddingSegmentStart()), c.embeddingSegmentLength());
key.copyFrom(tmpKey, tmpKey.getOffset(c.kvSegmentStart()), key.getOffset(c.kvSegmentStart()), c.kvSegmentLength());
val.copyFrom(tmpVal, tmpVal.getOffset(c.kvSegmentStart()), val.getOffset(c.kvSegmentStart()), c.kvSegmentLength());

// apply RoPE if present (accounting for huggingface permutation)
// https://github.com/huggingface/transformers/blob/d533465150532b0c5de167b574e59f64c68b1154/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
ropeFrequencies.ifPresent(rf -> {
c.ropeFreqs.ifPresent(rf -> {
int headPiece = c.headSize / 2;
int poffset = position * headPiece;
// apply RoPE rotation to the q and k vectors for each head
for (int h = c.headStart(); h < c.headEnd(); h++) {
// get the q and k vectors for this head
int offset = h * c.headSize;
// rotate q and k by the freq theta and freq r
for (int i = offset; i < (offset + headPiece); i++) {
float q0 = query.get(i);
float q1 = query.get(i + headPiece); //hf permutation is 0,64,1,65 etc...
float k0 = key.get(i);
float k1 = key.get(i + headPiece);
float[] f = rf[poffset + i];
float fcr = f[0];
float fci = f[1];
query.set(q0 * fcr - q1 * fci, i);
query.set(q0 * fci + q1 * fcr, i + headPiece);
key.set(k0 * fcr - k1 * fci, i);
key.set(k0 * fci + k1 * fcr, i + headPiece);

if (c.isGQA) {
// apply RoPE rotation to the q and k vectors for each head
for (int h = c.headStart(); h < c.headEnd(); h++) {
// get the q vectors for this head
int offset = h * c.headSize;
// rotate q by the freq theta and freq r
for (int i = offset; i < (offset + headPiece); i++) {
float q0 = query.get(i);
float q1 = query.get(i + headPiece); //hf permutation is 0,64,1,65 etc...
float[] f = rf[poffset + i];
float fcr = f[0];
float fci = f[1];
query.set(q0 * fcr - q1 * fci, i);
query.set(q0 * fci + q1 * fcr, i + headPiece);
}
}

//float[][] grf = c.groupRopeFreqs.get();
for (int h = c.groupHeadStart(); h < c.groupHeadEnd(); h++) {
// get the k vectors for this head
int offset = h * c.headSize;
// rotate k by the freq theta and freq r
for (int i = offset; i < (offset + headPiece); i++) {
float k0 = key.get(i);
float k1 = key.get(i + headPiece); //hf permutation is 0,64,1,65 etc...
float[] f = rf[poffset + i];
float fcr = f[0];
float fci = f[1];
key.set(k0 * fcr - k1 * fci, i);
key.set(k0 * fci + k1 * fcr, i + headPiece);
}
}
} else {
// apply RoPE rotation to the q and k vectors for each head
for (int h = c.headStart(); h < c.headEnd(); h++) {
// get the q and k vectors for this head
int offset = h * c.headSize;
// rotate q and k by the freq theta and freq r
for (int i = offset; i < (offset + headPiece); i++) {
float q0 = query.get(i);
float q1 = query.get(i + headPiece); //hf permutation is 0,64,1,65 etc...
float k0 = key.get(i);
float k1 = key.get(i + headPiece);
float[] f = rf[poffset + i];
float fcr = f[0];
float fci = f[1];
query.set(q0 * fcr - q1 * fci, i);
query.set(q0 * fci + q1 * fcr, i + headPiece);
key.set(k0 * fcr - k1 * fci, i);
key.set(k0 * fci + k1 * fcr, i + headPiece);
}
}
}
});

// with all key-value entries populated, compute attention
// the softmax is incrementally aggregated using the flash attention technique
AbstractTensor k0 = kvMem.slice(true, 0).slice(0);
AbstractTensor k0 = kvMem.slice(true,0).slice(0);
AbstractTensor v0 = kvMem.slice(true,0).slice(1);

// value is initially the position 0 value for all heads
value.copyFrom(v0, v0.getOffset(c.embeddingSegmentStart()), value.getOffset(c.embeddingSegmentStart()), c.embeddingSegmentLength());

//POSITION ZERO
for (int i = c.headStart(); i < c.headEnd(); i++) {
float a = TensorOperationsProvider.get().dotProduct(query, k0, i * c.headSize, i * c.headSize, c.headSize) * attentionScale;
value.copyFrom(v0, v0.getOffset(c.maybeMapToGroupHead(i) * c.headSize), value.getOffset(i * c.headSize), c.headSize);
float a = TensorOperationsProvider.get().dotProduct(query, k0, i * c.headSize, c.maybeMapToGroupHead(i) * c.headSize, c.headSize) * attentionScale;
flashAttn_m.set(a, i);
flashAttn_l.set(1, i);
}
Expand All @@ -165,7 +197,7 @@ public AbstractTensor forward(AbstractTensor input, int position, AbstractTensor
//KEY OFFSET
AbstractTensor kk = kvMem.slice(true, i + 1).slice(0);
for(int h = c.headStart(); h < c.headEnd(); h++){
flashAttnHeads[i][h] = TensorOperationsProvider.get().dotProduct(query, kk, h * c.headSize, h * c.headSize, c.headSize) * attentionScale;
flashAttnHeads[i][h] = TensorOperationsProvider.get().dotProduct(query, kk, h * c.headSize, c.maybeMapToGroupHead(h) * c.headSize, c.headSize) * attentionScale;
}
});

Expand All @@ -177,12 +209,12 @@ public AbstractTensor forward(AbstractTensor input, int position, AbstractTensor
float a = flashAttnHeads[i][h];
if (a > flashAttn_m.get(h)) {
float e = (float) Math.exp(flashAttn_m.get(h) - a);
TensorOperationsProvider.get().sxpby(e, vv, value, (h * c.headSize), h * c.headSize, c.headSize);
TensorOperationsProvider.get().sxpby(e, vv, value, c.maybeMapToGroupHead(h) * c.headSize, h * c.headSize, c.headSize);
flashAttn_l.set(1 + e * flashAttn_l.get(h), h);
flashAttn_m.set(a, h);
} else {
float e = (float) Math.exp(a - flashAttn_m.get(h));
TensorOperationsProvider.get().saxpy(e, vv, value, (h * c.headSize), h * c.headSize, c.headSize);
TensorOperationsProvider.get().saxpy(e, vv, value, c.maybeMapToGroupHead(h) * c.headSize, h * c.headSize, c.headSize);
flashAttn_l.set(flashAttn_l.get(h) + e, h);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ public BertConfig( @JsonProperty("max_position_embeddings") int contextLength,
@JsonProperty("num_hidden_layers") int numberOfLayers,
@JsonProperty("layer_norm_eps") float layerNormEps,
@JsonProperty("vocab_size") int vocabularySize) {
super(contextLength, embeddingLength, hiddenLength, numberOfHeads, numberOfLayers, layerNormEps, vocabularySize, 0, 0);
super(contextLength, embeddingLength, hiddenLength, numberOfHeads, numberOfHeads, numberOfLayers, layerNormEps, vocabularySize, 0, 0, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() {
CausalSelfAttention attention = new CausalSelfAttention(this,
keyBias, queryBias, valueBias,
keyWeight, queryWeight, valueWeight,
outputBias, outputWeight, Optional.empty());
outputBias, outputWeight);

prefix = b;
MLPBlock mlpBlock = new MLPBlock(this, ActivationFunction.Type.GELU,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ public GPT2Config( @JsonProperty("n_ctx") int contextLength,
@JsonProperty("vocab_size") int vocabularySize,
@JsonProperty("bos_token_id") int bosToken,
@JsonProperty("eos_token_id") int eosToken) {
super(contextLength, embeddingLength, embeddingLength * 4, numberOfHeads, numberOfLayers, layerNormEps, vocabularySize, bosToken, eosToken);
super(contextLength, embeddingLength, embeddingLength * 4, numberOfHeads, numberOfHeads, numberOfLayers, layerNormEps, vocabularySize, bosToken, eosToken, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() {
AbstractTensor[] attnWeights = weights.load(prefix + "c_attn.weight").transpose().split(3, 0);
CausalSelfAttention attention = new CausalSelfAttention(this, attnBias[0], attnBias[1], attnBias[2],
attnWeights[0], attnWeights[1], attnWeights[2],
weights.load(prefix + "c_proj.bias"), weights.load(prefix + "c_proj.weight").transpose(), Optional.empty());
weights.load(prefix + "c_proj.bias"), weights.load(prefix + "c_proj.weight").transpose());

prefix = b + "mlp.";
MLPBlock mlpBlock = new MLPBlock(this, ActivationFunction.Type.GELU,
Expand Down
Loading

0 comments on commit 345b569

Please sign in to comment.