Skip to content

Commit

Permalink
Add batch support to distributed inference
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Sep 16, 2024
1 parent 9bfc22f commit 555494c
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,15 @@ protected AbstractTensor batchForwardSlow(int[] token_ids, int startPos, KvBuffe
return last;
}

protected AbstractTensor batchForward(int[] token_ids, int startPos, KvBufferCache.KvBuffer kvbuf) {
public AbstractTensor batchForward(int[] token_ids, int startPos, KvBufferCache.KvBuffer kvbuf) {
return batchForward(token_ids, startPos, kvbuf, Optional.empty());
}

public AbstractTensor batchForward(int[] token_ids, int startPos, KvBufferCache.KvBuffer kvbuf, Optional<Consumer<List<AbstractTensor>>> tensorReducer) {
AbstractTensor embedding = embedInput.batchInputsToEmbeddings(token_ids, startPos);
for (int i = c.dctx().layerStart; i < c.dctx().layerEnd; i++) {
AbstractTensor ref = embedding; // reference so we can free
embedding = transformerBlocks[i].forward(embedding, startPos, kvbuf, Optional.empty());
embedding = transformerBlocks[i].forward(embedding, startPos, kvbuf, tensorReducer);
ref.close();
}

Expand Down
27 changes: 15 additions & 12 deletions jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,20 @@
import io.grpc.ServerBuilder;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.BiConsumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Coordinator implements Generator {
private static final Logger logger = LoggerFactory.getLogger(Coordinator.class);
private static final ConcurrentMap<UUID, Integer> sessionPositions = new ConcurrentHashMap<>();
private final int port;
private final int workerCount;
private final Server server;
Expand Down Expand Up @@ -108,38 +113,34 @@ public Generator.Response generate(
StringBuilder responseBuilder = new StringBuilder();
StringBuilder responseWithSpecialTokens = new StringBuilder();

int startPos = sessionPositions.computeIfAbsent(session, s -> 0);

FinishReason finishReason = FinishReason.MAX_TOKENS;
long[] encoded = model.getTokenizer().encode(promptContext.getPrompt());
Preconditions.checkArgument(encoded.length < model.getConfig().contextLength);

AbstractTensor logits = model.makeDenseTensor(model.getConfig().vocabularySize);

int[] promptTokens = new int[1 + encoded.length];
Integer[] promptTokens = new Integer[1 + encoded.length];

promptTokens[0] = model.getConfig().bosToken;
for (int i = 1; i < encoded.length; i++)
promptTokens[i] = Ints.checkedCast(encoded[i]);
for (int i = 1; i <= encoded.length; i++)
promptTokens[i] = Ints.checkedCast(encoded[i - 1]);

int promptLength = encoded.length;

long start = System.currentTimeMillis();

AbstractTensor output = null;
for (int i = 0; i < promptLength; i++) {
if (output != null) output.close();
logger.debug("Generating token {}", i);
output = service.generateNextOutput(session, promptTokens[i], i);
}
AbstractTensor output = service.generateNextOutput(session, Arrays.asList(promptTokens), startPos);

long promptTime = System.currentTimeMillis();
int lastPosition = startPos + promptLength;
int tokensGenerated = 0;

sessionPositions.put(session, lastPosition);
for (int i = promptLength; i < ntokens; i++) {
int next = model.sample(output, temperature, ThreadLocalRandom.current().nextFloat(), logits);
output.close();

tokensGenerated++;

// Model may tell us it's done
if (model.getConfig().eosTokens.contains(next)) {
finishReason = FinishReason.STOP_TOKEN;
Expand All @@ -160,6 +161,8 @@ public Generator.Response generate(
}

output = service.generateNextOutput(session, next, i);
tokensGenerated++;
sessionPositions.put(session, lastPosition++);
}

return new Generator.Response(
Expand Down
50 changes: 44 additions & 6 deletions jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,50 @@ private ByteString getTensorBytes(AbstractTensor tensor) {
return UnsafeByteOperations.unsafeWrap(tensor.getMemorySegment().asByteBuffer());
}

public void onNextBatch(GenerateResponse generateResponse) {
int[] tokens = generateResponse.getTokensList().stream().mapToInt(Integer::intValue).toArray();
int startPosition = generateResponse.getStartPosition();
ByteBuffer bb = generateResponse.getSession().asReadOnlyByteBuffer();
UUID session = new UUID(bb.getLong(), bb.getLong());

logger.info("Processing batch of {} starting at position {} for session {}", tokens, startPosition, session);

AbstractTensor output = model.batchForward(tokens, startPosition, kvBufferCache.getKvBuffer(session), Optional.of(t -> {
CombineRequest.Builder nrb = CombineRequest.newBuilder()
.setUuid(generateResponse.getSession())
.setWorkerid(workerIdBytes)
.setLayer(getNextRequestCount(session));
for (int i = 0; i < t.size(); i++)
nrb = nrb.addTensor(getTensorBytes(t.get(i)));

CombineResponse combineResponse = getCombineResponseStream(session).request(nrb.build()).join();

for (int i = 0; i < t.size(); i++)
t.get(i).getMemorySegment().copyFrom(
MemorySegment.ofBuffer(combineResponse.getTensor(i).asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN))
);
}));

outputStream.onNext(
GenerateRequest.newBuilder()
.setSession(generateResponse.getSession())
.setWorkerid(workerIdBytes)
.setTensor(getTensorBytes(output.slice(output.shape().first() - 1))) // keep only the last token
.build()
);

output.close();
}

@Override
public void onNext(GenerateResponse generateResponse) {
int token = generateResponse.getToken();
int position = generateResponse.getPosition();
if (generateResponse.getTokensCount() > 1) {
onNextBatch(generateResponse);
return;
}

int token = generateResponse.getTokens(0);
int position = generateResponse.getStartPosition();
ByteBuffer bb = generateResponse.getSession().asReadOnlyByteBuffer();
UUID session = new UUID(bb.getLong(), bb.getLong());

Expand All @@ -199,11 +239,9 @@ public void onNext(GenerateResponse generateResponse) {
CombineResponse combineResponse = getCombineResponseStream(session).request(nrb.build()).join();

for (int i = 0; i < t.size(); i++)
t.get(i)
.getMemorySegment()
.copyFrom(
t.get(i).getMemorySegment().copyFrom(
MemorySegment.ofBuffer(combineResponse.getTensor(i).asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN))
);
);
}));

outputStream.onNext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ public class JlamaService extends JlamaServiceGrpc.JlamaServiceImplBase {
private final GeneratorGroup generatorGroup;

private final ConcurrentMap<String, MpmcArrayQueue<Pair<CombineRequest, StreamObserver<CombineResponse>>>> combinations;
private final int headCount;

public JlamaService(AbstractModel model, int workerCount) {
Preconditions.checkArgument(
Expand All @@ -58,8 +57,6 @@ public JlamaService(AbstractModel model, int workerCount) {
this.workers = new ConcurrentHashMap<>();
this.combinations = new ConcurrentHashMap<>();
this.generatorGroup = new GeneratorGroup();

this.headCount = model.getConfig().numberOfKeyValueHeads;
}

public void waitForReady() {
Expand Down Expand Up @@ -119,6 +116,10 @@ public void register(RegisterRequest request, StreamObserver<RegisterResponse> r
}
}

public AbstractTensor generateNextOutput(UUID session, List<Integer> tokenIds, int startPosition) {
return generatorGroup.generateNextOutput(session, tokenIds, startPosition);
}

public AbstractTensor generateNextOutput(UUID session, int tokenId, int position) {
return generatorGroup.generateNextOutput(session, tokenId, position);
}
Expand Down Expand Up @@ -146,8 +147,6 @@ public void onNext(CombineRequest request) {

// If we have all the workers, then we can calculate the result and send it back
if (members.size() == workerCount && combinations.remove(key, members)) {
float sumSq = 0;
float sum = 0;
MemorySegment[] tensors = null;
for (Pair<CombineRequest, StreamObserver<CombineResponse>> f : members) {
if (f.left.getTensorCount() > 0) {
Expand Down Expand Up @@ -230,32 +229,29 @@ public void waitForReady() {
}

public AbstractTensor generateNextOutput(UUID session, int tokenId, int position) {
return generateNextOutput(session, Collections.singletonList(tokenId), position);
}

public AbstractTensor generateNextOutput(UUID session, List<Integer> tokenIds, int startPosition) {
Preconditions.checkArgument(generators.size() == workerCount, "Missing workers %d", workers.size());
ByteString sid = ByteString.copyFrom(
ByteBuffer.allocate(128).putLong(session.getMostSignificantBits()).putLong(session.getLeastSignificantBits()).flip()
);
GenerateResponse gr = GenerateResponse.newBuilder().setSession(sid).setToken(tokenId).setPosition(position).build();
GenerateResponse gr = GenerateResponse.newBuilder().setSession(sid).addAllTokens(tokenIds).setStartPosition(startPosition).build();
for (Generator g : generators) {
g.registerLatch(session);
g.responseObserver.onNext(gr);
}

AbstractTensor output = model.makeDenseTensor(model.getConfig().embeddingLength);


for (int j = 0; j < workerCount; j++) {
Generator g = generators.get(j);
ByteString v = g.waitForOutput(session);
RegisterResponse r = workers.get(g.workerId);

if (j == 0) {
FloatBuffer f = v.asReadOnlyByteBuffer()
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
int len = f.remaining();
for (int i = 0; i < len; i++) {
output.set(f.get(), 0, i);
}
output.getMemorySegment().copyFrom(MemorySegment.ofBuffer(v.asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN)));
}
}

Expand Down
16 changes: 4 additions & 12 deletions jlama-net/src/main/proto/JlamaService.proto
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,11 @@ message CombineRequest {
bytes workerid = 1;
bytes uuid = 2;
int32 layer = 3;
repeated bytes tensor = 6;
}

message TensorData {
int32 dim0 = 1; // Rows
int32 dim1 = 2; // Columns
int32 columnOffset = 3; // Offset into the column (if sparse)
int32 columnLength = 4; // Length of the column (if sparse)
bytes data = 5;
repeated bytes tensor = 4;
}

message CombineResponse {
repeated bytes tensor = 3;
repeated bytes tensor = 1;
}

message GenerateRequest {
Expand All @@ -29,8 +21,8 @@ message GenerateRequest {

message GenerateResponse {
bytes session = 1;
int32 token = 2;
int32 position = 3;
int32 startPosition = 2;
repeated int32 tokens = 3;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,6 @@ void manyWorkerTestLLama() throws Exception {
startWorker(modelRoot);
startWorker(modelRoot);

// startWorker(modelRoot);
// startWorker(modelRoot);
// startWorker(modelRoot);
// startWorker(modelRoot);

coordinator.generate(
UUID.randomUUID(),
PromptContext.of("Simply put, the theory of relativity states that"),
Expand Down

0 comments on commit 555494c

Please sign in to comment.