Skip to content

Commit

Permalink
Add basic paged attention for kv cache (no copy on write)
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Sep 15, 2024
1 parent 090178c commit 946a469
Show file tree
Hide file tree
Showing 18 changed files with 473 additions and 215 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -190,8 +189,8 @@ protected AbstractTensor maybeQuantize(AbstractTensor t) {
return t2;
}

protected AbstractTensor forward(int token_id, int pos, AbstractTensor kvbuf) {
return forward(token_id, pos, kvbuf, Optional.empty(), Optional.empty());
protected AbstractTensor forward(int token_id, int pos, KvBufferCache.KvBuffer kvbuf) {
return forward(token_id, pos, kvbuf, Optional.empty());
}

/**
Expand All @@ -207,8 +206,7 @@ protected AbstractTensor forward(int token_id, int pos, AbstractTensor kvbuf) {
public AbstractTensor forward(
int token_id,
int pos,
AbstractTensor kvbuf,
Optional<BiFunction<Float, Float, Pair<Float, Float>>> normReducer,
KvBufferCache.KvBuffer kvbuf,
Optional<Consumer<List<AbstractTensor>>> tensorReducer
) {
AbstractTensor embedding = embedInput.inputTokenToEmbedding(token_id, pos);
Expand All @@ -217,16 +215,15 @@ public AbstractTensor forward(
debug("TOKEN POSITION", pos);

for (int i = c.dctx().layerStart; i < c.dctx().layerEnd; i++) {
AbstractTensor kvlayer = kvbuf.slice(true, i);
AbstractTensor ref = embedding; // reference so we can free
embedding = transformerBlocks[i].forward(embedding, pos, kvlayer, normReducer, tensorReducer);
embedding = transformerBlocks[i].forward(embedding, pos, kvbuf, tensorReducer);
ref.close();
}

return embedding;
}

protected AbstractTensor batchForwardSlow(int[] token_ids, int startPos, AbstractTensor kvbuf) {
protected AbstractTensor batchForwardSlow(int[] token_ids, int startPos, KvBufferCache.KvBuffer kvbuf) {
AbstractTensor last = null;
for (int i = 0; i < token_ids.length; i++) {
if (last != null) last.close();
Expand All @@ -236,13 +233,12 @@ protected AbstractTensor batchForwardSlow(int[] token_ids, int startPos, Abstrac
return last;
}

protected AbstractTensor batchForward(int[] token_ids, int startPos, AbstractTensor kvbuf) {
protected AbstractTensor batchForward(int[] token_ids, int startPos, KvBufferCache.KvBuffer kvbuf) {

AbstractTensor embedding = embedInput.batchInputsToEmbeddings(token_ids, startPos);
for (int i = c.dctx().layerStart; i < c.dctx().layerEnd; i++) {
AbstractTensor kvlayer = kvbuf.slice(true, i);
AbstractTensor ref = embedding; // reference so we can free
embedding = transformerBlocks[i].forward(embedding, startPos, kvlayer, Optional.empty(), Optional.empty());
embedding = transformerBlocks[i].forward(embedding, startPos, kvbuf, Optional.empty());
ref.close();
}

Expand Down Expand Up @@ -303,11 +299,10 @@ public Response generate(
encoded = Arrays.copyOfRange(encoded, 1, encoded.length);
}

Preconditions.checkArgument(encoded.length < c.contextLength);
Preconditions.checkArgument(encoded.length < c.contextLength && encoded.length < ntokens, "Prompt exceeds max tokens");

AbstractTensor kvmem = kvBufferCache.getKvBuffer(sessionId); // k and v for context window
Integer startPos = (Integer) kvmem.getMetadata(KvBufferCache.TOKEN_COUNT); // Number of tokens in the buffer
if (startPos == null) startPos = 0;
KvBufferCache.KvBuffer kvmem = kvBufferCache.getKvBuffer(sessionId); // k and v for context window
int startPos = kvmem.getCurrentContextPosition(); // Number of tokens in the buffer

logger.debug("Starting at token {} for session {}", startPos, sessionId);

Expand Down Expand Up @@ -366,7 +361,7 @@ public Response generate(
if (logger.isTraceEnabled()) logger.trace("Sampled token {} with temperature {}", next, temperature);
output.close();

kvmem.setMetadata(KvBufferCache.TOKEN_COUNT, i);
kvmem.incrementContextPosition();

// Model may tell us it's done
if (c.eosTokens.contains(next)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.KvBufferCache;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.google.common.base.Preconditions;
import java.util.*;
Expand All @@ -28,6 +29,7 @@
public class CausalSelfAttention {
private final AbstractModel m;
private final Config c;
private final int layerIndex;
private final DistributedContext dctx;
private final Optional<AbstractTensor> queryAttnBias;
private final Optional<AbstractTensor> keyAttnBias;
Expand All @@ -44,20 +46,20 @@ public class CausalSelfAttention {

private final float attentionScale;
private final int attentionLength;
private final boolean attentionQVSizeMismatch;

private final AbstractTensor[] qkvResults;
private final AbstractTensor[] qkvWeights;

public CausalSelfAttention(
AbstractModel m,
int layerIndex,
AbstractTensor queryAttnWeights,
AbstractTensor keyAttnWeights,
AbstractTensor valueAttnWeights,
AbstractTensor outputProjectionWeights
) {
this(
m,
m, layerIndex,
Optional.empty(),
Optional.empty(),
Optional.empty(),
Expand All @@ -71,6 +73,7 @@ public CausalSelfAttention(

public CausalSelfAttention(
AbstractModel m,
int layerIndex,
AbstractTensor queryAttnBias,
AbstractTensor keyAttnBias,
AbstractTensor valueAttnBias,
Expand All @@ -82,6 +85,7 @@ public CausalSelfAttention(
) {
this(
m,
layerIndex,
Optional.of(queryAttnBias),
Optional.of(keyAttnBias),
Optional.of(valueAttnBias),
Expand All @@ -95,6 +99,7 @@ public CausalSelfAttention(

public CausalSelfAttention(
AbstractModel m,
int layerIndex,
Optional<AbstractTensor> queryAttnBias,
Optional<AbstractTensor> keyAttnBias,
Optional<AbstractTensor> valueAttnBias,
Expand All @@ -105,6 +110,7 @@ public CausalSelfAttention(
AbstractTensor outputProjectionWeights
) {
this.m = m;
this.layerIndex = layerIndex;
this.c = m.c;
this.dctx = m.c.dctx();
this.queryAttnBias = queryAttnBias;
Expand All @@ -118,7 +124,6 @@ public CausalSelfAttention(
this.outputProjectionWeights = outputProjectionWeights;
this.attentionLength = c.numberOfHeads * c.headSize;

this.attentionQVSizeMismatch = c.embeddingLength != attentionLength;
this.attentionScale = (float) (1.0 / StrictMath.sqrt(c.headSize));

this.qkvResults = new AbstractTensor[3];
Expand All @@ -128,7 +133,7 @@ public CausalSelfAttention(
public AbstractTensor forward(
AbstractTensor input,
int startPosition,
AbstractTensor kvMem,
KvBufferCache.KvBuffer kvMem,
Optional<Consumer<List<AbstractTensor>>> tensorReducer
) {
Preconditions.checkArgument(input.dims() == 2 && input.shape().last() == c.embeddingLength);
Expand Down Expand Up @@ -214,11 +219,12 @@ public AbstractTensor forward(
for (int position = startPosition, bi = 0; position < startPosition + batchSize; position++, bi++) {
int finalPostion = position;

AbstractTensor kvp = kvMem.slice(true, 0);
AbstractTensor vvp = kvMem.slice(true, 1);
AbstractTensor key = kvMem.getKeyTensorForPosition(layerIndex, position);
AbstractTensor val = kvMem.getValTensorForPosition(layerIndex, position);

AbstractTensor[] kvp = kvMem.getKeyTensorsUptoPosition(layerIndex, position);
AbstractTensor[] vvp = kvMem.getValTensorsUptoPosition(layerIndex, position);

AbstractTensor key = kvp.slice(position);
AbstractTensor val = vvp.slice(position);

AbstractTensor tmpKey = tmpKeyBatch.slice(bi);
AbstractTensor tmpVal = tmpValBatch.slice(bi);
Expand Down Expand Up @@ -318,21 +324,34 @@ public AbstractTensor forward(

// Attention
VectorMath.pfor(dctx.headStart, dctx.headEnd, h -> {
try (AbstractTensor attn = m.makeDenseTensor(1, kvp.shape().first())) {
int xoffset = c.maybeMapToGroupHead(h) * c.headSize;
int yoffset = h * c.headSize;
int xoffset = c.maybeMapToGroupHead(h) * c.headSize;
int yoffset = h * c.headSize;

if (yoffset >= query.shape().last()) return;
if (yoffset >= query.shape().last()) return;

try (AbstractTensor attn = m.makeDenseTensor(1, finalPostion + 1)) {
// compute attention scores by multiplying query and key for every position
TensorOperationsProvider.get().batchDotProduct(attn, query, kvp, yoffset, xoffset, c.headSize, 0, finalPostion + 1);
// Do this for each page
for (int i = 0; i < kvp.length; i++) {
int len = kvp[i].shape().first();
int offset = i * len;
int size = i == kvp.length - 1 ? (finalPostion + 1) - offset : len;
TensorOperationsProvider.get().batchDotProduct(attn, query, kvp[i], yoffset, xoffset, c.headSize, offset, 0, size);
}

TensorOperationsProvider.get().scale(attentionScale, attn, 0, finalPostion + 1);

// softmax the scores to get attention weights, from 0..pos inclusively
VectorMath.softMax(attn, 0, finalPostion + 1);

// apply adjusted attention weights to value vectors
TensorOperationsProvider.get().saxpy(attn, vvp, value, xoffset, yoffset, c.headSize, finalPostion + 1);
// do this for each page
for (int i = 0; i < vvp.length; i++) {
int len = vvp[i].shape().first();
int offset = i * len;
int size = i == vvp.length - 1 ? (finalPostion + 1) - offset : len;
TensorOperationsProvider.get().saxpy(attn, vvp[i], value, xoffset, yoffset, c.headSize, offset, 0, size);
}
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,16 @@ public LayerNorm(AbstractModel m, AbstractTensor bias, AbstractTensor weights) {
}

public AbstractTensor forward(AbstractTensor input) {
return forward(input, Optional.empty());
}

public AbstractTensor forward(AbstractTensor input, Optional<BiFunction<Float, Float, Pair<Float, Float>>> reducer) {
Preconditions.checkArgument(input.shape().dims() == 2);
int size = input.shape().last();
Preconditions.checkArgument(size == m.c.embeddingLength);
return forward(input, 0, m.c.embeddingLength, reducer);
return forward(input, 0, m.c.embeddingLength);
}

public AbstractTensor forward(
AbstractTensor input,
int offset,
int length,
Optional<BiFunction<Float, Float, Pair<Float, Float>>> reducer
int length
) {

int batchSize = input.shape().first();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
package com.github.tjake.jlama.model;

import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.util.Pair;
import java.util.Optional;
import java.util.function.BiFunction;


public class RMSNorm extends LayerNorm {
private final float weightAdjustment;
Expand All @@ -36,8 +34,7 @@ public RMSNorm(AbstractModel m, AbstractTensor weights, float weightAdjustment)
public AbstractTensor forward(
AbstractTensor input,
int offset,
int length,
Optional<BiFunction<Float, Float, Pair<Float, Float>>> reducer
int length
) {

int batchSize = input.shape().first();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import com.github.tjake.jlama.model.functions.FeedForward;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.KvBufferCache;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.github.tjake.jlama.util.Pair;
import java.util.List;
Expand Down Expand Up @@ -78,21 +79,20 @@ public TransformerBlock(
this.postFFNorm = Optional.of(postFFNorm);
}

public AbstractTensor forward(AbstractTensor embedding, int position, AbstractTensor kvBuffer) {
return forward(embedding, position, kvBuffer, Optional.empty(), Optional.empty());
public AbstractTensor forward(AbstractTensor embedding, int position, KvBufferCache.KvBuffer kvBuffer) {
return forward(embedding, position, kvBuffer, Optional.empty());
}

public AbstractTensor forward(
AbstractTensor embedding,
int position,
AbstractTensor kvBuffer,
Optional<BiFunction<Float, Float, Pair<Float, Float>>> normReducer,
KvBufferCache.KvBuffer kvBuffer,
Optional<Consumer<List<AbstractTensor>>> tensorReducer
) {

debug("input_emb", embedding, layerIndex);

AbstractTensor lnemb = preAttentionNorm.map(ln -> ln.forward(embedding, normReducer)).orElse(embedding);
AbstractTensor lnemb = preAttentionNorm.map(ln -> ln.forward(embedding)).orElse(embedding);

debug("ln_emb", lnemb, layerIndex);

Expand All @@ -109,7 +109,7 @@ public AbstractTensor forward(

debug("post_attn_res", postAttention, layerIndex);

AbstractTensor lnemb2 = postAttentionNorm.forward(postAttention, normReducer);
AbstractTensor lnemb2 = postAttentionNorm.forward(postAttention);

debug("ln_emb2", lnemb2, layerIndex);

Expand All @@ -131,7 +131,7 @@ public AbstractTensor forward(
postAttention.close();

return postFFNorm.map(ln -> {
AbstractTensor lnout = ln.forward(postFF, normReducer);
AbstractTensor lnout = ln.forward(postFF);
debug("ln_out", lnout, layerIndex);
postFF.close();
return lnout;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.KvBufferCache;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.util.Arrays;
import java.util.Optional;
import java.util.UUID;

public class BertModel extends AbstractModel {

Expand Down Expand Up @@ -100,6 +102,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() {
AbstractTensor outputWeight = weights.load(prefix + "output.dense.weight");
CausalSelfAttention attention = new CausalSelfAttention(
this,
i,
keyBias,
queryBias,
valueBias,
Expand Down Expand Up @@ -147,8 +150,7 @@ public float[] embed(String input) {
Preconditions.checkArgument(encoded.length < c.contextLength);
float[] outputEmbedding = new float[c.embeddingLength];

try (AbstractTensor kvmem = makeDenseTensor(c.dctx().numberOfLayers, 2, encoded.length, c.embeddingLength)) { // 2 for key and value

try (KvBufferCache.KvBuffer kvmem = kvBufferCache.getKvBuffer(UUID.randomUUID())) {
int promptLength = encoded.length;
float avgp = 1.0f / promptLength;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() {
String prefix = base + "self_attn.";
CausalSelfAttention attention = new CausalSelfAttention(
this,
i,
weights.load(prefix + "q_proj.weight", c.dctx(), true, false).quantize(qType),
weights.load(prefix + "k_proj.weight", c.dctx(), true, false).quantize(qType),
weights.load(prefix + "v_proj.weight", c.dctx(), true, false).quantize(qType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() {
AbstractTensor[] attnWeights = weights.load(prefix + "c_attn.weight").transpose().split(3, 0);
CausalSelfAttention attention = new CausalSelfAttention(
this,
i,
attnBias[0],
attnBias[1],
attnBias[2],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() {
String prefix = base + "self_attn.";
CausalSelfAttention attention = new CausalSelfAttention(
this,
i,
weights.load(prefix + "q_proj.weight", c.dctx(), true, false).quantize(qType),
weights.load(prefix + "k_proj.weight", c.dctx(), true, false).quantize(qType),
weights.load(prefix + "v_proj.weight", c.dctx(), true, false).quantize(qType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() {
String prefix = base + "self_attn.";
CausalSelfAttention attention = new CausalSelfAttention(
this,
i,
weights.load(prefix + "q_proj.weight", c.dctx(), true, false).quantize(qType),
weights.load(prefix + "k_proj.weight", c.dctx(), true, false).quantize(qType),
weights.load(prefix + "v_proj.weight", c.dctx(), true, false).quantize(qType),
Expand Down
Loading

0 comments on commit 946a469

Please sign in to comment.