Skip to content

Commit

Permalink
Fix attention bug and other bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Oct 22, 2023
1 parent 84b2194 commit e1264f3
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,21 @@ public abstract class AbstractModel {
protected final DType modelDType;
protected final DType workingDType;
protected final DType workingQType;
protected final Optional<DType> modelQType;
private static final ThreadLocal<AbstractTensor[]> tmpArray = new ThreadLocal<>();

protected AbstractModel(Config c, WeightLoader w, Tokenizer t, DType workingMemoryDType, DType workingMemoryQType)
{
this(c, w, t, workingMemoryDType, workingMemoryQType, Optional.empty());
}
protected AbstractModel(Config c, WeightLoader w, Tokenizer t, DType workingMemoryDType, DType workingMemoryQType, Optional<DType> modelQType)
{
this.c = c;
this.weights = w;
this.tokenizer = t;
this.modelDType = w.getModelDType();
this.workingDType = workingMemoryDType;
this.modelQType = modelQType;

if (workingMemoryQType != workingMemoryDType) {
boolean supportsQType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public AbstractTensor forward(AbstractTensor input, int position, AbstractTensor

queryAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(query, bias));
keyAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(key, bias));
valueAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(value, bias));
valueAttnBias.ifPresent(bias -> TensorOperationsProvider.get().accumulate(val, bias));

// apply RoPE if present (accounting for huggingface permutation)
// https://github.com/huggingface/transformers/blob/d533465150532b0c5de167b574e59f64c68b1154/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
Expand Down Expand Up @@ -122,10 +122,11 @@ public AbstractTensor forward(AbstractTensor input, int position, AbstractTensor

// with all key-value entries populated, compute attention
// the softmax is incrementally aggregated using the flash attention technique
AbstractTensor k0 = kvMem.slice(0).slice(1);
AbstractTensor k0 = kvMem.slice(0).slice(0);
AbstractTensor v0 = kvMem.slice(0).slice(1);

// value is initially the first value for all heads
value.copyFrom(k0, 0, 0, c.embeddingLength);
value.copyFrom(v0, 0, 0, c.embeddingLength);

//POSITION ZERO
for (int i = 0; i < c.numberOfHeads; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ public AbstractTensor forward(AbstractTensor lnemb) {
buf.set(w1a, i);
});

if (upProjectionWeights != null)
if (upProjectionWeights != null) {
TensorOperationsProvider.get().maccumulate(buf, buf2);
}

//matmul the projection and sum into input
AbstractTensor result = model.makeTensor(model.c.embeddingLength);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public float[] embed(String input) {
long[] encoded = tokenizer.encode(input);
Preconditions.checkArgument(encoded.length < c.contextLength);

AbstractTensor kvmem = makeTensor(c.numberOfLayers, encoded.length, c.embeddingLength * 2); //k and v are concatenated
AbstractTensor kvmem = makeTensor(c.numberOfLayers, encoded.length, 2, c.embeddingLength); // 2 for key and value

int promptLength = encoded.length;
float avgp = 1.0f/promptLength;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ public class LlamaModel extends AbstractModel {

private final AbstractTensor classificationWeights;

public LlamaModel(Config config, WeightLoader weights, Tokenizer tokenizer, DType workingDType, DType workingQType) {
super(config, weights, tokenizer, workingDType, workingQType);
public LlamaModel(Config config, WeightLoader weights, Tokenizer tokenizer, DType workingDType, DType workingQType, DType modelQType) {
super(config, weights, tokenizer, workingDType, workingQType, Optional.ofNullable(modelQType));

DType qType = DType.Q4;

logger.info("Quantizing model with {} - Please hold...", qType);
DType qType = modelQType != null ? modelQType : this.modelDType;

if (modelQType != this.modelDType) {
logger.info("Quantizing model with {} - Please hold...", qType);
}

this.wte = weights.load("model.embed_tokens.weight").quantize(workingDType); //Don't quantize this, it's used for the embedding layer
this.outputLayerNorm = new RMSNorm(this, weights.load("model.norm.weight").quantize(qType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ public String name() {

@Override
public boolean requiresOffHeapTensor() {
return false;
return true;
}

@Override
public float dotProduct(AbstractTensor a, AbstractTensor b, int aoffset, int boffset, int limit) {
Preconditions.checkArgument(limit % 32 == 0);
Preconditions.checkArgument(limit % 2 == 0, "Limit must be a multiple of 2, not" + limit);

return switch (a.dType()) {
case F32 -> switch (b.dType()) {
Expand Down Expand Up @@ -671,11 +671,6 @@ private float dotProductF32Q4_256(FloatBufferTensor a, Q4ByteBufferTensor b, int
return acc.reduceLanes(VectorOperators.ADD);
}

private FloatVector helpF32Q4(FloatVector acc, float scalef, FloatBufferTensor a, Q4ByteBufferTensor b, int aoffset, int boffset) {

return acc;
}

private float dotProductF32Q4_512(FloatBufferTensor a, Q4ByteBufferTensor b, int aoffset, int boffset, int limit) {
Preconditions.checkArgument(
boffset % Q4ByteBufferTensor.BLOCK_SIZE == 0 &&
Expand Down Expand Up @@ -1062,10 +1057,10 @@ void maccumulateF32(FloatBufferTensor a, FloatBufferTensor b) {
FloatVector vb = b.getVector(FloatVector.SPECIES_PREFERRED, i);
a.intoTensor(va.mul(vb), i);
}

// tail
for (; i < a.size(); i++) {
a.set(a.get(i) * b.get(i));
a.set(a.get(i) * b.get(i), i);
}
}

Expand Down Expand Up @@ -1103,7 +1098,7 @@ void accumulateF32(FloatBufferTensor a, FloatBufferTensor b) {

// tail
for (; i < a.size(); i++) {
a.set(a.get(i) + b.get(i));
a.set(a.get(i) + b.get(i), i);
}
}

Expand Down Expand Up @@ -1135,7 +1130,7 @@ void accumulateBF16_256(BFloat16BufferTensor a, BFloat16BufferTensor b) {

// tail
for (; i < a.size(); i++) {
a.set(a.get(i) + b.get(i));
a.set(a.get(i) + b.get(i), i);
}
}

Expand Down Expand Up @@ -1167,7 +1162,7 @@ void accumulateBF16_512(BFloat16BufferTensor a, BFloat16BufferTensor b) {

// tail
for (; i < a.size(); i++) {
a.set(a.get(i) + b.get(i));
a.set(a.get(i) + b.get(i), i);
}
}

Expand All @@ -1193,7 +1188,7 @@ public void scale(float factor, AbstractTensor a, int offset, int length)

public void scaleF32(float factor, FloatBufferTensor a, int offset, int length)
{
int upperBound = FloatVector.SPECIES_PREFERRED.loopBound(offset + length);
int upperBound = FloatVector.SPECIES_PREFERRED.loopBound(length) + offset;
int i = offset;

FloatVector sf = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, factor);
Expand All @@ -1210,7 +1205,7 @@ public void scaleF32(float factor, FloatBufferTensor a, int offset, int length)

public void scaleBF16_512(float factor, BFloat16BufferTensor a, int offset, int length)
{
int upperBound = FloatVector.SPECIES_512.loopBound(offset + length);
int upperBound = FloatVector.SPECIES_512.loopBound(length) + offset;
int i = offset;

FloatVector sf = FloatVector.broadcast(FloatVector.SPECIES_512, factor);
Expand All @@ -1235,7 +1230,7 @@ public void scaleBF16_512(float factor, BFloat16BufferTensor a, int offset, int

public void scaleBF16_256(float factor, BFloat16BufferTensor a, int offset, int length)
{
int upperBound = FloatVector.SPECIES_256.loopBound(offset + length);
int upperBound = FloatVector.SPECIES_256.loopBound(length) + offset;
int i = offset;

FloatVector sf = FloatVector.broadcast(FloatVector.SPECIES_256, factor);
Expand All @@ -1261,7 +1256,7 @@ public void scaleBF16_256(float factor, BFloat16BufferTensor a, int offset, int
@Override
public void saxpy(float alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit) {
Preconditions.checkArgument(x.dType() == y.dType());
Preconditions.checkArgument(limit % 8 == 0);
Preconditions.checkArgument(limit % 2 == 0);

switch (x.dType()) {
case F32: saxpyF32(alpha, (FloatBufferTensor) x, (FloatBufferTensor) y, xoffset, yoffset, limit); break;
Expand Down Expand Up @@ -1370,7 +1365,7 @@ void saxpyBF16_512(float alpha, BFloat16BufferTensor a, BFloat16BufferTensor b,
@Override
public void sxpby(float beta, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit) {
Preconditions.checkArgument(x.dType() == y.dType());
Preconditions.checkArgument(limit % 8 == 0);
Preconditions.checkArgument(limit % 2 == 0);

switch (x.dType()) {
case F32: sxpbyF32(beta, (FloatBufferTensor) x, (FloatBufferTensor) y, xoffset, yoffset, limit); break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public void GPT2Run() throws IOException {
String prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, " +
"previously unexplored valley, in the Andes Mountains. " +
"Even more surprising to the researchers was the fact that the unicorns spoke perfect English.";
gpt2.generate(prompt, 0.6f, 256, false, makeOutHandler());
gpt2.generate(prompt, 0.8f, 256, false, makeOutHandler());
}
}

Expand All @@ -67,8 +67,7 @@ public void LlamaRun() throws Exception {
try (SafeTensorIndex weights = SafeTensorIndex.loadWithWeights(Path.of(modelPrefix))) {
LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix));
Config c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class);
LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.I8);

LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.I8, DType.Q4);
String prompt = "Simply put, the theory of relativity states that";
model.generate(prompt, 0.7f, 256, false, makeOutHandler());
}
Expand All @@ -85,10 +84,10 @@ public void TinyLlamaRun() throws Exception {
Weights weights = SafeTensorSupport.readWeights(bb);
LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix));
Config c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class);
LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.F32);
LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.F32, DType.F32);

String prompt = "Lily picked up a flower and gave it to";
model.generate(prompt, 0.9f, 128, false, makeOutHandler());
model.generate(prompt, 0.7f, 128, false, makeOutHandler());
}
}

Expand Down

0 comments on commit e1264f3

Please sign in to comment.