Skip to content

Commit

Permalink
Add download command to cli, rm shell script
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Feb 18, 2024
1 parent 887d879 commit 1ea1866
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 137 deletions.
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Add LLM Inference directly to your Java application.
Jlama includes a simple UI if you just want to chat with an llm.

```
./download-hf-model.sh tjake/llama2-7b-chat-hf-jlama-Q4
./run-cli.sh download tjake/llama2-7b-chat-hf-jlama-Q4
./run-cli.sh serve models/llama2-7b-chat-hf-jlama-Q4
```
Expand All @@ -53,21 +53,22 @@ open browser to http://localhost:8080/ui/index.html
Jlama includes a cli tool to run models via the `run-cli.sh` command.
Before you do that first download one or more models from huggingface.

Use the `download-hf-models.sh` script in the data directory to download models from huggingface.
Use the `./run-cli.sh download` command to download models from huggingface.

```shell
./download-hf-model.sh gpt2-medium
./download-hf-model.sh -a XXXXXXXX meta-llama/Llama-2-7b-chat-hf
./download-hf-model.sh intfloat/e5-small-v2
./run-cli.sh download gpt2-medium
./run-cli.sh download -a XXXXXXXX meta-llama/Llama-2-7b-chat-hf
./run-cli.sh download intfloat/e5-small-v2
```

Then run the cli:
Then run the cli tool to chat with the model or complete a prompt.

```shell
./run-cli.sh complete -p "The best part of waking up is " -t 0.7 -tc 16 -q Q4 -wq I8 models/Llama-2-7b-chat-hf
./run-cli.sh chat -p "Tell me a joke about cats." -t 0.7 -tc 16 -q Q4 -wq I8 models/Llama-2-7b-chat-hf
```
## 🧪 Examples

## 🧪 Examples
### Llama 2 7B

```
Expand Down
113 changes: 0 additions & 113 deletions download-hf-model.sh

This file was deleted.

5 changes: 5 additions & 0 deletions jlama-cli/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
<artifactId>picocli</artifactId>
<version>4.7.5</version>
</dependency>
<dependency>
<groupId>me.tongfei</groupId>
<artifactId>progressbar</artifactId>
<version>0.10.0</version>
</dependency>
<dependency>
<groupId>org.jboss.resteasy</groupId>
<artifactId>resteasy-jaxrs</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
package com.github.tjake.jlama.cli;

import com.github.tjake.jlama.cli.commands.ChatCommand;
import com.github.tjake.jlama.cli.commands.ClusterCoordinatorCommand;
import com.github.tjake.jlama.cli.commands.ClusterWorkerCommand;
import com.github.tjake.jlama.cli.commands.CompleteCommand;
import com.github.tjake.jlama.cli.commands.QuantizeCommand;
import com.github.tjake.jlama.cli.commands.ServeCommand;
import com.github.tjake.jlama.cli.commands.*;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import picocli.CommandLine;
import picocli.CommandLine.*;
Expand All @@ -23,6 +18,7 @@ public class JlamaCli implements Runnable {

public static void main(String[] args) {
CommandLine cli = new CommandLine(new JlamaCli());
cli.addSubcommand("download", new DownloadCommand());
cli.addSubcommand("quantize", new QuantizeCommand());
cli.addSubcommand("chat", new ChatCommand());
cli.addSubcommand("complete", new CompleteCommand());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package com.github.tjake.jlama.cli.commands;

import com.github.tjake.jlama.cli.JlamaCli;
import com.github.tjake.jlama.safetensors.SafeTensorSupport;
import com.google.common.util.concurrent.Uninterruptibles;
import me.tongfei.progressbar.ProgressBar;
import me.tongfei.progressbar.ProgressBarBuilder;
import me.tongfei.progressbar.ProgressBarStyle;
import picocli.CommandLine;

import java.io.File;
import java.io.IOException;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

@CommandLine.Command(name = "download", description = "Downloads the specified model", mixinStandardHelpOptions = true)
public class DownloadCommand extends JlamaCli {
@CommandLine.Option(names={"-d", "--model-directory"}, description = "The directory to download the model to", defaultValue = "models")
protected File modelDirectory = new File("models");

@CommandLine.Option(names={"-t", "--auth-token"}, description = "The auth token to use for downloading the model (if required)")
protected String authToken = null;


@CommandLine.Parameters(index = "0", arity = "1", description = "The model owner/name pair to download")
protected String model;

@Override
public void run() {
AtomicReference<ProgressBar> progressRef = new AtomicReference<>();

String[] parts = model.split("/");
if (parts.length == 0 || parts.length > 2) {
System.err.println("Model must be in the form owner/name");
System.exit(1);
}

String owner;
String name;

if (parts.length == 1) {
owner = null;
name = model;
} else {
owner = parts[0];
name = parts[1];
}


try {
SafeTensorSupport.maybeDownloadModel(modelDirectory.getAbsolutePath(), Optional.ofNullable(owner), name, Optional.ofNullable(authToken), Optional.of((n,c,t) -> {
if (progressRef.get() == null || !progressRef.get().getTaskName().equals(n)) {
ProgressBarBuilder builder = new ProgressBarBuilder()
.setTaskName(n)
.setInitialMax(t)
.setStyle(ProgressBarStyle.ASCII);

if (t > 1000000) {
builder.setUnit("MB", 1000000);
} else if (t > 1000) {
builder.setUnit("KB", 1000);
} else {
builder.setUnit("B", 1);
}

progressRef.set(builder.build());
}

progressRef.get().stepTo(c);
Uninterruptibles.sleepUninterruptibly(150, TimeUnit.MILLISECONDS);
}));
} catch (IOException e) {
e.printStackTrace();
System.exit(1);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.util.Pair;
import com.github.tjake.jlama.util.PhysicalCoreExecutor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.io.*;
import java.lang.reflect.InvocationTargetException;
import java.nio.file.Path;
import java.util.Objects;
import java.util.Optional;
import java.util.*;

public class ModelSupport {

private static final Logger logger = LoggerFactory.getLogger(ModelSupport.class);

private static final ObjectMapper om = new ObjectMapper()
.configure(DeserializationFeature.FAIL_ON_IGNORED_PROPERTIES, false)
.configure(DeserializationFeature.FAIL_ON_TRAILING_TOKENS, false)
Expand Down
Loading

0 comments on commit 1ea1866

Please sign in to comment.