Skip to content

Commit

Permalink
Working with row and columnar based splitting of tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Sep 10, 2024
1 parent e0878fb commit 50e4296
Show file tree
Hide file tree
Showing 42 changed files with 626 additions and 638 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="TestModels.MistralRun" type="JUnit" factoryName="JUnit" nameIsGenerated="true">
<configuration default="false" name="TestModels.MixtralRun" type="JUnit" factoryName="JUnit" nameIsGenerated="true">
<classpathModifications>
<entry exclude="true" path="$PROJECT_DIR$/jlama-native/target/classes" />
<entry path="$PROJECT_DIR$/jlama-native/target/jlama-native-0.3.1-linux-x86_64.jar" />
Expand All @@ -15,7 +15,7 @@
<option name="ALTERNATIVE_JRE_PATH" value="graalvm-22" />
<option name="PACKAGE_NAME" value="com.github.tjake.jlama.model" />
<option name="MAIN_CLASS_NAME" value="com.github.tjake.jlama.model.TestModels" />
<option name="METHOD_NAME" value="MistralRun" />
<option name="METHOD_NAME" value="MixtralRun" />
<option name="TEST_OBJECT" value="method" />
<option name="VM_PARAMETERS" value="-ea --add-modules=jdk.incubator.vector -Djava.library.path=../jlama-native/target/native-lib-only" />
<method v="2">
Expand Down
46 changes: 39 additions & 7 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,40 @@
FROM openjdk:21-slim as builder
FROM ubuntu:22.04 as builder
RUN apt-get update
RUN apt-get install -y build-essential
RUN apt-get install -y build-essential git zip curl zlib1g-dev

ENV SDKMAN_DIR /root/.sdkman
ENV JAVA_VERSION_20 20.0.2-graalce
ENV JAVA_VERSION_21 21.0.2-graalce
ENV JAVA_VERSION_22 22.0.2-graalce

RUN ["mkdir", "-p", "/build"]

RUN rm /bin/sh && ln -s /bin/bash /bin/sh
RUN curl -s "https://get.sdkman.io" | bash
RUN chmod a+x "$SDKMAN_DIR/bin/sdkman-init.sh"

RUN set -x \
&& echo "sdkman_auto_answer=true" > $SDKMAN_DIR/etc/config \
&& echo "sdkman_auto_selfupdate=false" >> $SDKMAN_DIR/etc/config \
&& echo "sdkman_insecure_ssl=false" >> $SDKMAN_DIR/etc/config

WORKDIR $SDKMAN_DIR
RUN [[ -s "$SDKMAN_DIR/bin/sdkman-init.sh" ]] && source "$SDKMAN_DIR/bin/sdkman-init.sh" && exec "$@"

RUN source /root/.bashrc
RUN source "$SDKMAN_DIR/bin/sdkman-init.sh" && sdk install java $JAVA_VERSION_20
RUN source "$SDKMAN_DIR/bin/sdkman-init.sh" && sdk install java $JAVA_VERSION_21
RUN source "$SDKMAN_DIR/bin/sdkman-init.sh" && sdk install java $JAVA_VERSION_22

WORKDIR /build
RUN git clone https://github.com/tjake/sdkman-for-toolchains.git
WORKDIR /build/sdkman-for-toolchains
RUN source "$SDKMAN_DIR/bin/sdkman-init.sh" && ./mvnw -Pnative clean package
RUN ["mkdir", "-p", "/root/.m2"]
RUN printf "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<toolchains>\n</toolchains>" > /root/.m2/toolchains.xml
RUN source "$SDKMAN_DIR/bin/sdkman-init.sh" && target/toolchains generate
RUN cat /root/.m2/toolchains.xml
WORKDIR /build
COPY jlama-cli jlama-cli
COPY jlama-core jlama-core
COPY jlama-native jlama-native
Expand All @@ -11,20 +44,19 @@ COPY .mvn .mvn
COPY pom.xml .
COPY mvnw .

