Skip to content

Commit

Permalink
Fix cli as contructor changed
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Oct 22, 2023
1 parent e1264f3 commit f721980
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.lang.reflect.InvocationTargetException;
import java.nio.file.Path;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;

Expand Down Expand Up @@ -91,8 +92,8 @@ protected AbstractModel loadModel(File model) {
Tokenizer t = modelType.tokenizerClass.getConstructor(Path.class).newInstance(baseDir.toPath());
WeightLoader wl = SafeTensorSupport.loadWeights(baseDir);

return modelType.modelClass.getConstructor(Config.class, WeightLoader.class, Tokenizer.class, DType.class, DType.class)
.newInstance(c, wl, t, workingMemoryType, workingQuantizationType);
return modelType.modelClass.getConstructor(Config.class, WeightLoader.class, Tokenizer.class, DType.class, DType.class, Optional.class)
.newInstance(c, wl, t, workingMemoryType, workingQuantizationType, Optional.ofNullable(modelQuantization));

} catch (IOException | NoSuchMethodException | InvocationTargetException | InstantiationException |
IllegalAccessException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ public abstract class AbstractModel {
protected final Optional<DType> modelQType;
private static final ThreadLocal<AbstractTensor[]> tmpArray = new ThreadLocal<>();

protected AbstractModel(Config c, WeightLoader w, Tokenizer t, DType workingMemoryDType, DType workingMemoryQType)
{
this(c, w, t, workingMemoryDType, workingMemoryQType, Optional.empty());
}
protected AbstractModel(Config c, WeightLoader w, Tokenizer t, DType workingMemoryDType, DType workingMemoryQType, Optional<DType> modelQType)
{
this.c = c;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ public class BertModel extends AbstractModel {

private final TransformerBlock[] transformerBlocks;

public BertModel(Config c, Weights w, Tokenizer tokenizer, DType workingDType, DType workingQType) {
super(c, w, tokenizer, workingDType, workingQType);
public BertModel(Config c, Weights w, Tokenizer tokenizer, DType workingDType, DType workingQType, Optional<DType> modelQType) {
super(c, w, tokenizer, workingDType, workingQType, modelQType);

this.we = w.load("embeddings.word_embeddings.weight");
this.wte = w.load("embeddings.token_type_embeddings.weight");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ public class GPT2Model extends AbstractModel {

private final TransformerBlock[] transformerBlocks;

public GPT2Model(Config c, WeightLoader w, Tokenizer tokenizer, DType workingDType, DType workingQType) {
super(c, w, tokenizer, workingDType, workingQType);
public GPT2Model(Config c, WeightLoader w, Tokenizer tokenizer, DType workingDType, DType workingQType, Optional<DType> modelQType) {
super(c, w, tokenizer, workingDType, workingQType, modelQType);

this.wte = w.load("wte.weight");
this.wpe = w.load("wpe.weight");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ public class LlamaModel extends AbstractModel {

private final AbstractTensor classificationWeights;

public LlamaModel(Config config, WeightLoader weights, Tokenizer tokenizer, DType workingDType, DType workingQType, DType modelQType) {
super(config, weights, tokenizer, workingDType, workingQType, Optional.ofNullable(modelQType));
public LlamaModel(Config config, WeightLoader weights, Tokenizer tokenizer, DType workingDType, DType workingQType, Optional<DType> modelQType) {
super(config, weights, tokenizer, workingDType, workingQType, modelQType);

DType qType = modelQType != null ? modelQType : this.modelDType;
DType qType = modelQType.orElse(this.modelDType);

if (modelQType != this.modelDType) {
if (qType != this.modelDType) {
logger.info("Quantizing model with {} - Please hold...", qType);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ private DType findDType() {
}

//FIXME don't really support B16 atm
return maxType == DType.BF16 ? DType.F32 : maxType;
return maxType == DType.BF16 || maxType == DType.F16 ? DType.F32 : maxType;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ public float dotProduct(AbstractTensor a, AbstractTensor b, int aoffset, int bof
case F32 -> NativeSimd.dot_product_f32(flags, a.getMemorySegment(), aoffset, b.getMemorySegment(), boffset, limit);
case I8 -> NativeSimd.dot_product_f32_q8(flags, a.getMemorySegment(), aoffset, ((Q8ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), boffset, limit);
case Q4 -> NativeSimd.dot_product_f32_q4(flags, a.getMemorySegment(), aoffset, ((Q4ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), boffset, limit);
default -> throw new UnsupportedOperationException();
default -> throw new UnsupportedOperationException(b.dType().name());
};
case I8 -> switch (b.dType()) {
case Q4 -> NativeSimd.dot_product_q8_q4(flags, ((Q8ByteBufferTensor)a).getBlockF().getMemorySegment(), a.getMemorySegment(), aoffset, ((Q4ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), boffset, limit);
//case I8 -> NativeSimd.dot_product_q8(flags, ((Q8ByteBufferTensor)a).getBlockF().getMemorySegment(), a.getMemorySegment(), aoffset, ((Q8ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), boffset, limit);
default -> throw new UnsupportedOperationException();
default -> throw new UnsupportedOperationException(b.dType().name());
};
default -> throw new UnsupportedOperationException();
default -> throw new UnsupportedOperationException(a.dType().name());
};
}

Expand All @@ -88,16 +88,16 @@ public void dotProductChunk(AbstractTensor r, AbstractTensor a, AbstractTensor b
case F32: NativeSimd.dot_product_f32_chunked(flags, r.getMemorySegment(), a.getMemorySegment(), 0, b.getMemorySegment(), 0, limit, chunkStart, chunkSize); break;
case I8: NativeSimd.dot_product_f32_q8_chunked(flags, r.getMemorySegment(), a.getMemorySegment(), 0, ((Q8ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), 0, limit, chunkStart, chunkSize); break;
case Q4: NativeSimd.dot_product_f32_q4_chunked(flags, r.getMemorySegment(), a.getMemorySegment(), 0, ((Q4ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), 0, limit, chunkStart, chunkSize); break;
default: throw new UnsupportedOperationException();
default: throw new UnsupportedOperationException(b.dType().name());
}
break;
case I8: switch (b.dType()) {
case Q4: NativeSimd.dot_product_q8_q4_chunked(flags, r.getMemorySegment(), ((Q8ByteBufferTensor)a).getBlockF().getMemorySegment(), a.getMemorySegment(), 0, ((Q4ByteBufferTensor)b).getBlockF().getMemorySegment(), b.getMemorySegment(), 0, limit, chunkStart, chunkSize);
break;
default: throw new UnsupportedOperationException();
default: throw new UnsupportedOperationException(b.dType().name());
}
break;
default: throw new UnsupportedOperationException();
default: throw new UnsupportedOperationException(a.dType().name());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;

Expand All @@ -51,7 +52,7 @@ public void GPT2Run() throws IOException {
Weights v = SafeTensorSupport.readWeights(bb);
Tokenizer tokenizer = new GPT2Tokenizer(Paths.get(modelPrefix));
Config c = om.readValue(new File(modelPrefix + "/config.json"), GPT2Config.class);
GPT2Model gpt2 = new GPT2Model(c, v, tokenizer, DType.F32, DType.F32);
GPT2Model gpt2 = new GPT2Model(c, v, tokenizer, DType.F32, DType.F32, Optional.of(DType.F32));

String prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, " +
"previously unexplored valley, in the Andes Mountains. " +
Expand All @@ -67,7 +68,7 @@ public void LlamaRun() throws Exception {
try (SafeTensorIndex weights = SafeTensorIndex.loadWithWeights(Path.of(modelPrefix))) {
LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix));
Config c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class);
LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.I8, DType.Q4);
LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.I8, Optional.of(DType.Q4));
String prompt = "Simply put, the theory of relativity states that";
model.generate(prompt, 0.7f, 256, false, makeOutHandler());
}
Expand All @@ -84,7 +85,7 @@ public void TinyLlamaRun() throws Exception {
Weights weights = SafeTensorSupport.readWeights(bb);
LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix));
Config c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class);
LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.F32, DType.F32);
LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.F32, Optional.of(DType.F32));

String prompt = "Lily picked up a flower and gave it to";
model.generate(prompt, 0.7f, 128, false, makeOutHandler());
Expand All @@ -102,7 +103,7 @@ public void BertRun() throws Exception {
Weights weights = SafeTensorSupport.readWeights(bb);
Tokenizer tokenizer = new BertTokenizer(Paths.get(modelPrefix));
Config c = om.readValue(new File(modelPrefix + "/config.json"), BertConfig.class);
BertModel model = new BertModel(c, weights, tokenizer, DType.F32, DType.F32);
BertModel model = new BertModel(c, weights, tokenizer, DType.F32, DType.F32, Optional.of(DType.F32));

String base = "A man is eating food.";
String[] examples = new String[]{
Expand Down

0 comments on commit f721980

Please sign in to comment.