diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ChatCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ChatCommand.java index 8ff35db..330da77 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ChatCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ChatCommand.java @@ -89,7 +89,13 @@ public void run() { builder.addUserMessage(prompt); PromptContext builtPrompt = builder.build(); - Generator.Response r = m.generate(session, builtPrompt, temperature, tokens == null ? m.getConfig().contextLength : tokens, makeOutHandler()); + Generator.Response r = m.generate( + session, + builtPrompt, + temperature, + tokens == null ? m.getConfig().contextLength : tokens, + makeOutHandler() + ); out.println( "\n\n" diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ClusterCoordinatorCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ClusterCoordinatorCommand.java index f9f25e2..134c243 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ClusterCoordinatorCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ClusterCoordinatorCommand.java @@ -19,7 +19,6 @@ import com.github.tjake.jlama.net.Worker; import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.util.PhysicalCoreExecutor; -import com.google.common.util.concurrent.Uninterruptibles; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.builder.SpringApplicationBuilder; @@ -30,24 +29,23 @@ import java.nio.file.Path; import java.util.Optional; -import java.util.concurrent.TimeUnit; @CommandLine.Command(name = "cluster-coordinator", description = "Starts a distributed rest api for a model using cluster workers", abbreviateSynopsis = true) -@SpringBootApplication(scanBasePackages = { "com.github.tjake.jlama.net.openai", "com.github.tjake.jlama.cli.commands", "com.github.tjake.jlama.net.grpc" }) +@SpringBootApplication(scanBasePackages = { "com.github.tjake.jlama.net.openai", "com.github.tjake.jlama.cli.commands", + "com.github.tjake.jlama.net.grpc" }) @SpringBootConfiguration @Configuration public class ClusterCoordinatorCommand extends ModelBaseCommand implements WebMvcConfigurer { - @CommandLine.Option(names = { - "--worker-count" }, paramLabel = "ARG", description = "signifies this instance is a coordinator") + @CommandLine.Option(names = { "--worker-count" }, paramLabel = "ARG", description = "signifies this instance is a coordinator") int workerCount = 1; @CommandLine.Option(names = { - "--split-heads" }, paramLabel = "ARG", description = "Should coordinator split work across attention heads (default: ${DEFAULT-VALUE})") + "--split-heads" }, paramLabel = "ARG", description = "Should coordinator split work across attention heads (default: ${DEFAULT-VALUE})") boolean splitHeads = true; @CommandLine.Option(names = { - "--split-layers" }, paramLabel = "ARG", description = "Should coordinator split work across layers (default: ${DEFAULT-VALUE})") + "--split-layers" }, paramLabel = "ARG", description = "Should coordinator split work across layers (default: ${DEFAULT-VALUE})") boolean splitLayers = false; @CommandLine.Option(names = { @@ -117,28 +115,29 @@ public void run() { if (includeWorker) { Worker w = new Worker( - model.toFile(), - SimpleBaseCommand.getOwner(modelName), - SimpleBaseCommand.getName(modelName), - modelType, - "localhost", - grpcPort, - grpcPort + 1, - workingDirectory, - advancedSection.workingMemoryType, - advancedSection.workingQuantizationType, - Optional.ofNullable(advancedSection.modelQuantization), - Optional.ofNullable("in-jvm-worker"), - Optional.ofNullable(downloadSection.authToken), - Optional.ofNullable(downloadSection.branch)); + model.toFile(), + SimpleBaseCommand.getOwner(modelName), + SimpleBaseCommand.getName(modelName), + modelType, + "localhost", + grpcPort, + grpcPort + 1, + workingDirectory, + advancedSection.workingMemoryType, + advancedSection.workingQuantizationType, + Optional.ofNullable(advancedSection.modelQuantization), + Optional.ofNullable("in-jvm-worker"), + Optional.ofNullable(downloadSection.authToken), + Optional.ofNullable(downloadSection.branch) + ); new Thread(() -> { - try { - w.run(); - } catch (Exception e) { - e.printStackTrace(); - } - }).start(); + try { + w.run(); + } catch (Exception e) { + e.printStackTrace(); + } + }).start(); } System.out.println("Chat UI: http://localhost:" + port); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java index 6244b45..25a3004 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java @@ -263,10 +263,11 @@ public AbstractTensor batchForward( } public AbstractTensor forward( - AbstractTensor embedding, - int startPos, - KvBufferCache.KvBuffer kvbuf, - Optional>> tensorReducer) { + AbstractTensor embedding, + int startPos, + KvBufferCache.KvBuffer kvbuf, + Optional>> tensorReducer + ) { for (int i = c.dctx().layerStart; i < c.dctx().layerEnd; i++) { int relativeLayer = i - c.dctx().layerStart; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/DistributedContext.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/DistributedContext.java index 62c6e32..910c82f 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/DistributedContext.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/DistributedContext.java @@ -115,14 +115,20 @@ public int getShardLength(int length) { } public String toString() { - return "DistributedContext{" + - ", embeddingSegmentStart=" + embeddingSegmentStart + - ", embeddingSegmentEnd=" + embeddingSegmentEnd + - ", headStart=" + headStart + - ", headEnd=" + headEnd + - ", layerStart=" + layerStart + - ", layerEnd=" + layerEnd + - '}'; + return "DistributedContext{" + + ", embeddingSegmentStart=" + + embeddingSegmentStart + + ", embeddingSegmentEnd=" + + embeddingSegmentEnd + + ", headStart=" + + headStart + + ", headEnd=" + + headEnd + + ", layerStart=" + + layerStart + + ", layerEnd=" + + layerEnd + + '}'; } public static Builder builder(Config c) { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/MLPBlock.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/MLPBlock.java index eab564d..dbd97e8 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/MLPBlock.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/MLPBlock.java @@ -122,8 +122,8 @@ public AbstractTensor forward(AbstractTensor lnemb, Optional TensorOperationsProvider.get().accumulate(buf, bias, dctx.hiddenSegmentStart, dctx.hiddenSegmentLength) ); - //Not using pfor because we can use all cores - IntStream.range(dctx.hiddenSegmentStart, dctx.hiddenSegmentEnd).parallel().forEach( i -> { + // Not using pfor because we can use all cores + IntStream.range(dctx.hiddenSegmentStart, dctx.hiddenSegmentEnd).parallel().forEach(i -> { for (int j = 0; j < batchSize; j++) { float w1 = buf.get(j, i); float w1a = ActivationFunction.eval(activationFunction, w1); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java index cf2eeb9..cc4ce93 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaModel.java @@ -23,7 +23,6 @@ import com.github.tjake.jlama.safetensors.WeightLoader; import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer; import com.github.tjake.jlama.tensor.AbstractTensor; -import com.github.tjake.jlama.tensor.TensorCache; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; @@ -72,8 +71,7 @@ protected EmbedInput loadInputWeights() { return (inputToken, position) -> { - - AbstractTensor at = wte.slice(true, inputToken); + AbstractTensor at = wte.slice(true, inputToken); AbstractTensor embedding = at.copyShape(); // Always copy the entire embedding @@ -94,7 +92,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() { IntStream.range(c.dctx().layerStart, c.dctx().layerEnd).parallel().forEach(i -> { - int relativeLayer = i - c.dctx().layerStart; //FIXME: add a helper to the context + int relativeLayer = i - c.dctx().layerStart; // FIXME: add a helper to the context String base = "model.layers." + i + "."; String prefix = base + "self_attn."; @@ -135,10 +133,11 @@ protected SampleOutput loadOutputWeights() { DType qType = modelQType.orElse(this.modelDType); final LayerNorm outputLayerNorm = new RMSNorm(this, weights.load("model.norm.weight").quantize(qType)); - //Some llama models don't have a classification head + // Some llama models don't have a classification head AbstractTensor classificationWeights = weights.isWeightPresent("lm_head.weight") - ? weights.load("lm_head.weight").quantize(workingDType) - : wte == null ? wte = weights.load("model.embed_tokens.weight") : wte; + ? weights.load("lm_head.weight").quantize(workingDType) + : wte == null ? wte = weights.load("model.embed_tokens.weight") + : wte; return new SampleOutput() { @Override diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/HTTPSafeTensorLoader.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/HTTPSafeTensorLoader.java index d1ba9c3..a55e3d6 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/HTTPSafeTensorLoader.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/HTTPSafeTensorLoader.java @@ -136,20 +136,32 @@ public AbstractTensor load(String name, DistributedContext dctx, boolean sparseR long length = positionLimit - positionOffset; if (length > Integer.MAX_VALUE) { - //Make a segmented tensor + // Make a segmented tensor assert info.shape.length == 2 : "Only 2D tensors supported"; List tensors = new ArrayList<>(); int bytesPerColumn = info.dType.size() * info.shape[1]; long offset = positionOffset; - //Chunk size needs to be a multiple of the column size + // Chunk size needs to be a multiple of the column size long chunkSize = Integer.MAX_VALUE - (Integer.MAX_VALUE % bytesPerColumn); int chunkNum = 0; while (offset < positionLimit) { long chunkEnd = Math.min(offset + chunkSize, positionLimit); int numRowsInChunk = Ints.checkedCast((chunkEnd - offset) / bytesPerColumn); TensorShape chunkShape = TensorShape.of(numRowsInChunk, info.shape[1]); - tensors.add(downloadAndLoadTensor(name + ".part." + chunkNum++, weightFile, info, chunkShape, offset, chunkEnd, dctx, sparseRows, sparseColumns)); + tensors.add( + downloadAndLoadTensor( + name + ".part." + chunkNum++, + weightFile, + info, + chunkShape, + offset, + chunkEnd, + dctx, + sparseRows, + sparseColumns + ) + ); offset = chunkEnd; } @@ -165,20 +177,29 @@ public AbstractTensor load(String name, DistributedContext dctx, boolean sparseR } } - private AbstractTensor downloadAndLoadTensor(String name, String weightFile, TensorInfo info, TensorShape shape, long positionOffset, long positionLimit, DistributedContext dctx, boolean sparseRows, boolean sparseColumns) throws IOException - { + private AbstractTensor downloadAndLoadTensor( + String name, + String weightFile, + TensorInfo info, + TensorShape shape, + long positionOffset, + long positionLimit, + DistributedContext dctx, + boolean sparseRows, + boolean sparseColumns + ) throws IOException { Path weightPath = modelRoot.resolve(weightFile + ".part." + positionOffset + "_" + positionLimit); if (!weightPath.toFile().exists()) { logger.info("Downloading file: {} for {} {}MB", weightPath, name, (positionLimit - positionOffset) / 1024 / 1024); HttpSupport.downloadFile( - modelName, - weightFile, - branch, - authToken, - Optional.of(Pair.of(positionOffset, positionLimit)), - weightPath, - Optional.empty() + modelName, + weightFile, + branch, + authToken, + Optional.of(Pair.of(positionOffset, positionLimit)), + weightPath, + Optional.empty() ); } @@ -186,30 +207,30 @@ private AbstractTensor downloadAndLoadTensor(String name, String weightFile, Ten RandomAccessFile raf = new RandomAccessFile(weightPath.toFile(), "r"); ByteBuffer buf = raf.getChannel() - .map(FileChannel.MapMode.READ_ONLY, 0, raf.length()) - .duplicate() - .order(ByteOrder.LITTLE_ENDIAN) - .position(0) - .limit(length); + .map(FileChannel.MapMode.READ_ONLY, 0, raf.length()) + .duplicate() + .order(ByteOrder.LITTLE_ENDIAN) + .position(0) + .limit(length); if (raf.length() < length) { throw new RuntimeException( - "Failed to download the correct number of bytes: " + raf.length() + " != " + length + " for " + weightPath + "Failed to download the correct number of bytes: " + raf.length() + " != " + length + " for " + weightPath ); } logger.debug("Loading tensor: {} from {} with offsets: {} {}", name, weightPath, positionOffset, positionLimit); AbstractTensor tensor = Weights.loadTensorFromBuffer( - name, - info.dType, - modelDType, - shape, - buf, - sparseRows, - sparseColumns, - dctx, - this + name, + info.dType, + modelDType, + shape, + buf, + sparseRows, + sparseColumns, + dctx, + this ); layerFiles.put(name, Pair.of(raf, tensor)); @@ -262,8 +283,7 @@ public DType getModelDType() { public void close() { for (Pair pair : layerFiles.values()) { try { - if (pair.left() != null) - pair.left().close(); + if (pair.left() != null) pair.left().close(); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorIndex.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorIndex.java index b94f5dc..25be502 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorIndex.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorIndex.java @@ -157,10 +157,10 @@ private Map, List> computeMmapSplits(Map logger.debug("Adding tensor {} to split {}-{}", next.getKey(), info.dataOffsets[0], info.dataOffsets[1]); - //Used so fetch the tensor from the mmap + // Used so fetch the tensor from the mmap next = null; } else { - //Split large tensors up (they will be reassembled in the Weights class) + // Split large tensors up (they will be reassembled in the Weights class) if (tensors.size() == 0) { int bytesPerColumn = info.dType.size() * info.shape[1]; @@ -176,7 +176,7 @@ private Map, List> computeMmapSplits(Map long offset = info.dataOffsets[0]; long length = info.dataOffsets[1] - offset; - //Chunk size needs to be a multiple of the column size + // Chunk size needs to be a multiple of the column size long chunkSize = Integer.MAX_VALUE - (Integer.MAX_VALUE % bytesPerColumn); long offsetAdded = 0; int chunk = 0; @@ -184,16 +184,26 @@ private Map, List> computeMmapSplits(Map while (length > 0) { long chunkEnd = Math.min(offset + chunkSize, endOffset); String chunkName = next.getKey() + "-part-" + chunk++; - logger.debug("Adding chunk {} to split {}-{} {}", chunkName, offset, chunkEnd, Ints.checkedCast(chunkEnd - offset)); + logger.debug( + "Adding chunk {} to split {}-{} {}", + chunkName, + offset, + chunkEnd, + Ints.checkedCast(chunkEnd - offset) + ); splits.put(List.of(offset, chunkEnd), List.of(chunkName)); - //Add TensorInfo for the chunk + // Add TensorInfo for the chunk assert info.shape.length == 2 : "Only 2D tensors supported"; int numRowsInChunk = Ints.checkedCast((chunkEnd - offset) / bytesPerColumn); - //This tensorInfo is relative to the split which we know is at least the mmap limit + // This tensorInfo is relative to the split which we know is at least the mmap limit // We track the offsetAdded so we can make the offset relative to the current split - TensorInfo chunkInfo = new TensorInfo(info.dType, new long[]{numRowsInChunk, info.shape[1]}, new long[] { offset - offsetAdded, chunkEnd - offsetAdded }); + TensorInfo chunkInfo = new TensorInfo( + info.dType, + new long[] { numRowsInChunk, info.shape[1] }, + new long[] { offset - offsetAdded, chunkEnd - offsetAdded } + ); tensorInfoMap.put(chunkName, chunkInfo); added = true; offsetAdded += chunkEnd - offset; @@ -216,14 +226,11 @@ private Map, List> computeMmapSplits(Map logger.debug("Adding split {}-{} with {} tensors of {}", startOffset, endOffset, tensors.size(), tensorsSplit); // Add any sections that were split - if (!tensors.isEmpty()) - splits.put(List.of(startOffset, endOffset), new ArrayList<>(tensors)); + if (!tensors.isEmpty()) splits.put(List.of(startOffset, endOffset), new ArrayList<>(tensors)); - if (endOffset > lastSplitOffset) - lastSplitOffset = endOffset; + if (endOffset > lastSplitOffset) lastSplitOffset = endOffset; } - assert tensorsInFile == tensorsSplit : "Not all tensors were split: " + tensorsSplit + " != " + tensorsInFile; return splits; } @@ -248,7 +255,7 @@ public Map tensorInfoMap() { public AbstractTensor load(String name, DistributedContext dctx, boolean sparseRows, boolean sparseColumns) { Weights w = weightMap.get(name); if (w == null) { - //Maybe assemble the tensor from segments + // Maybe assemble the tensor from segments List segments = new ArrayList<>(); int idx = 0; while (true) { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSplitter.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSplitter.java index 4ee71a7..e1e041c 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSplitter.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSplitter.java @@ -1,6 +1,20 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.safetensors; - import com.github.tjake.jlama.tensor.AbstractTensor; import com.github.tjake.jlama.util.Pair; @@ -18,31 +32,28 @@ /** Helper class to split a large model into pieces **/ public class SafeTensorSplitter { - //Limit chunk size to 20G + // Limit chunk size to 20G static long MAX_CHUNK_SIZE = 20L << 30; - static String getChunkFile(TensorInfo info, long fileSize) { - //Map Tensor to a chunk based on its location in the model + // Map Tensor to a chunk based on its location in the model long fileChunk = Math.floorDiv(info.dataOffsets[1], MAX_CHUNK_SIZE); long totalChunks = Math.floorDiv(fileSize, MAX_CHUNK_SIZE); return String.format("model-%05d-of-%05d.safetensor", fileChunk, totalChunks); } public static void main(String[] args) { - if (args.length == 0) - throw new IllegalArgumentException("Missing model name"); + if (args.length == 0) throw new IllegalArgumentException("Missing model name"); String modelDir = args[0]; - if (!new File(modelDir).isDirectory()) - throw new IllegalArgumentException("Not a directory"); + if (!new File(modelDir).isDirectory()) throw new IllegalArgumentException("Not a directory"); - if (Paths.get(modelDir, SafeTensorIndex.MODEL_INDEX_JSON).toFile().exists()) - throw new IllegalArgumentException("Already split"); + if (Paths.get(modelDir, SafeTensorIndex.MODEL_INDEX_JSON).toFile().exists()) throw new IllegalArgumentException("Already split"); - if (!Paths.get(modelDir, SafeTensorIndex.SINGLE_MODEL_NAME).toFile().exists()) - throw new IllegalArgumentException("Missing model file"); + if (!Paths.get(modelDir, SafeTensorIndex.SINGLE_MODEL_NAME).toFile().exists()) throw new IllegalArgumentException( + "Missing model file" + ); WeightLoader wl = SafeTensorSupport.loadWeights(new File(modelDir)); @@ -50,7 +61,6 @@ public static void main(String[] args) { Map info = wl.tensorInfoMap(); - // First split the metadata into N chunks and adjust the offsets Map tensorIndex = new LinkedHashMap<>(); Map> chunkFiles = new HashMap<>(); @@ -66,9 +76,9 @@ public static void main(String[] args) { Pair chunkFile = chunkFiles.computeIfAbsent(chunkName, n -> { try { - File tmp = File.createTempFile("jlama","chunk"); + File tmp = File.createTempFile("jlama", "chunk"); tmp.deleteOnExit(); - RandomAccessFile r = new RandomAccessFile(tmp, "rw"); + RandomAccessFile r = new RandomAccessFile(tmp, "rw"); FileChannel ch = r.getChannel(); return Pair.of(r, ch); @@ -81,14 +91,15 @@ public static void main(String[] args) { AbstractTensor t = wl.load(name); FileChannel ch = chunkFile.right; TensorInfo newInfo = t.save(ch); - System.out.println("Wrote " + name + " to " + chunkName + " at " + newInfo.dataOffsets[0] + " to " + newInfo.dataOffsets[1]); + System.out.println( + "Wrote " + name + " to " + chunkName + " at " + newInfo.dataOffsets[0] + " to " + newInfo.dataOffsets[1] + ); Map tensors = tensorsInChunk.computeIfAbsent(chunkName, n -> new LinkedHashMap<>()); tensors.put(name, newInfo); } - - //Now We have the data im place data, write the real file + // Now We have the data im place data, write the real file for (Map.Entry> entry : chunkFiles.entrySet()) { String chunkName = entry.getKey(); Pair chunkFile = entry.getValue(); @@ -98,21 +109,21 @@ public static void main(String[] args) { byte[] header = om.writeValueAsBytes(chunkTensors); System.out.println("Writing " + chunkName + " with " + chunkTensors.size() + " tensors"); - //System.out.println(new String(header)); + // System.out.println(new String(header)); byte[] hsize = new byte[Long.BYTES]; ByteBuffer.wrap(hsize).order(ByteOrder.LITTLE_ENDIAN).putLong(header.length); - try(RandomAccessFile raf = new RandomAccessFile(Paths.get(modelDir, chunkName).toFile(), "rw")) { + try (RandomAccessFile raf = new RandomAccessFile(Paths.get(modelDir, chunkName).toFile(), "rw")) { raf.write(hsize); raf.write(header); raf.seek(raf.length()); - System.out.println("Writing " + ch.size() + " bytes of data from " + raf.getChannel().position() ); + System.out.println("Writing " + ch.size() + " bytes of data from " + raf.getChannel().position()); ch.transferTo(0, ch.size(), raf.getChannel()); } } - //Write the index - try(RandomAccessFile raf = new RandomAccessFile(Paths.get(modelDir, SafeTensorIndex.MODEL_INDEX_JSON).toFile(), "rw")) { + // Write the index + try (RandomAccessFile raf = new RandomAccessFile(Paths.get(modelDir, SafeTensorIndex.MODEL_INDEX_JSON).toFile(), "rw")) { raf.write(om.writeValueAsBytes(Map.of("metadata", new HashMap<>(), "weight_map", tensorIndex))); } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java index 052dc68..c12267c 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java @@ -33,7 +33,6 @@ import java.io.*; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.nio.channels.FileChannel; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; @@ -70,10 +69,10 @@ public static Map readTensorInfoMap(ByteBuffer buf, Optional } // Sort by value using a lambda expression - Map sortedMap = tensorInfoMap.entrySet().stream() - .sorted(Map.Entry.comparingByValue()) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (e1, e2) -> e1, LinkedHashMap::new)); - + Map sortedMap = tensorInfoMap.entrySet() + .stream() + .sorted(Map.Entry.comparingByValue()) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (e1, e2) -> e1, LinkedHashMap::new)); final Map finalMetadata = metadata; saveMetadata.ifPresent(m -> m.putAll(finalMetadata)); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/TensorInfo.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/TensorInfo.java index d5414bb..ae506c6 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/TensorInfo.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/TensorInfo.java @@ -75,7 +75,7 @@ public int hashCode() { @Override public int compareTo(TensorInfo o) { - //In the case we are reading in order of dataOffsets + // In the case we are reading in order of dataOffsets return Long.compare(dataOffsets[0], o.dataOffsets[0]); } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Weights.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Weights.java index 84118f8..8714e34 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Weights.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Weights.java @@ -15,7 +15,6 @@ */ package com.github.tjake.jlama.safetensors; -import com.github.tjake.jlama.math.FloatConversions; import com.github.tjake.jlama.model.DistributedContext; import com.github.tjake.jlama.tensor.*; import com.github.tjake.jlama.util.Pair; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java index 6b1a335..21044f3 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/FloatBufferTensor.java @@ -16,7 +16,6 @@ package com.github.tjake.jlama.tensor; import com.github.tjake.jlama.safetensors.DType; -import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.github.tjake.jlama.util.DebugSupport; import com.github.tjake.jlama.util.UnsafeDirectByteBuffer; import com.google.common.base.Preconditions; @@ -64,9 +63,9 @@ public FloatBufferTensor(TensorShape shape) { super(DType.F32, shape, true); this.name = "tmp"; this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer( - Ints.checkedCast(shape.size() * dType().size()), - UnsafeDirectByteBuffer.CACHE_LINE_SIZE - ).asFloatBuffer(); + Ints.checkedCast(shape.size() * dType().size()), + UnsafeDirectByteBuffer.CACHE_LINE_SIZE + ).asFloatBuffer(); this.segment = MemorySegment.ofBuffer(b); } @@ -137,8 +136,7 @@ public int getMemorySegmentOffset(int offset) { @Override public FloatVector getVector(VectorSpecies species, int... voffset) { int offset = getOffset(voffset); - if (b.hasArray()) - return FloatVector.fromArray(species, b.array(), offset); + if (b.hasArray()) return FloatVector.fromArray(species, b.array(), offset); return FloatVector.fromMemorySegment(species, segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN); } @@ -148,10 +146,8 @@ public void intoTensor(FloatVector vector, int... aoffset) { // Preconditions.checkArgument(!b.isReadOnly()); int offset = getOffset(aoffset); - if (b.hasArray()) - vector.intoArray(b.array(), offset); - else - vector.intoMemorySegment(segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN); + if (b.hasArray()) vector.intoArray(b.array(), offset); + else vector.intoMemorySegment(segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN); } @Override diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/SegmentedTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/SegmentedTensor.java index 43afd05..52df22a 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/SegmentedTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/SegmentedTensor.java @@ -1,3 +1,18 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.tensor; import java.io.IOException; @@ -12,10 +27,8 @@ import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.VectorSpecies; -public class SegmentedTensor extends BFloat16BufferTensor -{ - public static SegmentedTensor wrap(List ft) - { +public class SegmentedTensor extends BFloat16BufferTensor { + public static SegmentedTensor wrap(List ft) { Preconditions.checkArgument(ft.size() > 1, "Must have at least two tensor to segment"); Preconditions.checkArgument(ft.get(0).shape().dims() == 2, "First tensor must be 2D"); @@ -25,15 +38,14 @@ public static SegmentedTensor wrap(List ft) int secondDim = ft0.shape().last(); int[] splitPoints = new int[ft.size()]; splitPoints[0] = firstDim; - for (int i = 1; i < ft.size(); i++) - { + for (int i = 1; i < ft.size(); i++) { AbstractTensor t = ft.get(i); Preconditions.checkArgument(t.shape().last() == secondDim, "All tensors must have the same second dimension"); firstDim += t.shape().first(); splitPoints[i] = firstDim; } - SegmentedTensor st = new SegmentedTensor(TensorShape.of(firstDim, secondDim), splitPoints, ft.toArray(new AbstractTensor[0])); + SegmentedTensor st = new SegmentedTensor(TensorShape.of(firstDim, secondDim), splitPoints, ft.toArray(new AbstractTensor[0])); return st; } @@ -41,8 +53,7 @@ public static SegmentedTensor wrap(List ft) private final AbstractTensor[] tensors; private final int[] splitPoints; - protected SegmentedTensor(TensorShape shape, int[] splitPoints, AbstractTensor ...tensors) - { + protected SegmentedTensor(TensorShape shape, int[] splitPoints, AbstractTensor... tensors) { super("segmented-tensor", ShortBuffer.allocate(0), shape, false); this.splitPoints = splitPoints; this.tensors = tensors; @@ -64,14 +75,11 @@ public TensorInfo save(FileChannel out) throws IOException { } @Override - public AbstractTensor slice(int... dims) - { + public AbstractTensor slice(int... dims) { Preconditions.checkArgument(dims.length == 1, "Must slice on first dimension"); int index = dims[0]; - for (int i = 0; i < splitPoints.length; i++) - { - if (index < splitPoints[i]) - { + for (int i = 0; i < splitPoints.length; i++) { + if (index < splitPoints[i]) { return tensors[i].slice(index - (i == 0 ? 0 : splitPoints[i - 1])); } } @@ -79,80 +87,64 @@ public AbstractTensor slice(int... dims) throw new IllegalArgumentException("Index out of range"); } - - ////////////////////// Everything below this line is not supported ////////////////////// + ////////////////////// Everything below this line is not supported ////////////////////// @Override - public AbstractTensor slice(boolean cacheInnerSlice, int... dims) - { + public AbstractTensor slice(boolean cacheInnerSlice, int... dims) { return super.slice(dims); } @Override - protected AbstractTensor make(TensorShape shape) - { + protected AbstractTensor make(TensorShape shape) { throw new UnsupportedOperationException("Not supported"); } @Override - protected AbstractTensor make(int heapOffset, int heapLength, TensorShape shape, boolean cacheSlices) - { + protected AbstractTensor make(int heapOffset, int heapLength, TensorShape shape, boolean cacheSlices) { throw new UnsupportedOperationException("Not supported"); } @Override - public float get(int... dims) - { + public float get(int... dims) { throw new UnsupportedOperationException("Not supported"); } @Override - public void set(float v, int... dims) - { + public void set(float v, int... dims) { throw new UnsupportedOperationException("Not supported"); } @Override - public ShortVector getVector(VectorSpecies species, int... voffset) - { + public ShortVector getVector(VectorSpecies species, int... voffset) { throw new UnsupportedOperationException("Not supported"); } @Override - public void intoTensor(ShortVector vector, int... aoffset) - { + public void intoTensor(ShortVector vector, int... aoffset) { throw new UnsupportedOperationException("Not supported"); } @Override - public String toString() - { - return "SegmentedBF16Tensor{" + - "shape=" + shape + - ", tensors=" + tensors.length + - '}'; + public String toString() { + return "SegmentedBF16Tensor{" + "shape=" + shape + ", tensors=" + tensors.length + '}'; } @Override - public MemorySegment getMemorySegment() - { + public MemorySegment getMemorySegment() { throw new UnsupportedOperationException("Not supported"); } @Override - public int getMemorySegmentOffset(int offset) - { + public int getMemorySegmentOffset(int offset) { throw new UnsupportedOperationException("Not supported"); } @Override - public void copyFrom(AbstractTensor src, int srcOffset, int destOffset, int length) - { + public void copyFrom(AbstractTensor src, int srcOffset, int destOffset, int length) { throw new UnsupportedOperationException("Not supported"); } @Override - public void clear() - { + public void clear() { throw new UnsupportedOperationException("Not supported"); } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java index 582442d..2faf5ce 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/PanamaTensorOperations.java @@ -212,10 +212,14 @@ protected BiIntConsumer initMatmul1x1() { var b0hi = b0.lanewise(VectorOperators.LSHR, Q4_BYTE_SHIFT_128).sub(Q4_BYTE_SUB_128); // BLOCK_SIZE Floats - var af0 = a.getVector(FloatVector.SPECIES_256, i, aoffset).mul(b0lo.castShape(FloatVector.SPECIES_256, 0));; - var af1 = a.getVector(FloatVector.SPECIES_256, i, aoffset + 8).mul(b0lo.castShape(FloatVector.SPECIES_256, 1));; - var af2 = a.getVector(FloatVector.SPECIES_256, i, aoffset + Q4ByteBufferTensor.HALF_BLOCK).mul(b0hi.castShape(FloatVector.SPECIES_256, 0)); - var af3 = a.getVector(FloatVector.SPECIES_256, i, aoffset + Q4ByteBufferTensor.HALF_BLOCK + 8).mul(b0hi.castShape(FloatVector.SPECIES_256, 1)); + var af0 = a.getVector(FloatVector.SPECIES_256, i, aoffset).mul(b0lo.castShape(FloatVector.SPECIES_256, 0)); + ; + var af1 = a.getVector(FloatVector.SPECIES_256, i, aoffset + 8).mul(b0lo.castShape(FloatVector.SPECIES_256, 1)); + ; + var af2 = a.getVector(FloatVector.SPECIES_256, i, aoffset + Q4ByteBufferTensor.HALF_BLOCK) + .mul(b0hi.castShape(FloatVector.SPECIES_256, 0)); + var af3 = a.getVector(FloatVector.SPECIES_256, i, aoffset + Q4ByteBufferTensor.HALF_BLOCK + 8) + .mul(b0hi.castShape(FloatVector.SPECIES_256, 1)); acc = af0.add(af1).add(af2).add(af3).fma(scale, acc); } @@ -238,7 +242,6 @@ protected BiIntConsumer initMatmul1x4() { FloatVector scale0 = FloatVector.broadcast(FloatVector.SPECIES_256, b.getFactorForIndex(j + 0, boffset)); FloatVector scale1 = FloatVector.broadcast(FloatVector.SPECIES_256, b.getFactorForIndex(j + 1, boffset)); - // BLOCK_SIZE Floats var af0 = a.getVector(FloatVector.SPECIES_256, i, aoffset); var af1 = a.getVector(FloatVector.SPECIES_256, i, aoffset + 8); @@ -273,7 +276,6 @@ protected BiIntConsumer initMatmul1x4() { acc1 = af0l.add(af1l).add(af2l).add(af3l).fma(scale1, acc1); } - } c.set(acc0.reduceLanes(VectorOperators.ADD), i, j + 0 + rOffset); @@ -729,8 +731,10 @@ protected BiIntConsumer initMatmul1x1() { // First take the scaling factors of both tensors and multiply them in SIMD for (int bi = 0; bi < blocksNeeded; bi += FloatVector.SPECIES_256.length()) { - final var ablock = a.getBlockF().getVector(FloatVector.SPECIES_256, i, (int) (Q8ByteBufferTensor.I_BLOCK_SIZE * aoffset)); - final var bblock = b.getBlockF().getVector(FloatVector.SPECIES_256, j, (int) (Q8ByteBufferTensor.I_BLOCK_SIZE * boffset)); + final var ablock = a.getBlockF() + .getVector(FloatVector.SPECIES_256, i, (int) (Q8ByteBufferTensor.I_BLOCK_SIZE * aoffset)); + final var bblock = b.getBlockF() + .getVector(FloatVector.SPECIES_256, j, (int) (Q8ByteBufferTensor.I_BLOCK_SIZE * boffset)); final var scales = ablock.mul(bblock); // Now for each scalar fetch the corresponding block of data and dot product them @@ -2175,12 +2179,10 @@ public void accumulate(AbstractTensor aBatch, AbstractTensor bBatch, int offset, case BF16: switch (vectorType) { case AVX_512: - accumulateBF16_512( - (BFloat16BufferTensor) a, (BFloat16BufferTensor) b, offset, limit); + accumulateBF16_512((BFloat16BufferTensor) a, (BFloat16BufferTensor) b, offset, limit); break; case AVX_256: - accumulateBF16_256( - (BFloat16BufferTensor) a, (BFloat16BufferTensor) b, offset, limit); + accumulateBF16_256((BFloat16BufferTensor) a, (BFloat16BufferTensor) b, offset, limit); break; default: throw new UnsupportedOperationException(); @@ -2212,7 +2214,6 @@ void accumulateF32(FloatBufferTensor a, FloatBufferTensor b, int offset, int lim } } - void accumulateF32Q4_256(FloatBufferTensor a, Q4ByteBufferTensor b, int offset, int limit) { int aoffset = offset; int boffset = offset; @@ -2221,7 +2222,7 @@ void accumulateF32Q4_256(FloatBufferTensor a, Q4ByteBufferTensor b, int offset, int slen = Q4ByteBufferTensor.BLOCK_SIZE; for (; aoffset < alim; aoffset += slen, boffset += slen) { - FloatVector scale = FloatVector.broadcast(FloatVector.SPECIES_256, b.getFactorForIndex(0, boffset)); + FloatVector scale = FloatVector.broadcast(FloatVector.SPECIES_256, b.getFactorForIndex(0, boffset)); // Make 8 bytes -> 16 4bit -> 16 bytes -> 16 32F var wBytes = b.getVector(ByteVector.SPECIES_128, 0, boffset); @@ -2231,8 +2232,10 @@ void accumulateF32Q4_256(FloatBufferTensor a, Q4ByteBufferTensor b, int offset, // BLOCK_SIZE Floats var af0 = a.getVector(FloatVector.SPECIES_256, 0, aoffset).add(loBytes.castShape(FloatVector.SPECIES_256, 0).mul(scale)); var af1 = a.getVector(FloatVector.SPECIES_256, 0, aoffset + 8).add(loBytes.castShape(FloatVector.SPECIES_256, 1).mul(scale)); - var af2 = a.getVector(FloatVector.SPECIES_256, 0, aoffset + Q4ByteBufferTensor.HALF_BLOCK).add(hiBytes.castShape(FloatVector.SPECIES_256, 0).mul(scale)); - var af3 = a.getVector(FloatVector.SPECIES_256, 0, aoffset + Q4ByteBufferTensor.HALF_BLOCK + 8).add(hiBytes.castShape(FloatVector.SPECIES_256, 1).mul(scale)); + var af2 = a.getVector(FloatVector.SPECIES_256, 0, aoffset + Q4ByteBufferTensor.HALF_BLOCK) + .add(hiBytes.castShape(FloatVector.SPECIES_256, 0).mul(scale)); + var af3 = a.getVector(FloatVector.SPECIES_256, 0, aoffset + Q4ByteBufferTensor.HALF_BLOCK + 8) + .add(hiBytes.castShape(FloatVector.SPECIES_256, 1).mul(scale)); a.intoTensor(af0, 0, aoffset); a.intoTensor(af1, 0, aoffset + 8); diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java index c82d6bb..68eafe9 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java @@ -42,7 +42,6 @@ import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.ForkJoinPool; import java.util.concurrent.ThreadLocalRandom; import java.util.function.BiConsumer; import org.slf4j.Logger; @@ -74,7 +73,10 @@ public Coordinator( Optional authToken, Optional branch ) { - Preconditions.checkArgument(workerCount > 0 && (workerCount == 1 || workerCount % 2 == 0), "worker count must be a positive even number"); + Preconditions.checkArgument( + workerCount > 0 && (workerCount == 1 || workerCount % 2 == 0), + "worker count must be a positive even number" + ); Function weightLoaderFunction = SafeTensorSupport.isModelLocal(modelPath.toPath()) ? b -> SafeTensorSupport.loadWeights(modelPath) diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java index cffe366..2c4fa5b 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java @@ -20,7 +20,6 @@ import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.DistributedContext; import com.github.tjake.jlama.net.grpc.JlamaRingWorkerService; -import com.github.tjake.jlama.net.grpc.JlamaService; import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.safetensors.HTTPSafeTensorLoader; import com.github.tjake.jlama.safetensors.SafeTensorSupport; @@ -43,7 +42,6 @@ import java.util.Optional; import java.util.UUID; import java.util.concurrent.*; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -62,7 +60,6 @@ public class Worker implements Closeable { } } - private static final Logger logger = LoggerFactory.getLogger(Worker.class); private final UUID workerId; private final KvBufferCache kvBufferCache; @@ -98,32 +95,31 @@ public Worker( Optional authToken, Optional branch ) { - Channel channel = ManagedChannelBuilder.forAddress(host, coordinatorPort).usePlaintext() - .maxInboundMessageSize(MESSAGE_SIZE) - .build(); + Channel channel = ManagedChannelBuilder.forAddress(host, coordinatorPort) + .usePlaintext() + .maxInboundMessageSize(MESSAGE_SIZE) + .build(); - //Start the ring service + // Start the ring service this.peerService = new JlamaRingWorkerService(this); this.peerServer = ServerBuilder.forPort(peerPort).addService(peerService).maxInboundMessageSize(MESSAGE_SIZE).build(); - try{ + try { this.peerServer.start(); } catch (IOException e) { throw new RuntimeException(e); } - //Setup via coordinator + // Setup via coordinator this.workerId = optionalWorkerId.map(s -> new UUID(s.hashCode(), s.hashCode())).orElse(UUID.randomUUID()); this.client = JlamaServiceGrpc.newStub(channel).withMaxInboundMessageSize(MESSAGE_SIZE).withMaxOutboundMessageSize(MESSAGE_SIZE); - this.blockingClient = JlamaServiceGrpc.newBlockingStub(channel).withMaxInboundMessageSize(MESSAGE_SIZE).withMaxOutboundMessageSize(MESSAGE_SIZE); + this.blockingClient = JlamaServiceGrpc.newBlockingStub(channel) + .withMaxInboundMessageSize(MESSAGE_SIZE) + .withMaxOutboundMessageSize(MESSAGE_SIZE); this.workerIdBytes = ByteString.copyFrom( ByteBuffer.allocate(128).putLong(workerId.getMostSignificantBits()).putLong(workerId.getLeastSignificantBits()).flip() ); - RegisterRequest rr = RegisterRequest.newBuilder() - .setWorkerid(workerIdBytes) - .setHostname(HOSTNAME) - .setPeerPort(peerPort) - .build(); + RegisterRequest rr = RegisterRequest.newBuilder().setWorkerid(workerIdBytes).setHostname(HOSTNAME).setPeerPort(peerPort).build(); this.registerResponse = blockingClient.register(rr); @@ -136,11 +132,13 @@ public Worker( registerResponse.getNumLayerShards() ); - //Setup peer + // Setup peer this.peerInfo = registerResponse.getNumLayerShards() == 1 ? null : blockingClient.discover(rr); - this.peerClient = peerInfo == null || peerInfo.getIsCoordinator() ? null : JlamaWorkerRingGrpc.newStub( - ManagedChannelBuilder.forAddress(peerInfo.getHostname(), peerInfo.getPeerPort()).usePlaintext().build() - ); + this.peerClient = peerInfo == null || peerInfo.getIsCoordinator() + ? null + : JlamaWorkerRingGrpc.newStub( + ManagedChannelBuilder.forAddress(peerInfo.getHostname(), peerInfo.getPeerPort()).usePlaintext().build() + ); this.peerStream = peerInfo == null || peerInfo.getIsCoordinator() ? null : peerClient.pass(new StreamObserver<>() { @Override public void onNext(Empty empty) {} @@ -156,9 +154,9 @@ public void onCompleted() { } }); - this.combineStreams = new ConcurrentHashMap<>(); + this.combineStreams = new ConcurrentHashMap<>(); - //Load the model + // Load the model Function weightLoaderFunction = SafeTensorSupport.isModelLocal(modelPath.toPath()) ? b -> SafeTensorSupport.loadWeights(modelPath) : b -> new HTTPSafeTensorLoader(modelPath.toPath(), modelOwner, modelName, modelDType, authToken, branch); @@ -201,32 +199,25 @@ public void pass(ByteString sessionBytes, int startPosition, AbstractTensor tens ByteBuffer bb = sessionBytes.asReadOnlyByteBuffer(); UUID session = new UUID(bb.getLong(), bb.getLong()); - //logger.info("From Peer: {} token(s) from position {} for session {}", tensor.shape().first(), startPosition, session); + // logger.info("From Peer: {} token(s) from position {} for session {}", tensor.shape().first(), startPosition, session); - Consumer> combineCallback = registerResponse.getNumModelShards() == 1 - ? t -> {} - : t -> { + Consumer> combineCallback = registerResponse.getNumModelShards() == 1 ? t -> {} : t -> { CombineRequest.Builder nrb = CombineRequest.newBuilder() - .setUuid(sessionBytes) - .setWorkerid(workerIdBytes) - .setLayerShard(registerResponse.getLayerShard()) - .setModelShard(registerResponse.getModelShard()); + .setUuid(sessionBytes) + .setWorkerid(workerIdBytes) + .setLayerShard(registerResponse.getLayerShard()) + .setModelShard(registerResponse.getModelShard()); for (int i = 0; i < t.size(); i++) nrb = nrb.addTensor(getTensorBytes(t.get(i))); - //logger.info("1)Sending combine request for session {}", session); - CombineResponse combineResponse = getCombineResponseStream(session) - .request(nrb.build()) - .join(); + // logger.info("1)Sending combine request for session {}", session); + CombineResponse combineResponse = getCombineResponseStream(session).request(nrb.build()).join(); for (int i = 0; i < t.size(); i++) t.get(i) - .getMemorySegment() - .copyFrom(MemorySegment.ofBuffer(combineResponse - .getTensor(i) - .asReadOnlyByteBuffer() - .order(ByteOrder.LITTLE_ENDIAN))); + .getMemorySegment() + .copyFrom(MemorySegment.ofBuffer(combineResponse.getTensor(i).asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN))); }; AbstractTensor output = model.forward(tensor, startPosition, kvBufferCache.getKvBuffer(session), Optional.of(combineCallback)); @@ -236,19 +227,21 @@ public void pass(ByteString sessionBytes, int startPosition, AbstractTensor tens public void processOutput(ByteString session, int startPosition, int batchSize, AbstractTensor output) { if (peerInfo == null || peerInfo.getIsCoordinator()) { - outputStream.onNext(GenerateRequest.newBuilder() + outputStream.onNext( + GenerateRequest.newBuilder() .setSession(session) .setWorkerid(workerIdBytes) .setTensor(getTensorBytes(output.slice(output.shape().first() - 1))) // keep only the last token - .build()); + .build() + ); } else { // Send the last token to the next worker PassRecord peerRequest = PassRecord.newBuilder() - .setSession(session) - .setStartPosition(startPosition) - .setBatchSize(batchSize) - .setTensor(getTensorBytes(output)) - .build(); + .setSession(session) + .setStartPosition(startPosition) + .setBatchSize(batchSize) + .setTensor(getTensorBytes(output)) + .build(); peerStream.onNext(peerRequest); } @@ -317,8 +310,6 @@ private GenerateObserver(CountDownLatch finishedLatch) { this.finishedLatch = finishedLatch; } - - @Override public void onNext(GenerateResponse generateResponse) { int[] tokens = generateResponse.getTokensList().stream().mapToInt(Integer::intValue).toArray(); @@ -329,37 +320,37 @@ public void onNext(GenerateResponse generateResponse) { // logger.info("From Coordinator: {} token(s) from position {} for session {}", tokens.length, // startPosition, session); - Consumer> combineCallback = registerResponse.getNumModelShards() == 1 - ? t -> {} - : t -> { - CombineRequest.Builder nrb = CombineRequest.newBuilder() - .setUuid(generateResponse.getSession()) - .setWorkerid(workerIdBytes) - .setLayerShard(registerResponse.getLayerShard()) - .setModelShard(registerResponse.getModelShard()); - for (int i = 0; i < t.size(); i++) nrb = nrb.addTensor(getTensorBytes(t.get(i))); - - //logger.info("2){} Sending combine request for session {}", registerResponse.getWorkerOrd(), session); - - CombineResponse combineResponse = getCombineResponseStream(session) - .request(nrb.build()) - .join(); - - for (int i = 0; i < t.size(); i++) - t.get(i) - .getMemorySegment() - .copyFrom(MemorySegment.ofBuffer(combineResponse - .getTensor(i) - .asReadOnlyByteBuffer() - .order(ByteOrder.LITTLE_ENDIAN))); - }; - - AbstractTensor output = model.batchForward(tokens, startPosition, kvBufferCache.getKvBuffer(session), Optional.of(combineCallback)); + Consumer> combineCallback = registerResponse.getNumModelShards() == 1 ? t -> {} : t -> { + CombineRequest.Builder nrb = CombineRequest.newBuilder() + .setUuid(generateResponse.getSession()) + .setWorkerid(workerIdBytes) + .setLayerShard(registerResponse.getLayerShard()) + .setModelShard(registerResponse.getModelShard()); + for (int i = 0; i < t.size(); i++) + nrb = nrb.addTensor(getTensorBytes(t.get(i))); + + // logger.info("2){} Sending combine request for session {}", registerResponse.getWorkerOrd(), session); + + CombineResponse combineResponse = getCombineResponseStream(session).request(nrb.build()).join(); + + for (int i = 0; i < t.size(); i++) + t.get(i) + .getMemorySegment() + .copyFrom( + MemorySegment.ofBuffer(combineResponse.getTensor(i).asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN)) + ); + }; + + AbstractTensor output = model.batchForward( + tokens, + startPosition, + kvBufferCache.getKvBuffer(session), + Optional.of(combineCallback) + ); processOutput(generateResponse.getSession(), startPosition, tokens.length, output); } - @Override public void onError(Throwable throwable) { logger.error("Error in generate", throwable); @@ -386,7 +377,7 @@ public void run() { Uninterruptibles.awaitUninterruptibly(finishedLatch); - //Cleanup + // Cleanup if (peerStream != null) peerStream.onCompleted(); if (peerClient != null) ((ManagedChannel) peerClient.getChannel()).shutdown(); peerServer.shutdown(); diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/grpc/JlamaRingWorkerService.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/grpc/JlamaRingWorkerService.java index f105b73..a2f0d1e 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/grpc/JlamaRingWorkerService.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/grpc/JlamaRingWorkerService.java @@ -1,3 +1,18 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.net.grpc; import com.github.tjake.jlama.net.*; @@ -9,16 +24,15 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; -import java.util.UUID; public class JlamaRingWorkerService extends JlamaWorkerRingGrpc.JlamaWorkerRingImplBase { private static final Logger logger = LoggerFactory.getLogger(JlamaRingWorkerService.class); private final Worker worker; + public JlamaRingWorkerService(Worker worker) { this.worker = worker; } @@ -29,10 +43,14 @@ public StreamObserver pass(StreamObserver responseObserver) { return new StreamObserver<>() { @Override public void onNext(PassRecord value) { - //logger.info("Recieved pass record from peer"); + // logger.info("Recieved pass record from peer"); int startPosition = value.getStartPosition(); FloatBuffer buffer = value.getTensor().asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(); - AbstractTensor tensor = new FloatBufferTensor(buffer, TensorShape.of(value.getBatchSize(), worker.model.getConfig().embeddingLength), true); + AbstractTensor tensor = new FloatBufferTensor( + buffer, + TensorShape.of(value.getBatchSize(), worker.model.getConfig().embeddingLength), + true + ); ByteString sessionBytes = value.getSession(); worker.pass(sessionBytes, startPosition, tensor); diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/grpc/JlamaService.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/grpc/JlamaService.java index 20d0326..032d0a9 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/grpc/JlamaService.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/grpc/JlamaService.java @@ -15,12 +15,10 @@ */ package com.github.tjake.jlama.net.grpc; -import com.github.tjake.jlama.math.VectorMath; import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.net.*; import com.github.tjake.jlama.safetensors.Config; import com.github.tjake.jlama.tensor.AbstractTensor; -import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.github.tjake.jlama.util.Pair; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; @@ -83,7 +81,7 @@ public JlamaService(AbstractModel model, int workerCount, boolean splitHeads, bo // Calculate the number of parameters per layer and use it to determine the number of heads to split per worker if (splitLayers && splitHeads) { - //throw new RuntimeException("Not yet supporting splitting layers and heads together"); + // throw new RuntimeException("Not yet supporting splitting layers and heads together"); long queryParams = (long) c.embeddingLength * c.embeddingLength; long keyValueParams = 2L * c.numberOfKeyValueHeads * c.embeddingLength * c.embeddingLength; @@ -92,8 +90,7 @@ public JlamaService(AbstractModel model, int workerCount, boolean splitHeads, bo long attentionParams = queryParams + keyValueParams; // Calculate the parameters for the feedforward network - long feedforwardParams = - 2L * ((long) c.embeddingLength * c.hiddenLength + (long) c.hiddenLength * c.embeddingLength); + long feedforwardParams = 2L * ((long) c.embeddingLength * c.hiddenLength + (long) c.hiddenLength * c.embeddingLength); // Calculate the parameters for layer normalization (2 * hiddenSize for scaling and shifting) long layerNormParams = 2L * c.embeddingLength; @@ -107,16 +104,16 @@ public JlamaService(AbstractModel model, int workerCount, boolean splitHeads, bo long idealParamsPerWorker = idealBillionParamsPerWorker * 1_000_000_000L; long paramsPerWorker = tmpLayersPerShard * paramsPerLayer; - if (paramsPerWorker > idealParamsPerWorker) - { - tmpHeadsPerLayerShard = Math.min(Math.min(workerCount, c.numberOfKeyValueHeads), (int) Math.ceilDivExact(paramsPerLayer, idealParamsPerWorker)); - //Round up to the nearest power of 2 + if (paramsPerWorker > idealParamsPerWorker) { + tmpHeadsPerLayerShard = Math.min( + Math.min(workerCount, c.numberOfKeyValueHeads), + (int) Math.ceilDivExact(paramsPerLayer, idealParamsPerWorker) + ); + // Round up to the nearest power of 2 tmpHeadsPerLayerShard = nextPowerOfTwo(tmpHeadsPerLayerShard); tmpHeadsPerLayerShard = c.numberOfKeyValueHeads / tmpHeadsPerLayerShard; tmpLayersPerShard = tmpLayersPerShard * (c.numberOfKeyValueHeads / tmpHeadsPerLayerShard); - } - else - { + } else { tmpHeadsPerLayerShard = c.numberOfKeyValueHeads; } } @@ -131,7 +128,7 @@ public JlamaService(AbstractModel model, int workerCount, boolean splitHeads, bo this.ordinalCombinations = new ArrayList<>(workerCount); for (int i = 0; i < numLayerShards; i++) { for (int j = 0; j < numHeadShards; j++) { - ordinalCombinations.add(new int[]{i, j}); + ordinalCombinations.add(new int[] { i, j }); } } } @@ -199,14 +196,14 @@ public void register(RegisterRequest request, StreamObserver r int workerNum = workers.size(); RegisterResponse r = RegisterResponse.newBuilder() - .setHostname(request.getHostname()) - .setPeerPort(request.getPeerPort()) - .setModelShard(ordinalCombinations.get(workerNum)[HEAD_IDX]) - .setNumModelShards(numHeadShards) - .setLayerShard(ordinalCombinations.get(workerNum)[LAYER_IDX]) - .setNumLayerShards(numLayerShards) - .setWorkerOrd(workerNum) - .build(); + .setHostname(request.getHostname()) + .setPeerPort(request.getPeerPort()) + .setModelShard(ordinalCombinations.get(workerNum)[HEAD_IDX]) + .setNumModelShards(numHeadShards) + .setLayerShard(ordinalCombinations.get(workerNum)[LAYER_IDX]) + .setNumLayerShards(numLayerShards) + .setWorkerOrd(workerNum) + .build(); workers.put(wid, r); logger.info("Registered worker {} with workerNum {} of {} with {}", wid, workerNum, workerCount, r); @@ -226,7 +223,7 @@ public void register(RegisterRequest request, StreamObserver r public void discover(RegisterRequest request, StreamObserver responseObserver) { ByteBuffer bb = request.getWorkerid().asReadOnlyByteBuffer(); UUID wid = new UUID(bb.getLong(), bb.getLong()); - //Register should have been called before this + // Register should have been called before this if (!workers.containsKey(wid)) { responseObserver.onError(new RuntimeException("Worker not registered")); } else { @@ -243,24 +240,28 @@ public void discover(RegisterRequest request, StreamObserver responseO // If this is the last worker in the layer, then it should connect to the coordinator if (thisWorkersLayerShard == numLayerShards - 1) { - responseObserver.onNext(PeerInfo.newBuilder() + responseObserver.onNext( + PeerInfo.newBuilder() .setWorkerid(request.getWorkerid()) .setHostname(request.getHostname()) .setPeerPort(request.getPeerPort()) .setIsCoordinator(true) - .build()); + .build() + ); responseObserver.onCompleted(); } else { for (RegisterResponse r : workers.values()) { // If this worker is the next layer shard and the same head shard, then connect to it if (r.getLayerShard() == thisWorkersLayerShard + 1 && r.getModelShard() == thisWorkersHeadShard) { - responseObserver.onNext(PeerInfo.newBuilder() + responseObserver.onNext( + PeerInfo.newBuilder() .setWorkerid(r.getHostnameBytes()) .setIsCoordinator(false) .setHostname(r.getHostname()) .setPeerPort(r.getPeerPort()) - .build()); + .build() + ); responseObserver.onCompleted(); return; @@ -313,7 +314,7 @@ public void onNext(CombineRequest request) { k -> new MpmcArrayQueue<>(workerCount + 1) ); members.add(Pair.of(request, responseObserver)); - //logger.info("GOT COMBINE REQUEST {} {}", key, members.size()); + // logger.info("GOT COMBINE REQUEST {} {}", key, members.size()); // If we have all the workers, then we can calculate the result and send it back if (members.size() == numHeadShards && combinations.remove(key, members)) { MemorySegment[] tensors = null; @@ -350,7 +351,7 @@ public void onNext(CombineRequest request) { for (Pair> f : members) { f.right.onNext(response); } - //logger.info("Sent response to {} members", members.size()); + // logger.info("Sent response to {} members", members.size()); members.clear(); } } @@ -413,13 +414,11 @@ public AbstractTensor generateNextOutput(UUID session, List tokenIds, i .build(); for (Generator g : generators) { if (splitLayers) { - //The last layer shard sends back to coordinator from ring - if (g.workerAssignment.getLayerShard() == numLayerShards - 1) - g.registerLatch(session); + // The last layer shard sends back to coordinator from ring + if (g.workerAssignment.getLayerShard() == numLayerShards - 1) g.registerLatch(session); // The first layer shard gets the request from the coordinator - if (g.workerAssignment.getLayerShard() == 0) - g.responseObserver.onNext(gr); + if (g.workerAssignment.getLayerShard() == 0) g.responseObserver.onNext(gr); } else { g.registerLatch(session); g.responseObserver.onNext(gr); @@ -442,7 +441,7 @@ public AbstractTensor generateNextOutput(UUID session, List tokenIds, i throw new RuntimeException("No output received from workers"); } - //logger.info("Received output from worker {}", TensorOperationsProvider.get().sum(output)); + // logger.info("Received output from worker {}", TensorOperationsProvider.get().sum(output)); return output; } diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/grpc/TopologyService.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/grpc/TopologyService.java index d9d2d59..0d2b9df 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/grpc/TopologyService.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/grpc/TopologyService.java @@ -1,3 +1,18 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.net.grpc; import java.util.ArrayList; @@ -8,23 +23,18 @@ import com.github.tjake.jlama.model.functions.Generator; import com.github.tjake.jlama.net.Coordinator; import com.github.tjake.jlama.net.RegisterResponse; -import com.github.tjake.jlama.net.openai.model.CreateChatCompletionRequest; import com.github.tjake.jlama.safetensors.Config; -import jakarta.validation.Valid; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; import org.springframework.validation.annotation.Validated; -import org.springframework.web.bind.annotation.RequestBody; -import org.springframework.web.bind.annotation.RequestHeader; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.bind.annotation.RestController; @RestController @Validated -public class TopologyService -{ +public class TopologyService { @Autowired private Generator model; @@ -35,8 +45,7 @@ public class TopologyService * @return OK (status code 200) */ @RequestMapping(method = RequestMethod.GET, value = "/cluster/topology", produces = { "application/json" }) - public Object getTopology() - { + public Object getTopology() { if (!(model instanceof Coordinator)) { return new ResponseEntity<>(HttpStatus.BAD_GATEWAY); } @@ -49,22 +58,37 @@ public Object getTopology() String id = entry.getKey().toString(); RegisterResponse w = entry.getValue(); - workerList.add(Map.of( - "id", id, - "address", w.getHostname(), - "layer_shard", Integer.toString(w.getLayerShard()), - "head_shard", Integer.toString(w.getModelShard()), - "layer_shard_total", Integer.toString(w.getNumLayerShards()), - "head_shard_total", Integer.toString(w.getNumModelShards()), - "ordinal", Integer.toString(w.getWorkerOrd()))); + workerList.add( + Map.of( + "id", + id, + "address", + w.getHostname(), + "layer_shard", + Integer.toString(w.getLayerShard()), + "head_shard", + Integer.toString(w.getModelShard()), + "layer_shard_total", + Integer.toString(w.getNumLayerShards()), + "head_shard_total", + Integer.toString(w.getNumModelShards()), + "ordinal", + Integer.toString(w.getWorkerOrd()) + ) + ); } Map topology = Map.of( - "num_layers", config.numberOfLayers, - "num_heads", config.numberOfKeyValueHeads, - "num_workers", workerList.size(), - "workers", workerList); + "num_layers", + config.numberOfLayers, + "num_heads", + config.numberOfKeyValueHeads, + "num_workers", + workerList.size(), + "workers", + workerList + ); return new ResponseEntity<>(topology, HttpStatus.OK); } -} \ No newline at end of file +} diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java index 917cf0d..9343d58 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java @@ -107,27 +107,24 @@ Object createChatCompletion(@RequestHeader Map headers, @Valid @ AtomicInteger index = new AtomicInteger(0); if (request.getStream() != null && request.getStream()) { SseEmitter emitter = new SseEmitter(-1L); - CompletableFuture.supplyAsync(() -> model.generate(sessionId, builder.build(), temperature, maxTokens, (t, f) -> - CompletableFuture.supplyAsync(() -> { - try - { + CompletableFuture.supplyAsync( + () -> model.generate(sessionId, builder.build(), temperature, maxTokens, (t, f) -> CompletableFuture.supplyAsync(() -> { + try { emitter.send( - new CreateChatCompletionStreamResponse().id(sessionId.toString()) - .choices( - List.of( - new CreateChatCompletionStreamResponseChoicesInner().index(index.getAndIncrement()) - .delta(new ChatCompletionStreamResponseDelta().content(t)) - ) - ) + new CreateChatCompletionStreamResponse().id(sessionId.toString()) + .choices( + List.of( + new CreateChatCompletionStreamResponseChoicesInner().index(index.getAndIncrement()) + .delta(new ChatCompletionStreamResponseDelta().content(t)) + ) + ) ); - } - catch (IOException e) - { + } catch (IOException e) { emitter.completeWithError(e); } return null; - }) - )).handle((r, ex) -> { + })) + ).handle((r, ex) -> { try { emitter.send( new CreateChatCompletionStreamResponse().id(sessionId.toString()) @@ -142,9 +139,11 @@ Object createChatCompletion(@RequestHeader Map headers, @Valid @ emitter.complete(); - logger.info("Stats: {} ms/tok (prompt), {} ms/tok (gen)", - Math.round(r.promptTimeMs / (double) r.promptTokens), - Math.round(r.generateTimeMs / (double) r.generatedTokens)); + logger.info( + "Stats: {} ms/tok (prompt), {} ms/tok (gen)", + Math.round(r.promptTimeMs / (double) r.promptTokens), + Math.round(r.generateTimeMs / (double) r.generatedTokens) + ); } catch (IOException e) { emitter.completeWithError(e); diff --git a/jlama-net/src/test/java/com/github/tjake/jlama/net/DistributedServiceTest.java b/jlama-net/src/test/java/com/github/tjake/jlama/net/DistributedServiceTest.java index 32f8011..9db5e17 100644 --- a/jlama-net/src/test/java/com/github/tjake/jlama/net/DistributedServiceTest.java +++ b/jlama-net/src/test/java/com/github/tjake/jlama/net/DistributedServiceTest.java @@ -92,27 +92,27 @@ private void startWorker(Path modelRoot, String modelOwner, String modelName, in new Thread(() -> { try { Worker worker = new Worker( - modelRoot.toFile(), - modelOwner, - modelName, - DType.Q4, - "localhost", - 8888, - 8888 + workerNumber, - null, - DType.F32, - DType.I8, - Optional.empty(), - Optional.empty(), - Optional.empty(), - Optional.empty() + modelRoot.toFile(), + modelOwner, + modelName, + DType.Q4, + "localhost", + 8888, + 8888 + workerNumber, + null, + DType.F32, + DType.I8, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty() ); worker.run(); } catch (Exception e) { e.printStackTrace(); } finally { - //worker.close(); + // worker.close(); } }).start(); } diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/microbench/BatchBench.java b/jlama-tests/src/test/java/com/github/tjake/jlama/microbench/BatchBench.java index f82f46d..dfb115f 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/microbench/BatchBench.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/microbench/BatchBench.java @@ -19,7 +19,6 @@ import com.github.tjake.jlama.tensor.*; import com.github.tjake.jlama.tensor.operations.NaiveTensorOperations; import com.github.tjake.jlama.tensor.operations.TensorOperations; -import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import java.util.Collection; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; @@ -36,7 +35,7 @@ @Fork(warmups = 1, value = 1, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector", "--enable-preview", "-Djlama.force_panama_tensor_operations=true" }) public class BatchBench { - private static final TensorOperations ops = new NaiveTensorOperations(); //TensorOperationsProvider.get(); + private static final TensorOperations ops = new NaiveTensorOperations(); // TensorOperationsProvider.get(); private static final int BATCH_SIZE = 1024; private static final int SIZE = 1024; diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/microbench/TensorBench.java b/jlama-tests/src/test/java/com/github/tjake/jlama/microbench/TensorBench.java index 06a8106..d980585 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/microbench/TensorBench.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/microbench/TensorBench.java @@ -25,21 +25,11 @@ import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import com.github.tjake.jlama.util.MachineSpec; -import java.lang.foreign.MemorySegment; -import java.lang.reflect.Field; -import java.nio.ByteOrder; -import java.nio.file.Path; -import java.util.Arrays; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; -import jdk.incubator.vector.ByteVector; -import jdk.incubator.vector.FloatVector; -import jdk.incubator.vector.VectorOperators; -import jdk.incubator.vector.VectorSpecies; import org.openjdk.jmh.annotations.*; import org.openjdk.jmh.infra.Blackhole; -import sun.misc.Unsafe; @Warmup(iterations = 1, time = 5) @Measurement(iterations = 3, time = 5) @@ -58,7 +48,7 @@ public static class Parameters { final FloatBufferTensor f = new FloatBufferTensor(SIZE); final FloatBufferTensor f2 = new FloatBufferTensor(SIZE); - final FloatBufferTensor r = new FloatBufferTensor(1,1); + final FloatBufferTensor r = new FloatBufferTensor(1, 1); final BFloat16BufferTensor bf; final Q8ByteBufferTensor q81; final Q8ByteBufferTensor q82; @@ -88,14 +78,14 @@ public Parameters() { } } - /* @Benchmark + /* @Benchmark @OutputTimeUnit(TimeUnit.MILLISECONDS) @BenchmarkMode(Mode.Throughput) @Threads(8) public void a_aq8dotq4(Parameters p, Blackhole bh) { bh.consume(nops.dotProduct(p.q81, p.q4, 0, 0, SIZE)); } - + @Benchmark @OutputTimeUnit(TimeUnit.MILLISECONDS) @BenchmarkMode(Mode.Throughput) @@ -103,7 +93,7 @@ public void a_aq8dotq4(Parameters p, Blackhole bh) { public void a_pq8dotq4(Parameters p, Blackhole bh) { bh.consume(ops.dotProduct(p.q81, p.q4, 0, 0, SIZE)); } - + @Benchmark @OutputTimeUnit(TimeUnit.MILLISECONDS) @BenchmarkMode(Mode.Throughput) @@ -128,14 +118,14 @@ public void panama_f32dotq4(Parameters p, Blackhole bh) { bh.consume(ops.dotProduct(p.f, p.q4, 0, 0, SIZE)); } - /* @Benchmark + /* @Benchmark @OutputTimeUnit(TimeUnit.MILLISECONDS) @BenchmarkMode(Mode.Throughput) @Threads(8) public void a_f32dotq8(Parameters p, Blackhole bh) { bh.consume(ops.dotProduct(p.f, p.q82, 0, 0, SIZE)); } - + @Benchmark @OutputTimeUnit(TimeUnit.MILLISECONDS) @BenchmarkMode(Mode.Throughput) @@ -143,7 +133,7 @@ public void a_f32dotq8(Parameters p, Blackhole bh) { public void f32dotf32(Parameters p, Blackhole bh) { bh.consume(ops.dotProduct(p.f, p.f2, 0, 0, SIZE)); } - + @Benchmark @OutputTimeUnit(TimeUnit.MILLISECONDS) @BenchmarkMode(Mode.Throughput) @@ -151,8 +141,8 @@ public void f32dotf32(Parameters p, Blackhole bh) { public void f32dotf32nops(Parameters p, Blackhole bh) { bh.consume(nops.dotProduct(p.f, p.f2, 0, 0, SIZE)); } - - + + @Benchmark @OutputTimeUnit(TimeUnit.MILLISECONDS) @BenchmarkMode(Mode.Throughput) diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/safetensors/TestParser.java b/jlama-tests/src/test/java/com/github/tjake/jlama/safetensors/TestParser.java index 85ef574..4fd691a 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/safetensors/TestParser.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/safetensors/TestParser.java @@ -202,7 +202,7 @@ public void testSegmentedTensor() throws IOException { Assert.assertEquals(orig.shape[0], t.shape().dim(0)); Assert.assertEquals(orig.shape[1], t.shape().dim(1)); - //Make sure we can slice the last row + // Make sure we can slice the last row AbstractTensor s = t.slice(orig.shape[0] - 1); } diff --git a/pom.xml b/pom.xml index 3b67d76..43ae3e1 100644 --- a/pom.xml +++ b/pom.xml @@ -42,7 +42,7 @@ UTF-8 - 0.5.1 + 0.6.0 2.0.7 1.5.6