RUN --mount=type=cache,target=/root/.m2 ./mvnw clean package
RUN --mount=type=cache,target=/root/.m2/repository source "$SDKMAN_DIR/bin/sdkman-init.sh" && sdk use java $JAVA_VERSION_22 && ./mvnw clean package -DskipTests

FROM openjdk:21-slim
RUN apt-get update
RUN apt-get install -y procps curl
RUN apt-get install -y procps curl gzip

LABEL org.opencontainers.image.source=https://github.com/tjake/Jlama

COPY inlinerules.json inlinerules.json
COPY run-cli.sh run-cli.sh
COPY conf/logback.xml logback.xml
COPY --from=builder jlama-cli/target/jlama-cli.jar ./jlama-cli.jar

RUN curl -s -L https://github.com/async-profiler/async-profiler/releases/download/v3.0/async-profiler-3.0-linux-x64.tar.gz | tar xvz - -C /profiler
COPY --from=builder /build/jlama-cli/target/jlama-cli.jar ./jlama-cli.jar
RUN mkdir -p /profiler && curl -s -L https://github.com/async-profiler/async-profiler/releases/download/v3.0/async-profiler-3.0-linux-x64.tar.gz | tar zxvf - -C /profiler

ENV JLAMA_PREINSTALLED_JAR=/jlama-cli.jar
ENV JLAMA_JVM_ARGS="-Dlogback.configurationFile=./logback.xml"
Expand Down
4 changes: 2 additions & 2 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ services:
resources:
limits:
memory: 1g
cpus: 1
cpus: "1"
command:
- cluster-coordinator
- --threads=2
Expand All @@ -36,7 +36,7 @@ services:
replicas: 8
resources:
limits:
cpus: 1
cpus: "1"
memory: 500m
command:
- cluster-worker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import com.github.tjake.jlama.model.AbstractModel;
import java.util.Optional;

