Skip to content

Commit

Permalink
Range based http loader for distributed inf and cli improvements (#59)
Browse files Browse the repository at this point in the history
* Range based http loader for distributed inf and cli improvements
  • Loading branch information
tjake authored Sep 25, 2024
1 parent 2c3f080 commit 823deca
Show file tree
Hide file tree
Showing 36 changed files with 1,100 additions and 443 deletions.
30 changes: 18 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,14 @@ curl -Ls https://sh.jbang.dev | bash -s - app setup

#Install Jlama CLI (will ask if you trust the source)
jbang app install --force jlama@tjake

```

Now that you have jlama installed you can download a model from huggingface and chat with it.
Note I have pre-quantized models available at https://hf.co/tjake

```shell
# Download a small model (defaults to ./models)
jlama download tjake/TinyLlama-1.1B-Chat-v1.0-Jlama-Q4

# Run the openai chat api and UI on this model
jlama restapi models/TinyLlama-1.1B-Chat-v1.0-Jlama-Q4
# Run the openai chat api and UI on a model
jlama restapi tjake/TinyLlama-1.1B-Chat-v1.0-Jlama-Q4 --auto-download
```

open browser to http://localhost:8080/
Expand All @@ -74,19 +70,29 @@ open browser to http://localhost:8080/


```shell
Usage: jlama [COMMAND]
Jlama is a modern LLM inference engine for Java!
Usage:

jlama [COMMAND]

Description:

Jlama is a modern LLM inference engine for Java!
Quantized models are maintained at https://hf.co/tjake

Commands:
download Downloads a HuggingFace model - use owner/name format
quantize Quantize the specified model
Choose from the available commands:

Inference:
chat Interact with the specified model
complete Completes a prompt using the specified model
restapi Starts a openai compatible rest api for interacting with this model
complete Completes a prompt using the specified model

Distributed Inference:
cluster-coordinator Starts a distributed rest api for a model using cluster workers
cluster-worker Connects to a cluster coordinator to perform distributed inference

Other:
download Downloads a HuggingFace model - use owner/name format
quantize Quantize the specified model
```


Expand Down
6 changes: 3 additions & 3 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ services:
- cluster-coordinator
- --threads=2
- --worker-count=8
- /models/Llama-2-7b-chat-hf-jlama-Q4/
- tjake/Mistral-7B-Instruct-v0.3-jlama-Q4
volumes:
- "./models:/models"
healthcheck:
Expand All @@ -41,7 +41,7 @@ services:
command:
- cluster-worker
- --threads=1
- --host=jlama-coordinator
- /models/Llama-2-7b-chat-hf-jlama-Q4/
- --coordinator=jlama-coordinator
- tjake/Mistral-7B-Instruct-v0.3-jlama-Q4
volumes:
- "./models:/models"
163 changes: 151 additions & 12 deletions jlama-cli/src/main/java/com/github/tjake/jlama/cli/JlamaCli.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,183 @@
*/
package com.github.tjake.jlama.cli;

import ch.qos.logback.classic.LoggerContext;
import ch.qos.logback.classic.encoder.PatternLayoutEncoder;
import ch.qos.logback.core.ConsoleAppender;
import com.github.tjake.jlama.cli.commands.*;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import ch.qos.logback.classic.Level;
import ch.qos.logback.classic.Logger;
import org.slf4j.LoggerFactory;
import picocli.CommandLine;
import picocli.CommandLine.*;

@Command(name = "jlama", mixinStandardHelpOptions = true, requiredOptionMarker = '*', usageHelpAutoWidth = true, sortOptions = true, description = "Jlama is a modern LLM inference engine for Java!\n\n"
+ "Quantized models are maintained at https://hf.co/tjake\n")
import java.util.*;

import static java.util.Arrays.asList;
import static picocli.CommandLine.Help.Column.Overflow.*;
import static picocli.CommandLine.Model.UsageMessageSpec.*;

@Command(name = "jlama", sortOptions = false, headerHeading = "Usage:%n", synopsisHeading = "%n", descriptionHeading = "%nDescription:%n%n", parameterListHeading = "%nParameters:%n", optionListHeading = "%nCommand Options:%n", mixinStandardHelpOptions = true, usageHelpAutoWidth = true, requiredOptionMarker = '*', description = "Jlama is a modern LLM inference engine for Java!\nQuantized models are maintained at https://hf.co/tjake\n\nChoose from the available commands:", defaultValueProvider = PropertiesDefaultProvider.class)
public class JlamaCli implements Runnable {
static {
System.setProperty("jdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK", "0");
TensorOperationsProvider.get();
setupLogging();
}

@Option(names = { "-h", "--help" }, usageHelp = true, hidden = true)
boolean helpRequested = false;

public static void main(String[] args) {
Logger root = (Logger) LoggerFactory.getLogger(org.slf4j.Logger.ROOT_LOGGER_NAME);
root.setLevel(Level.INFO);

CommandLine cli = new CommandLine(new JlamaCli());
cli.addSubcommand("download", new DownloadCommand());
cli.addSubcommand("quantize", new QuantizeCommand());
cli.addSubcommand("chat", new ChatCommand());
cli.addSubcommand("complete", new CompleteCommand());
cli.addSubcommand("restapi", new ApiServiceCommand());
cli.addSubcommand("complete", new CompleteCommand());
cli.addSubcommand("download", new DownloadCommand());
cli.addSubcommand("quantize", new QuantizeCommand());
cli.addSubcommand("cluster-coordinator", new ClusterCoordinatorCommand());
cli.addSubcommand("cluster-worker", new ClusterWorkerCommand());

cli.setUsageHelpLongOptionsMaxWidth(256);
cli.getHelpSectionMap().remove(SECTION_KEY_COMMAND_LIST_HEADING);
cli.getHelpSectionMap().put(SECTION_KEY_COMMAND_LIST, getCommandRenderer());

String[] pargs = args.length == 0 ? new String[] { "-h" } : args;
cli.parseWithHandler(new RunLast(), pargs);
}

@Override
public void run() {}

/** Shamelessly stolen from jbang */
public static CommandGroupRenderer getCommandRenderer() {
Map<String, List<String>> sections = new LinkedHashMap<>();
sections.put("Inference", asList("chat", "restapi", "complete"));
sections.put("Distributed Inference", asList("cluster-coordinator", "cluster-worker"));
sections.put("Other", asList("download", "quantize"));
CommandGroupRenderer renderer = new CommandGroupRenderer(sections);
return renderer;
}

public static class CommandGroupRenderer implements CommandLine.IHelpSectionRenderer {
private final Map<String, List<String>> sections;

public CommandGroupRenderer(Map<String, List<String>> sections) {
this.sections = sections;
}

/**
* validate all commands in Help is covered by section and each section command
* exist in help.
*
* @param help
*/
public void validate(CommandLine.Help help) {
Set<String> cmds = new HashSet<>();
sections.forEach((key, value) -> cmds.addAll(value));

Set<String> actualcmds = new HashSet<>(help.subcommands().keySet());

actualcmds.removeAll(cmds);

cmds.removeAll(help.subcommands().keySet());

if (cmds.size() > 0) {
throw new IllegalStateException("Section help defined for non existent commands" + cmds);
}

if (actualcmds.size() > 0) {
throw new IllegalStateException(("Commands found with no assigned section" + actualcmds));
}

sections.forEach((key, value) -> cmds.addAll(value));

}

// @Override
public String render(CommandLine.Help help) {
if (help.commandSpec().subcommands().isEmpty()) {
return "";
}

StringBuilder result = new StringBuilder();

sections.forEach((key, value) -> result.append(renderSection(key, value, help)));
return result.toString();
}

private String renderSection(String sectionHeading, List<String> cmdNames, CommandLine.Help help) {
Help.TextTable textTable = createTextTable(help);

for (String name : cmdNames) {
Model.CommandSpec sub = help.commandSpec().subcommands().get(name).getCommandSpec();

// create comma-separated list of command name and aliases
String names = sub.names().toString();
names = names.substring(1, names.length() - 1); // remove leading '[' and trailing ']'

// description may contain line separators; use Text::splitLines to handle this
String description = description(sub.usageMessage());
CommandLine.Help.Ansi.Text[] lines = help.colorScheme().text(description).splitLines();

for (int i = 0; i < lines.length; i++) {
CommandLine.Help.Ansi.Text cmdNamesText = help.colorScheme().commandText(i == 0 ? names : "");
textTable.addRowValues(cmdNamesText, lines[i]);
}
}
return help.createHeading("%n" + sectionHeading + ":%n") + textTable.toString();
}

private static Help.TextTable createTextTable(CommandLine.Help help) {
Model.CommandSpec spec = help.commandSpec();
// prepare layout: two columns
// the left column overflows, the right column wraps if text is too long
int commandLength = maxLength(spec.subcommands(), 37);
Help.TextTable textTable = Help.TextTable.forColumns(
help.colorScheme(),
new CommandLine.Help.Column(commandLength + 2, 2, SPAN),
new CommandLine.Help.Column(spec.usageMessage().width() - (commandLength + 2), 2, WRAP)
);
textTable.setAdjustLineBreaksForWideCJKCharacters(spec.usageMessage().adjustLineBreaksForWideCJKCharacters());
return textTable;
}

private static int maxLength(Map<String, CommandLine> subcommands, int max) {
int result = subcommands.values()
.stream()
.map(cmd -> cmd.getCommandSpec().names().toString().length() - 2)
.max(Integer::compareTo)
.get();
return Math.min(max, result);
}

private String description(Model.UsageMessageSpec usageMessage) {
if (usageMessage.header().length > 0) {
return usageMessage.header()[0];
}
if (usageMessage.description().length > 0) {
return usageMessage.description()[0];
}
return "";
}
}

private static void setupLogging() {
Logger root = (Logger) LoggerFactory.getLogger(org.slf4j.Logger.ROOT_LOGGER_NAME);
LoggerContext logCtx = root.getLoggerContext();

logCtx.reset();

PatternLayoutEncoder logEncoder = new PatternLayoutEncoder();
logEncoder.setContext(logCtx);
logEncoder.setPattern("%msg%n");
logEncoder.start();

ConsoleAppender logConsoleAppender = new ConsoleAppender();
logConsoleAppender.setContext(logCtx);
logConsoleAppender.setName("console");
logConsoleAppender.setEncoder(logEncoder);
logConsoleAppender.start();

root.addAppender(logConsoleAppender);
root.setAdditive(false);
root.setLevel(Level.INFO);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static com.github.tjake.jlama.model.ModelSupport.loadModel;

import java.nio.file.Path;
import java.util.Optional;

import com.github.tjake.jlama.model.functions.Generator;
Expand All @@ -31,14 +32,15 @@
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
import picocli.CommandLine;

@CommandLine.Command(name = "restapi", description = "Starts a openai compatible rest api for interacting with this model")
@CommandLine.Command(name = "restapi", description = "Starts a openai compatible rest api for interacting with this model", abbreviateSynopsis = true)
@SpringBootApplication(scanBasePackages = { "com.github.tjake.jlama.net.openai", "com.github.tjake.jlama.cli.commands" })
@SpringBootConfiguration
@Configuration
public class ApiServiceCommand extends BaseCommand implements WebMvcConfigurer {
private static final Logger logger = LoggerFactory.getLogger(ApiServiceCommand.class);

@CommandLine.Option(names = { "-p", "--port" }, description = "http port (default: ${DEFAULT-VALUE})", defaultValue = "8080")
@CommandLine.Option(names = {
"--port" }, paramLabel = "ARG", description = "http port (default: ${DEFAULT-VALUE})", defaultValue = "8080")
int port = 8080;

protected static volatile Generator m;
Expand All @@ -56,13 +58,21 @@ public void addResourceHandlers(ResourceHandlerRegistry registry) {
@Override
public void run() {
try {
Path modelPath = SimpleBaseCommand.getModel(
modelName,
modelDirectory,
downloadSection.autoDownload,
downloadSection.branch,
downloadSection.authToken
);

m = loadModel(
model,
modelPath.toFile(),
workingDirectory,
workingMemoryType,
workingQuantizationType,
Optional.ofNullable(modelQuantization),
Optional.ofNullable(threadCount)
advancedSection.workingMemoryType,
advancedSection.workingQuantizationType,
Optional.ofNullable(advancedSection.modelQuantization),
Optional.ofNullable(advancedSection.threadCount)
);

System.out.println("Chat UI: http://localhost:" + port);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,26 @@
import picocli.CommandLine;

public class BaseCommand extends SimpleBaseCommand {
@CommandLine.Option(names = { "-d", "--working-directory" }, description = "Working directory for attention cache")
@CommandLine.Option(names = { "--work-directory" }, paramLabel = "ARG", description = "Working directory for attention cache")
protected File workingDirectory = null;

@CommandLine.Option(names = { "-wm",
"--working-dtype" }, description = "Working memory data type (default: ${DEFAULT-VALUE})", defaultValue = "F32")
protected DType workingMemoryType = DType.F32;
@CommandLine.ArgGroup(exclusive = false, heading = "Advanced Options:%n")
protected AdvancedSection advancedSection = new AdvancedSection();

@CommandLine.Option(names = { "-wq",
"--working-qtype" }, description = "Working memory quantization data type (default: ${DEFAULT-VALUE})", defaultValue = "I8")
protected DType workingQuantizationType = DType.I8;
static class AdvancedSection {
@CommandLine.Option(names = {
"--working-dtype" }, paramLabel = "ARG", description = "Working memory data type (default: ${DEFAULT-VALUE})", defaultValue = "F32")
protected DType workingMemoryType = DType.F32;

@CommandLine.Option(names = { "-tc", "--threads" }, description = "Number of threads to use (default: number of cores)")
protected Integer threadCount = null;
@CommandLine.Option(names = {
"--working-qtype" }, paramLabel = "ARG", description = "Working memory quantization data type (default: ${DEFAULT-VALUE})", defaultValue = "I8")
protected DType workingQuantizationType = DType.I8;

@CommandLine.Option(names = { "-q", "--quantization" }, description = "Model quantization type")
protected DType modelQuantization;
@CommandLine.Option(names = {
"--threads" }, paramLabel = "ARG", description = "Number of threads to use (default: number of physical cores)")
protected Integer threadCount = null;

@CommandLine.Option(names = { "--quantize-to" }, paramLabel = "ARG", description = "Runtime Model quantization type")
protected DType modelQuantization;
}
}
Loading

0 comments on commit 823deca

Please sign in to comment.