Skip to content

Commit

Permalink
Fix another small model bug
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Oct 20, 2024
1 parent 1754cf2 commit 7ada7dd
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ protected AbstractModel(
workingMemoryQType = DType.BF16;
}

// Check to make sure the model is big enough to support Q4I8 computations
// If not, fall back to F32
if (modelDType == DType.Q4 && workingMemoryQType == DType.I8 &&
((c.embeddingLength/Q8ByteBufferTensor.BLOCK_SIZE) % (FloatVector.SPECIES_PREFERRED.vectorBitSize()/Float.SIZE) != 0 ||
(c.hiddenLength/Q8ByteBufferTensor.BLOCK_SIZE) % (FloatVector.SPECIES_PREFERRED.vectorBitSize()/Float.SIZE) != 0)){
workingMemoryQType = DType.F32;
}

// Check to make sure the model is big enough to support Q4I8 computations
// If not, fall back to F32
if (modelDType == DType.Q4 && workingMemoryQType == DType.I8 &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
public class TestModels {

static {
System.setProperty("jdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK", "0");
System.setProperty("jdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK", "2");
//System.setProperty("jlama.force_panama_tensor_operations", "true");
}

Expand Down Expand Up @@ -212,15 +212,15 @@ public void MixtralRun() throws Exception {

@Test
public void GemmaRun() throws Exception {
String modelPrefix = "../models/Yi-Coder-1.5B-Chat";
String modelPrefix = "../models/01-ai_Yi-Coder-1.5B-Chat-JQ4";
Assume.assumeTrue(Files.exists(Paths.get(modelPrefix)));
try (WeightLoader weights = SafeTensorSupport.loadWeights(Path.of(modelPrefix).toFile())) {
LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix));
LlamaConfig c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class);
LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.BF16, Optional.empty());
LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.I8, Optional.empty());
String prompt = "Write a java function that takes a list of integers and returns the sum of all the integers in the list.";
PromptContext p = model.promptSupport().get().builder().addUserMessage(prompt).build();
model.generate(UUID.randomUUID(), p, 0.3f, c.contextLength, makeOutHandler());
model.generate(UUID.randomUUID(), p, 0.7f, c.contextLength, makeOutHandler());
}
}

Expand Down

0 comments on commit 7ada7dd

Please sign in to comment.