Skip to content

Commit

Permalink
Next rev
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Oct 16, 2024
1 parent 865c9c7 commit 6d3ab99
Show file tree
Hide file tree
Showing 26 changed files with 445 additions and 385 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,13 @@ public void run() {
builder.addUserMessage(prompt);
PromptContext builtPrompt = builder.build();

Generator.Response r = m.generate(session, builtPrompt, temperature, tokens == null ? m.getConfig().contextLength : tokens, makeOutHandler());
Generator.Response r = m.generate(
session,
builtPrompt,
temperature,
tokens == null ? m.getConfig().contextLength : tokens,
makeOutHandler()
);

out.println(
"\n\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.github.tjake.jlama.net.Worker;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.util.PhysicalCoreExecutor;
import com.google.common.util.concurrent.Uninterruptibles;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.builder.SpringApplicationBuilder;
Expand All @@ -30,24 +29,23 @@

import java.nio.file.Path;
import java.util.Optional;
import java.util.concurrent.TimeUnit;

@CommandLine.Command(name = "cluster-coordinator", description = "Starts a distributed rest api for a model using cluster workers", abbreviateSynopsis = true)
@SpringBootApplication(scanBasePackages = { "com.github.tjake.jlama.net.openai", "com.github.tjake.jlama.cli.commands", "com.github.tjake.jlama.net.grpc" })
@SpringBootApplication(scanBasePackages = { "com.github.tjake.jlama.net.openai", "com.github.tjake.jlama.cli.commands",
"com.github.tjake.jlama.net.grpc" })
@SpringBootConfiguration
@Configuration
public class ClusterCoordinatorCommand extends ModelBaseCommand implements WebMvcConfigurer {

@CommandLine.Option(names = {
"--worker-count" }, paramLabel = "ARG", description = "signifies this instance is a coordinator")
@CommandLine.Option(names = { "--worker-count" }, paramLabel = "ARG", description = "signifies this instance is a coordinator")
int workerCount = 1;

@CommandLine.Option(names = {
"--split-heads" }, paramLabel = "ARG", description = "Should coordinator split work across attention heads (default: ${DEFAULT-VALUE})")
"--split-heads" }, paramLabel = "ARG", description = "Should coordinator split work across attention heads (default: ${DEFAULT-VALUE})")
boolean splitHeads = true;

@CommandLine.Option(names = {
"--split-layers" }, paramLabel = "ARG", description = "Should coordinator split work across layers (default: ${DEFAULT-VALUE})")
"--split-layers" }, paramLabel = "ARG", description = "Should coordinator split work across layers (default: ${DEFAULT-VALUE})")
boolean splitLayers = false;

@CommandLine.Option(names = {
Expand Down Expand Up @@ -117,28 +115,29 @@ public void run() {

if (includeWorker) {
Worker w = new Worker(
model.toFile(),
SimpleBaseCommand.getOwner(modelName),
SimpleBaseCommand.getName(modelName),
modelType,
"localhost",
grpcPort,
grpcPort + 1,
workingDirectory,
advancedSection.workingMemoryType,
advancedSection.workingQuantizationType,
Optional.ofNullable(advancedSection.modelQuantization),
Optional.ofNullable("in-jvm-worker"),
Optional.ofNullable(downloadSection.authToken),
Optional.ofNullable(downloadSection.branch));
model.toFile(),
SimpleBaseCommand.getOwner(modelName),
SimpleBaseCommand.getName(modelName),
modelType,
"localhost",
grpcPort,
grpcPort + 1,
workingDirectory,
advancedSection.workingMemoryType,
advancedSection.workingQuantizationType,
Optional.ofNullable(advancedSection.modelQuantization),
Optional.ofNullable("in-jvm-worker"),
Optional.ofNullable(downloadSection.authToken),
Optional.ofNullable(downloadSection.branch)
);

new Thread(() -> {
try {
w.run();
} catch (Exception e) {
e.printStackTrace();
}
}).start();
try {
w.run();
} catch (Exception e) {
e.printStackTrace();
}
}).start();
}