import com.github.tjake.jlama.model.functions.Generator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.SpringBootConfiguration;
Expand All @@ -40,10 +42,10 @@ public class ApiServiceCommand extends BaseCommand implements WebMvcConfigurer {
@CommandLine.Option(names = { "-p", "--port" }, description = "http port (default: ${DEFAULT-VALUE})", defaultValue = "8080")
int port = 8080;

static volatile AbstractModel m;
protected static volatile Generator m;

@Bean
public AbstractModel getModelBean() {
public Generator getModelBean() {
return m;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,23 @@
*/
package com.github.tjake.jlama.cli.commands;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.net.Coordinator;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.builder.SpringApplicationBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
import picocli.CommandLine;

@CommandLine.Command(name = "cluster-coordinator", description = "Starts a distributed rest api for a model using cluster workers")
public class ClusterCoordinatorCommand extends BaseCommand {
@SpringBootApplication(scanBasePackages = { "com.github.tjake.jlama.net.openai", "com.github.tjake.jlama.cli.commands" })
@SpringBootConfiguration
@Configuration
public class ClusterCoordinatorCommand extends BaseCommand implements WebMvcConfigurer {

@CommandLine.Option(names = { "-w", "--worker-count" }, description = "signifies this instance is a coordinator", required = true)
int workerCount = 1;
Expand All @@ -32,11 +44,19 @@ public class ClusterCoordinatorCommand extends BaseCommand {
"--port" }, description = "http port to listen on (default: ${DEFAULT-VALUE})", defaultValue = "8080")
int port = 8080;

@Override
public void addResourceHandlers(ResourceHandlerRegistry registry) {
registry.addResourceHandler("/ui/**").addResourceLocations("classpath:/static/ui/");
}

@Override
public void run() {
try {
Coordinator c = new Coordinator(model, workingDirectory, grpcPort, workerCount);

//This wires up the bean for the rest api
ApiServiceCommand.m = c;

new Thread(() -> {
try {
c.start();
Expand All @@ -45,16 +65,13 @@ public void run() {
}
}).start();

/*UndertowJaxrsServer ut = new UndertowJaxrsServer();
ut.deploy(new JlamaRestApi(c), APPLICATION_PATH);
ut.addResourcePrefixPath(
"/ui",
resource(new ClassPathResourceManager(ServeCommand.class.getClassLoader()))
.setDirectoryListingEnabled(true)
.addWelcomeFiles("index.html"));
System.out.println("Chat UI: http://localhost:" + port + "/ui/index.html");
ut.start(Undertow.builder().addHttpListener(port, "0.0.0.0"));*/
System.out.println("Chat UI: http://localhost:" + port);
System.out.println("OpenAI Chat API: http://localhost:" + port + "/chat/completions");

new SpringApplicationBuilder(ClusterCoordinatorCommand.class).lazyInitialization(true)
.properties("server.port", "" + port, "logging.level.org.springframework.web", "info")
.build()
.run();

} catch (Exception e) {
e.printStackTrace();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ protected AbstractModel(

if (workingMemoryQType != workingMemoryDType) {
boolean supportsQType;
AbstractTensor tmp = makeTensor(Q8ByteBufferTensor.BLOCK_SIZE);
AbstractTensor tmp = makeDenseTensor(Q8ByteBufferTensor.BLOCK_SIZE);
try (AbstractTensor tmp2 = TensorOperationsProvider.get().quantize(tmp, workingMemoryQType, 0, Q8ByteBufferTensor.BLOCK_SIZE)) {
supportsQType = tmp2.dType() == workingMemoryQType;
if (!supportsQType) {
Expand Down Expand Up @@ -176,14 +176,11 @@ public Optional<PromptSupport> promptSupport() {
}

public AbstractTensor makeTensor(int... shape) {
TensorShape s;
if (c.offset().isPresent() && shape[shape.length - 1] == c.embeddingLength) s = TensorShape.sparse(shape, c.offset().get());
else s = TensorShape.of(shape);

TensorShape s = TensorShape.of(shape);
return c.tensorCache.get(workingDType, s);
}

public AbstractTensor makeFullTensor(int... shape) {
public AbstractTensor makeDenseTensor(int... shape) {
return c.tensorCache.get(workingDType, TensorShape.of(shape));
}

Expand Down Expand Up @@ -219,7 +216,7 @@ public AbstractTensor forward(
debug("EMBEDDING TOKEN", token_id);
debug("TOKEN POSITION", pos);

for (int i = c.layerStart(); i < c.layerEnd(); i++) {
for (int i = c.dctx().layerStart; i < c.dctx().layerEnd; i++) {
AbstractTensor kvlayer = kvbuf.slice(true, i);
AbstractTensor ref = embedding; // reference so we can free
embedding = transformerBlocks[i].forward(embedding, pos, kvlayer, normReducer, tensorReducer);
Expand All @@ -242,7 +239,7 @@ protected AbstractTensor batchForwardSlow(int[] token_ids, int startPos, Abstrac
protected AbstractTensor batchForward(int[] token_ids, int startPos, AbstractTensor kvbuf) {

AbstractTensor embedding = embedInput.batchInputsToEmbeddings(token_ids, startPos);
for (int i = c.layerStart(); i < c.layerEnd(); i++) {
for (int i = c.dctx().layerStart; i < c.dctx().layerEnd; i++) {
AbstractTensor kvlayer = kvbuf.slice(true, i);
AbstractTensor ref = embedding; // reference so we can free
embedding = transformerBlocks[i].forward(embedding, startPos, kvlayer, Optional.empty(), Optional.empty());
Expand Down Expand Up @@ -323,7 +320,7 @@ public Response generate(
StringBuilder responseText = new StringBuilder();
StringBuilder responseTextWithSpecialTokens = new StringBuilder();

try (AbstractTensor logits = makeTensor(c.vocabularySize)) {
try (AbstractTensor logits = makeDenseTensor(c.vocabularySize)) {
int[] promptTokens = new int[(1 + encoded.length)];

promptTokens[0] = c.bosToken;
Expand Down
Loading

0 comments on commit 50e4296

Please sign in to comment.