Skip to content

Commit

Permalink
Bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Sep 15, 2024
1 parent 7dd8e37 commit eba680c
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 24 deletions.
4 changes: 2 additions & 2 deletions .run/TestModels.LlamaRun.run.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<configuration default="false" name="TestModels.LlamaRun" type="JUnit" factoryName="JUnit">
<classpathModifications>
<entry exclude="true" path="$PROJECT_DIR$/jlama-native/target/classes" />
<entry path="$PROJECT_DIR$/jlama-native/target/jlama-native-0.3.1-linux-x86_64.jar" />
<entry path="$PROJECT_DIR$/jlama-native/target/jlama-native-0.4.0-linux-x86_64.jar" />
</classpathModifications>
<module name="jlama-tests" />
<extension name="coverage">
Expand All @@ -17,7 +17,7 @@
<option name="MAIN_CLASS_NAME" value="com.github.tjake.jlama.model.TestModels" />
<option name="METHOD_NAME" value="LlamaRun" />
<option name="TEST_OBJECT" value="method" />
<option name="VM_PARAMETERS" value="-ea --add-modules=jdk.incubator.vector -Djava.library.path=../jlama-native/target/native-lib-only" />
<option name="VM_PARAMETERS" value="-ea --add-modules=jdk.incubator.vector -Djava.library.path=../jlama-native/target/native-lib-only --enable-preview" />
<method v="2">
<option name="Make" enabled="true" />
</method>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ public class VectorMath {
private static final Logger logger = LoggerFactory.getLogger(VectorMath.class);

public static void pfor(int start, int end, IntConsumer action) {
PhysicalCoreExecutor.instance.get().execute(() -> IntStream.range(start, end).parallel().forEach(action));
PhysicalCoreExecutor.instance.get().execute(() -> IntStream.range(start, end)
.parallel()
.forEach(action));
}

public static void pchunk(int offset, int length, BiIntConsumer action) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ public AbstractTensor forward(
bias -> TensorOperationsProvider.get().accumulate(tmpValBatch, bias, dctx.kvSegmentStart, dctx.kvSegmentLength)
);

debug("query", queryBatch, 0);
debug("key", tmpKeyBatch, 0);
debug("value", tmpValBatch, 0);
debug("query", queryBatch, layerIndex);
debug("key", tmpKeyBatch, layerIndex);
debug("value", tmpValBatch, layerIndex);

// This is our memory of the key and value vectors for each position
for (int position = startPosition, bi = 0; position < startPosition + batchSize; position++, bi++) {
Expand Down Expand Up @@ -329,7 +329,7 @@ public AbstractTensor forward(

if (yoffset >= query.shape().last()) return;

try (AbstractTensor attn = m.makeDenseTensor(1, finalPostion + 1)) {
try (AbstractTensor attn = m.makeDenseTensor(1, kvp[0].shape().first() * kvp.length)) { // chunky so the cache isn't thrashed
// compute attention scores by multiplying query and key for every position
// Do this for each page
for (int i = 0; i < kvp.length; i++) {
Expand All @@ -356,7 +356,7 @@ public AbstractTensor forward(
});
}

debug("after_attention", valueBatch, 0);
debug("after_attention", valueBatch, layerIndex);

// matmul the projection and sum into input
// input += c_proj_weight @ ybuf + c_proj_bias
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ class KvBufferPage implements AutoCloseable {

KvBufferPage(KvPageContext pageCtx, String pageId) {
this.pageCtx = pageCtx;
this.pageId = pageId
;
this.pageId = pageId;

if (model.getConfig().workingDirectory().isEmpty()) {
this.raf = null;
Expand Down Expand Up @@ -279,7 +278,7 @@ private AbstractTensor getTensorForPosition(int layerIndex, int position, int in
// Calculate page indices and relative indices
int layerPageIndex = layerIndex / pageContext.layersPerPage;
int contextPageIndex = position / pageContext.contextLengthPerPage;
int relativeLayerIndex = layerPageIndex % pageContext.layersPerPage;
int relativeLayerIndex = layerIndex % pageContext.layersPerPage;
int relativeContextIndex = position % pageContext.contextLengthPerPage;

KvBufferPage page = pages[layerPageIndex][contextPageIndex];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public void batchDotProduct(
) {
Preconditions.checkArgument(a.dims() == 2 && b.dims() == 2 && result.dims() == 2);
Preconditions.checkArgument(a.shape().dim(0) == result.shape().dim(0), "BAD M");
Preconditions.checkArgument(rOffset >= bRowOffset, "Result offset must be >= b row offset");
Preconditions.checkArgument(rOffset == 0 || rOffset >= bRowOffset, "Result offset must be >= b row offset");
// Preconditions.checkArgument(b.shape().dim(0) == result.shape().dim(1), "BAD N");
// This check breaks for GQA
// Preconditions.checkArgument(a.shape().dim(1) == b.shape().dim(1), "BAD K" + a.shape() + " " + b.shape() + " "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public class TestModels {

static {
System.setProperty("jdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK", "0");
// System.setProperty("jlama.force_panama_tensor_operations", "true");
//System.setProperty("jlama.force_panama_tensor_operations", "true");
}

private static final Logger logger = LoggerFactory.getLogger(TestModels.class);
Expand Down Expand Up @@ -282,20 +282,19 @@ public void testQuantize() throws Exception {

@Test
public void TinyLlamaRun() throws Exception {
String modelPrefix = "models/TinyLLama";
String modelPrefix = "../models/TinyLlama-1.1B-Chat-v1.0-jlama-Q4";
Assume.assumeTrue(Files.exists(Paths.get(modelPrefix)));

try (RandomAccessFile sc = new RandomAccessFile(modelPrefix + "/model.safetensors", "r")) {
ByteBuffer bb = sc.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, sc.length());
AbstractModel model = ModelSupport.loadModel(new File(modelPrefix), DType.F32, DType.I8);

Weights weights = SafeTensorSupport.readWeights(bb);
LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix));
Config c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class);
LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.F32, Optional.of(DType.F32));

String prompt = "Lily picked up a flower and gave it to";
model.generate(UUID.randomUUID(), PromptContext.of(prompt), 0.7f, 128, makeOutHandler());
}
String prompt = "What is the best season to plant avocados?";
PromptContext promptContext = model.promptSupport()
.get()
.builder()
.addSystemMessage("You are a helpful chatbot who writes short responses.")
.addUserMessage(prompt)
.build();
model.generate(UUID.randomUUID(), promptContext, 0.0f, 256, makeOutHandler());
}

@Test
Expand Down

0 comments on commit eba680c

Please sign in to comment.