System.out.println("Chat UI: http://localhost:" + port);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,11 @@ public AbstractTensor batchForward(
}

public AbstractTensor forward(
AbstractTensor embedding,
int startPos,
KvBufferCache.KvBuffer kvbuf,
Optional<Consumer<List<AbstractTensor>>> tensorReducer) {
AbstractTensor embedding,
int startPos,
KvBufferCache.KvBuffer kvbuf,
Optional<Consumer<List<AbstractTensor>>> tensorReducer
) {

for (int i = c.dctx().layerStart; i < c.dctx().layerEnd; i++) {
int relativeLayer = i - c.dctx().layerStart;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,20 @@ public int getShardLength(int length) {
}

public String toString() {
return "DistributedContext{" +
", embeddingSegmentStart=" + embeddingSegmentStart +
", embeddingSegmentEnd=" + embeddingSegmentEnd +
", headStart=" + headStart +
", headEnd=" + headEnd +
", layerStart=" + layerStart +
", layerEnd=" + layerEnd +
'}';
return "DistributedContext{"
+ ", embeddingSegmentStart="
+ embeddingSegmentStart
+ ", embeddingSegmentEnd="
+ embeddingSegmentEnd
+ ", headStart="
+ headStart
+ ", headEnd="
+ headEnd
+ ", layerStart="
+ layerStart
+ ", layerEnd="
+ layerEnd
+ '}';
}

public static Builder builder(Config c) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ public AbstractTensor forward(AbstractTensor lnemb, Optional<Consumer<List<Abstr
bias -> TensorOperationsProvider.get().accumulate(buf, bias, dctx.hiddenSegmentStart, dctx.hiddenSegmentLength)
);

//Not using pfor because we can use all cores
IntStream.range(dctx.hiddenSegmentStart, dctx.hiddenSegmentEnd).parallel().forEach( i -> {
// Not using pfor because we can use all cores
IntStream.range(dctx.hiddenSegmentStart, dctx.hiddenSegmentEnd).parallel().forEach(i -> {
for (int j = 0; j < batchSize; j++) {
float w1 = buf.get(j, i);
float w1a = ActivationFunction.eval(activationFunction, w1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.TensorCache;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
Expand Down Expand Up @@ -72,8 +71,7 @@ protected EmbedInput loadInputWeights() {

return (inputToken, position) -> {


AbstractTensor at = wte.slice(true, inputToken);
AbstractTensor at = wte.slice(true, inputToken);
AbstractTensor embedding = at.copyShape();

// Always copy the entire embedding
Expand All @@ -94,7 +92,7 @@ protected TransformerBlock[] loadTransformerBlockWeights() {

IntStream.range(c.dctx().layerStart, c.dctx().layerEnd).parallel().forEach(i -> {

int relativeLayer = i - c.dctx().layerStart; //FIXME: add a helper to the context
int relativeLayer = i - c.dctx().layerStart; // FIXME: add a helper to the context

String base = "model.layers." + i + ".";
String prefix = base + "self_attn.";
Expand Down Expand Up @@ -135,10 +133,11 @@ protected SampleOutput loadOutputWeights() {
DType qType = modelQType.orElse(this.modelDType);
final LayerNorm outputLayerNorm = new RMSNorm(this, weights.load("model.norm.weight").quantize(qType));

//Some llama models don't have a classification head
// Some llama models don't have a classification head
AbstractTensor classificationWeights = weights.isWeightPresent("lm_head.weight")
? weights.load("lm_head.weight").quantize(workingDType)
: wte == null ? wte = weights.load("model.embed_tokens.weight") : wte;
? weights.load("lm_head.weight").quantize(workingDType)
: wte == null ? wte = weights.load("model.embed_tokens.weight")
: wte;

return new SampleOutput() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,32 @@ public AbstractTensor load(String name, DistributedContext dctx, boolean sparseR
long length = positionLimit - positionOffset;

if (length > Integer.MAX_VALUE) {
//Make a segmented tensor
// Make a segmented tensor
assert info.shape.length == 2 : "Only 2D tensors supported";

List<AbstractTensor> tensors = new ArrayList<>();
int bytesPerColumn = info.dType.size() * info.shape[1];
long offset = positionOffset;
//Chunk size needs to be a multiple of the column size
// Chunk size needs to be a multiple of the column size
long chunkSize = Integer.MAX_VALUE - (Integer.MAX_VALUE % bytesPerColumn);
int chunkNum = 0;
while (offset < positionLimit) {
long chunkEnd = Math.min(offset + chunkSize, positionLimit);
int numRowsInChunk = Ints.checkedCast((chunkEnd - offset) / bytesPerColumn);
TensorShape chunkShape = TensorShape.of(numRowsInChunk, info.shape[1]);
tensors.add(downloadAndLoadTensor(name + ".part." + chunkNum++, weightFile, info, chunkShape, offset, chunkEnd, dctx, sparseRows, sparseColumns));
tensors.add(
downloadAndLoadTensor(
name + ".part." + chunkNum++,
weightFile,
info,
chunkShape,
offset,
chunkEnd,
dctx,
sparseRows,
sparseColumns
)
);
offset = chunkEnd;
}

Expand All @@ -165,51 +177,60 @@ public AbstractTensor load(String name, DistributedContext dctx, boolean sparseR
}
}

private AbstractTensor downloadAndLoadTensor(String name, String weightFile, TensorInfo info, TensorShape shape, long positionOffset, long positionLimit, DistributedContext dctx, boolean sparseRows, boolean sparseColumns) throws IOException
{
private AbstractTensor downloadAndLoadTensor(
String name,
String weightFile,
TensorInfo info,
TensorShape shape,
long positionOffset,
long positionLimit,
DistributedContext dctx,
boolean sparseRows,
boolean sparseColumns
) throws IOException {
Path weightPath = modelRoot.resolve(weightFile + ".part." + positionOffset + "_" + positionLimit);

if (!weightPath.toFile().exists()) {
logger.info("Downloading file: {} for {} {}MB", weightPath, name, (positionLimit - positionOffset) / 1024 / 1024);
HttpSupport.downloadFile(
modelName,
weightFile,
branch,
authToken,
Optional.of(Pair.of(positionOffset, positionLimit)),
weightPath,
Optional.empty()
modelName,
weightFile,
branch,
authToken,
Optional.of(Pair.of(positionOffset, positionLimit)),
weightPath,
Optional.empty()
);
}

int length = Ints.checkedCast(positionLimit - positionOffset);

RandomAccessFile raf = new RandomAccessFile(weightPath.toFile(), "r");
ByteBuffer buf = raf.getChannel()
.map(FileChannel.MapMode.READ_ONLY, 0, raf.length())
.duplicate()
.order(ByteOrder.LITTLE_ENDIAN)
.position(0)
.limit(length);
.map(FileChannel.MapMode.READ_ONLY, 0, raf.length())
.duplicate()
.order(ByteOrder.LITTLE_ENDIAN)
.position(0)
.limit(length);

if (raf.length() < length) {
throw new RuntimeException(
"Failed to download the correct number of bytes: " + raf.length() + " != " + length + " for " + weightPath
"Failed to download the correct number of bytes: " + raf.length() + " != " + length + " for " + weightPath
);
}

logger.debug("Loading tensor: {} from {} with offsets: {} {}", name, weightPath, positionOffset, positionLimit);

AbstractTensor tensor = Weights.loadTensorFromBuffer(
name,
info.dType,
modelDType,
shape,
buf,
sparseRows,
sparseColumns,
dctx,
this
name,
info.dType,
modelDType,
shape,
buf,
sparseRows,
sparseColumns,
dctx,
this
);

layerFiles.put(name, Pair.of(raf, tensor));
Expand Down Expand Up @@ -262,8 +283,7 @@ public DType getModelDType() {
public void close() {
for (Pair<RandomAccessFile, AbstractTensor> pair : layerFiles.values()) {
try {
if (pair.left() != null)
pair.left().close();
if (pair.left() != null) pair.left().close();
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
Loading

0 comments on commit 6d3ab99

Please sign in to comment.