From 823deca6ed6bd6299baf946f871fde1178cb20a4 Mon Sep 17 00:00:00 2001 From: Jake Luciani Date: Tue, 24 Sep 2024 23:09:57 -0400 Subject: [PATCH] Range based http loader for distributed inf and cli improvements (#59) * Range based http loader for distributed inf and cli improvements --- README.md | 30 ++- docker-compose.yaml | 6 +- .../com/github/tjake/jlama/cli/JlamaCli.java | 163 +++++++++++- .../jlama/cli/commands/ApiServiceCommand.java | 24 +- .../tjake/jlama/cli/commands/BaseCommand.java | 28 +- .../tjake/jlama/cli/commands/ChatCommand.java | 49 ++-- .../commands/ClusterCoordinatorCommand.java | 46 +++- .../cli/commands/ClusterWorkerCommand.java | 46 +++- .../jlama/cli/commands/CompleteCommand.java | 27 +- .../jlama/cli/commands/DownloadCommand.java | 83 ++---- .../jlama/cli/commands/ModelBaseCommand.java | 15 +- .../jlama/cli/commands/QuantizeCommand.java | 17 +- .../jlama/cli/commands/SimpleBaseCommand.java | 125 ++++++++- .../github/tjake/jlama/math/VectorMath.java | 2 +- .../tjake/jlama/model/AbstractModel.java | 38 ++- .../tjake/jlama/model/ModelSupport.java | 32 ++- .../tjake/jlama/model/bert/BertModel.java | 36 +-- .../jlama/model/functions/ClassifyOutput.java | 16 ++ .../jlama/model/functions/Generator.java | 1 - .../jlama/model/functions/PoolingLayer.java | 16 ++ .../tjake/jlama/safetensors/Config.java | 60 +++-- .../safetensors/HTTPSafeTensorLoader.java | 243 ++++++++++++++++++ .../jlama/safetensors/SafeTensorIndex.java | 28 +- .../jlama/safetensors/SafeTensorSupport.java | 58 ++++- .../tjake/jlama/safetensors/WeightLoader.java | 1 + .../tjake/jlama/safetensors/Weights.java | 60 +++-- .../tjake/jlama/tensor/AbstractTensor.java | 12 - .../github/tjake/jlama/util/HttpSupport.java | 28 +- .../github/tjake/jlama/net/Coordinator.java | 26 +- .../com/github/tjake/jlama/net/Worker.java | 22 +- .../jlama/net/DistributedServiceTest.java | 72 +++--- .../github/tjake/jlama/model/TestModels.java | 87 +++---- .../github/tjake/jlama/model/TestSample.java | 13 +- jlama.java | 21 -- pom.xml | 2 +- run-cli.sh | 10 +- 36 files changed, 1100 insertions(+), 443 deletions(-) create mode 100644 jlama-core/src/main/java/com/github/tjake/jlama/safetensors/HTTPSafeTensorLoader.java delete mode 100755 jlama.java diff --git a/README.md b/README.md index 159a766..0925bcf 100644 --- a/README.md +++ b/README.md @@ -52,18 +52,14 @@ curl -Ls https://sh.jbang.dev | bash -s - app setup #Install Jlama CLI (will ask if you trust the source) jbang app install --force jlama@tjake - ``` Now that you have jlama installed you can download a model from huggingface and chat with it. Note I have pre-quantized models available at https://hf.co/tjake ```shell -# Download a small model (defaults to ./models) -jlama download tjake/TinyLlama-1.1B-Chat-v1.0-Jlama-Q4 - -# Run the openai chat api and UI on this model -jlama restapi models/TinyLlama-1.1B-Chat-v1.0-Jlama-Q4 +# Run the openai chat api and UI on a model +jlama restapi tjake/TinyLlama-1.1B-Chat-v1.0-Jlama-Q4 --auto-download ``` open browser to http://localhost:8080/ @@ -74,19 +70,29 @@ open browser to http://localhost:8080/ ```shell -Usage: jlama [COMMAND] -Jlama is a modern LLM inference engine for Java! +Usage: + +jlama [COMMAND] +Description: + +Jlama is a modern LLM inference engine for Java! Quantized models are maintained at https://hf.co/tjake -Commands: - download Downloads a HuggingFace model - use owner/name format - quantize Quantize the specified model +Choose from the available commands: + +Inference: chat Interact with the specified model - complete Completes a prompt using the specified model restapi Starts a openai compatible rest api for interacting with this model + complete Completes a prompt using the specified model + +Distributed Inference: cluster-coordinator Starts a distributed rest api for a model using cluster workers cluster-worker Connects to a cluster coordinator to perform distributed inference + +Other: + download Downloads a HuggingFace model - use owner/name format + quantize Quantize the specified model ``` diff --git a/docker-compose.yaml b/docker-compose.yaml index 638bab6..fd1e046 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -17,7 +17,7 @@ services: - cluster-coordinator - --threads=2 - --worker-count=8 - - /models/Llama-2-7b-chat-hf-jlama-Q4/ + - tjake/Mistral-7B-Instruct-v0.3-jlama-Q4 volumes: - "./models:/models" healthcheck: @@ -41,7 +41,7 @@ services: command: - cluster-worker - --threads=1 - - --host=jlama-coordinator - - /models/Llama-2-7b-chat-hf-jlama-Q4/ + - --coordinator=jlama-coordinator + - tjake/Mistral-7B-Instruct-v0.3-jlama-Q4 volumes: - "./models:/models" diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/JlamaCli.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/JlamaCli.java index 84d7707..802dcc1 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/JlamaCli.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/JlamaCli.java @@ -15,39 +15,43 @@ */ package com.github.tjake.jlama.cli; +import ch.qos.logback.classic.LoggerContext; +import ch.qos.logback.classic.encoder.PatternLayoutEncoder; +import ch.qos.logback.core.ConsoleAppender; import com.github.tjake.jlama.cli.commands.*; -import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import ch.qos.logback.classic.Level; import ch.qos.logback.classic.Logger; import org.slf4j.LoggerFactory; import picocli.CommandLine; import picocli.CommandLine.*; -@Command(name = "jlama", mixinStandardHelpOptions = true, requiredOptionMarker = '*', usageHelpAutoWidth = true, sortOptions = true, description = "Jlama is a modern LLM inference engine for Java!\n\n" - + "Quantized models are maintained at https://hf.co/tjake\n") +import java.util.*; + +import static java.util.Arrays.asList; +import static picocli.CommandLine.Help.Column.Overflow.*; +import static picocli.CommandLine.Model.UsageMessageSpec.*; + +@Command(name = "jlama", sortOptions = false, headerHeading = "Usage:%n", synopsisHeading = "%n", descriptionHeading = "%nDescription:%n%n", parameterListHeading = "%nParameters:%n", optionListHeading = "%nCommand Options:%n", mixinStandardHelpOptions = true, usageHelpAutoWidth = true, requiredOptionMarker = '*', description = "Jlama is a modern LLM inference engine for Java!\nQuantized models are maintained at https://hf.co/tjake\n\nChoose from the available commands:", defaultValueProvider = PropertiesDefaultProvider.class) public class JlamaCli implements Runnable { static { - System.setProperty("jdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK", "0"); - TensorOperationsProvider.get(); + setupLogging(); } @Option(names = { "-h", "--help" }, usageHelp = true, hidden = true) boolean helpRequested = false; public static void main(String[] args) { - Logger root = (Logger) LoggerFactory.getLogger(org.slf4j.Logger.ROOT_LOGGER_NAME); - root.setLevel(Level.INFO); - CommandLine cli = new CommandLine(new JlamaCli()); - cli.addSubcommand("download", new DownloadCommand()); - cli.addSubcommand("quantize", new QuantizeCommand()); cli.addSubcommand("chat", new ChatCommand()); - cli.addSubcommand("complete", new CompleteCommand()); cli.addSubcommand("restapi", new ApiServiceCommand()); + cli.addSubcommand("complete", new CompleteCommand()); + cli.addSubcommand("download", new DownloadCommand()); + cli.addSubcommand("quantize", new QuantizeCommand()); cli.addSubcommand("cluster-coordinator", new ClusterCoordinatorCommand()); cli.addSubcommand("cluster-worker", new ClusterWorkerCommand()); - cli.setUsageHelpLongOptionsMaxWidth(256); + cli.getHelpSectionMap().remove(SECTION_KEY_COMMAND_LIST_HEADING); + cli.getHelpSectionMap().put(SECTION_KEY_COMMAND_LIST, getCommandRenderer()); String[] pargs = args.length == 0 ? new String[] { "-h" } : args; cli.parseWithHandler(new RunLast(), pargs); @@ -55,4 +59,139 @@ public static void main(String[] args) { @Override public void run() {} + + /** Shamelessly stolen from jbang */ + public static CommandGroupRenderer getCommandRenderer() { + Map> sections = new LinkedHashMap<>(); + sections.put("Inference", asList("chat", "restapi", "complete")); + sections.put("Distributed Inference", asList("cluster-coordinator", "cluster-worker")); + sections.put("Other", asList("download", "quantize")); + CommandGroupRenderer renderer = new CommandGroupRenderer(sections); + return renderer; + } + + public static class CommandGroupRenderer implements CommandLine.IHelpSectionRenderer { + private final Map> sections; + + public CommandGroupRenderer(Map> sections) { + this.sections = sections; + } + + /** + * validate all commands in Help is covered by section and each section command + * exist in help. + * + * @param help + */ + public void validate(CommandLine.Help help) { + Set cmds = new HashSet<>(); + sections.forEach((key, value) -> cmds.addAll(value)); + + Set actualcmds = new HashSet<>(help.subcommands().keySet()); + + actualcmds.removeAll(cmds); + + cmds.removeAll(help.subcommands().keySet()); + + if (cmds.size() > 0) { + throw new IllegalStateException("Section help defined for non existent commands" + cmds); + } + + if (actualcmds.size() > 0) { + throw new IllegalStateException(("Commands found with no assigned section" + actualcmds)); + } + + sections.forEach((key, value) -> cmds.addAll(value)); + + } + + // @Override + public String render(CommandLine.Help help) { + if (help.commandSpec().subcommands().isEmpty()) { + return ""; + } + + StringBuilder result = new StringBuilder(); + + sections.forEach((key, value) -> result.append(renderSection(key, value, help))); + return result.toString(); + } + + private String renderSection(String sectionHeading, List cmdNames, CommandLine.Help help) { + Help.TextTable textTable = createTextTable(help); + + for (String name : cmdNames) { + Model.CommandSpec sub = help.commandSpec().subcommands().get(name).getCommandSpec(); + + // create comma-separated list of command name and aliases + String names = sub.names().toString(); + names = names.substring(1, names.length() - 1); // remove leading '[' and trailing ']' + + // description may contain line separators; use Text::splitLines to handle this + String description = description(sub.usageMessage()); + CommandLine.Help.Ansi.Text[] lines = help.colorScheme().text(description).splitLines(); + + for (int i = 0; i < lines.length; i++) { + CommandLine.Help.Ansi.Text cmdNamesText = help.colorScheme().commandText(i == 0 ? names : ""); + textTable.addRowValues(cmdNamesText, lines[i]); + } + } + return help.createHeading("%n" + sectionHeading + ":%n") + textTable.toString(); + } + + private static Help.TextTable createTextTable(CommandLine.Help help) { + Model.CommandSpec spec = help.commandSpec(); + // prepare layout: two columns + // the left column overflows, the right column wraps if text is too long + int commandLength = maxLength(spec.subcommands(), 37); + Help.TextTable textTable = Help.TextTable.forColumns( + help.colorScheme(), + new CommandLine.Help.Column(commandLength + 2, 2, SPAN), + new CommandLine.Help.Column(spec.usageMessage().width() - (commandLength + 2), 2, WRAP) + ); + textTable.setAdjustLineBreaksForWideCJKCharacters(spec.usageMessage().adjustLineBreaksForWideCJKCharacters()); + return textTable; + } + + private static int maxLength(Map subcommands, int max) { + int result = subcommands.values() + .stream() + .map(cmd -> cmd.getCommandSpec().names().toString().length() - 2) + .max(Integer::compareTo) + .get(); + return Math.min(max, result); + } + + private String description(Model.UsageMessageSpec usageMessage) { + if (usageMessage.header().length > 0) { + return usageMessage.header()[0]; + } + if (usageMessage.description().length > 0) { + return usageMessage.description()[0]; + } + return ""; + } + } + + private static void setupLogging() { + Logger root = (Logger) LoggerFactory.getLogger(org.slf4j.Logger.ROOT_LOGGER_NAME); + LoggerContext logCtx = root.getLoggerContext(); + + logCtx.reset(); + + PatternLayoutEncoder logEncoder = new PatternLayoutEncoder(); + logEncoder.setContext(logCtx); + logEncoder.setPattern("%msg%n"); + logEncoder.start(); + + ConsoleAppender logConsoleAppender = new ConsoleAppender(); + logConsoleAppender.setContext(logCtx); + logConsoleAppender.setName("console"); + logConsoleAppender.setEncoder(logEncoder); + logConsoleAppender.start(); + + root.addAppender(logConsoleAppender); + root.setAdditive(false); + root.setLevel(Level.INFO); + } } diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ApiServiceCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ApiServiceCommand.java index 096cc68..0b5d40c 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ApiServiceCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ApiServiceCommand.java @@ -17,6 +17,7 @@ import static com.github.tjake.jlama.model.ModelSupport.loadModel; +import java.nio.file.Path; import java.util.Optional; import com.github.tjake.jlama.model.functions.Generator; @@ -31,14 +32,15 @@ import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; import picocli.CommandLine; -@CommandLine.Command(name = "restapi", description = "Starts a openai compatible rest api for interacting with this model") +@CommandLine.Command(name = "restapi", description = "Starts a openai compatible rest api for interacting with this model", abbreviateSynopsis = true) @SpringBootApplication(scanBasePackages = { "com.github.tjake.jlama.net.openai", "com.github.tjake.jlama.cli.commands" }) @SpringBootConfiguration @Configuration public class ApiServiceCommand extends BaseCommand implements WebMvcConfigurer { private static final Logger logger = LoggerFactory.getLogger(ApiServiceCommand.class); - @CommandLine.Option(names = { "-p", "--port" }, description = "http port (default: ${DEFAULT-VALUE})", defaultValue = "8080") + @CommandLine.Option(names = { + "--port" }, paramLabel = "ARG", description = "http port (default: ${DEFAULT-VALUE})", defaultValue = "8080") int port = 8080; protected static volatile Generator m; @@ -56,13 +58,21 @@ public void addResourceHandlers(ResourceHandlerRegistry registry) { @Override public void run() { try { + Path modelPath = SimpleBaseCommand.getModel( + modelName, + modelDirectory, + downloadSection.autoDownload, + downloadSection.branch, + downloadSection.authToken + ); + m = loadModel( - model, + modelPath.toFile(), workingDirectory, - workingMemoryType, - workingQuantizationType, - Optional.ofNullable(modelQuantization), - Optional.ofNullable(threadCount) + advancedSection.workingMemoryType, + advancedSection.workingQuantizationType, + Optional.ofNullable(advancedSection.modelQuantization), + Optional.ofNullable(advancedSection.threadCount) ); System.out.println("Chat UI: http://localhost:" + port); diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/BaseCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/BaseCommand.java index 867a469..7796ca3 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/BaseCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/BaseCommand.java @@ -20,20 +20,26 @@ import picocli.CommandLine; public class BaseCommand extends SimpleBaseCommand { - @CommandLine.Option(names = { "-d", "--working-directory" }, description = "Working directory for attention cache") + @CommandLine.Option(names = { "--work-directory" }, paramLabel = "ARG", description = "Working directory for attention cache") protected File workingDirectory = null; - @CommandLine.Option(names = { "-wm", - "--working-dtype" }, description = "Working memory data type (default: ${DEFAULT-VALUE})", defaultValue = "F32") - protected DType workingMemoryType = DType.F32; + @CommandLine.ArgGroup(exclusive = false, heading = "Advanced Options:%n") + protected AdvancedSection advancedSection = new AdvancedSection(); - @CommandLine.Option(names = { "-wq", - "--working-qtype" }, description = "Working memory quantization data type (default: ${DEFAULT-VALUE})", defaultValue = "I8") - protected DType workingQuantizationType = DType.I8; + static class AdvancedSection { + @CommandLine.Option(names = { + "--working-dtype" }, paramLabel = "ARG", description = "Working memory data type (default: ${DEFAULT-VALUE})", defaultValue = "F32") + protected DType workingMemoryType = DType.F32; - @CommandLine.Option(names = { "-tc", "--threads" }, description = "Number of threads to use (default: number of cores)") - protected Integer threadCount = null; + @CommandLine.Option(names = { + "--working-qtype" }, paramLabel = "ARG", description = "Working memory quantization data type (default: ${DEFAULT-VALUE})", defaultValue = "I8") + protected DType workingQuantizationType = DType.I8; - @CommandLine.Option(names = { "-q", "--quantization" }, description = "Model quantization type") - protected DType modelQuantization; + @CommandLine.Option(names = { + "--threads" }, paramLabel = "ARG", description = "Number of threads to use (default: number of physical cores)") + protected Integer threadCount = null; + + @CommandLine.Option(names = { "--quantize-to" }, paramLabel = "ARG", description = "Runtime Model quantization type") + protected DType modelQuantization; + } } diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ChatCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ChatCommand.java index e8ab1e1..55bd7e8 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ChatCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ChatCommand.java @@ -20,40 +20,43 @@ import com.diogonunes.jcolor.AnsiFormat; import com.diogonunes.jcolor.Attribute; import com.github.tjake.jlama.model.AbstractModel; +import com.github.tjake.jlama.model.functions.Generator; import com.github.tjake.jlama.safetensors.prompt.PromptContext; import com.github.tjake.jlama.safetensors.prompt.PromptSupport; + import java.io.PrintWriter; +import java.nio.file.Path; import java.util.Optional; import java.util.Scanner; import java.util.UUID; import java.util.function.BiConsumer; + import picocli.CommandLine.*; -@Command(name = "chat", description = "Interact with the specified model") -public class ChatCommand extends BaseCommand { +@Command(name = "chat", description = "Interact with the specified model", abbreviateSynopsis = true) +public class ChatCommand extends ModelBaseCommand { private static final AnsiFormat chatText = new AnsiFormat(Attribute.CYAN_TEXT()); private static final AnsiFormat statsColor = new AnsiFormat(Attribute.BLUE_TEXT()); - @Option(names = { "-s", "--system-prompt" }, description = "Change the default system prompt for this model") + @Option(names = { "--system-prompt" }, paramLabel = "ARG", description = "Change the default system prompt for this model") String systemPrompt = null; - @Option(names = { "-t", - "--temperature" }, description = "Temperature of response [0,1] (default: ${DEFAULT-VALUE})", defaultValue = "0.6") - protected Float temperature; - - @Option(names = { - "--top-p" }, description = "Controls how many different words the model considers per token [0,1] (default: ${DEFAULT-VALUE})", defaultValue = ".9") - protected Float topp; - @Override public void run() { + Path modelPath = SimpleBaseCommand.getModel( + modelName, + modelDirectory, + downloadSection.autoDownload, + downloadSection.branch, + downloadSection.authToken + ); AbstractModel m = loadModel( - model, + modelPath.toFile(), workingDirectory, - workingMemoryType, - workingQuantizationType, - Optional.ofNullable(modelQuantization), - Optional.ofNullable(threadCount) + advancedSection.workingMemoryType, + advancedSection.workingQuantizationType, + Optional.ofNullable(advancedSection.modelQuantization), + Optional.ofNullable(advancedSection.threadCount) ); if (m.promptSupport().isEmpty()) { @@ -65,7 +68,7 @@ public void run() { PromptSupport promptSupport = m.promptSupport().get(); PrintWriter out = System.console().writer(); - out.println("Chatting with " + model + "...\n"); + out.println("\nChatting with " + modelName + "...\n"); out.flush(); Scanner sc = new Scanner(System.in); boolean first = true; @@ -86,7 +89,17 @@ public void run() { builder.addUserMessage(prompt); PromptContext builtPrompt = builder.build(); - m.generate(session, builtPrompt, temperature, Integer.MAX_VALUE, makeOutHandler()); + Generator.Response r = m.generate(session, builtPrompt, temperature, Integer.MAX_VALUE, makeOutHandler()); + + out.println( + "\n\n" + + statsColor.format( + Math.round(r.promptTimeMs / (double) r.promptTokens) + + " ms/tok (prompt), " + + Math.round(r.generateTimeMs / (double) r.generatedTokens) + + " ms/tok (gen)" + ) + ); first = false; } diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ClusterCoordinatorCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ClusterCoordinatorCommand.java index 18470fb..32424ca 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ClusterCoordinatorCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ClusterCoordinatorCommand.java @@ -16,6 +16,7 @@ package com.github.tjake.jlama.cli.commands; import com.github.tjake.jlama.net.Coordinator; +import com.github.tjake.jlama.safetensors.DType; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.builder.SpringApplicationBuilder; @@ -24,23 +25,31 @@ import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; import picocli.CommandLine; -@CommandLine.Command(name = "cluster-coordinator", description = "Starts a distributed rest api for a model using cluster workers") +import java.nio.file.Path; +import java.util.Optional; + +@CommandLine.Command(name = "cluster-coordinator", description = "Starts a distributed rest api for a model using cluster workers", abbreviateSynopsis = true) @SpringBootApplication(scanBasePackages = { "com.github.tjake.jlama.net.openai", "com.github.tjake.jlama.cli.commands" }) @SpringBootConfiguration @Configuration -public class ClusterCoordinatorCommand extends BaseCommand implements WebMvcConfigurer { +public class ClusterCoordinatorCommand extends ModelBaseCommand implements WebMvcConfigurer { - @CommandLine.Option(names = { "-w", "--worker-count" }, description = "signifies this instance is a coordinator", required = true) + @CommandLine.Option(names = { + "--worker-count" }, paramLabel = "ARG", description = "signifies this instance is a coordinator", required = true) int workerCount = 1; - @CommandLine.Option(names = { "-g", - "--grpc-port" }, description = "grpc port to listen on (default: ${DEFAULT-VALUE})", defaultValue = "9777") + @CommandLine.Option(names = { + "--grpc-port" }, paramLabel = "ARG", description = "grpc port to listen on (default: ${DEFAULT-VALUE})", defaultValue = "9777") int grpcPort = 9777; - @CommandLine.Option(names = { "-p", - "--port" }, description = "http port to listen on (default: ${DEFAULT-VALUE})", defaultValue = "8080") + @CommandLine.Option(names = { + "--port" }, paramLabel = "ARG", description = "http port to listen on (default: ${DEFAULT-VALUE})", defaultValue = "8080") int port = 8080; + @CommandLine.Option(names = { + "--model-type" }, paramLabel = "ARG", description = "The models base type F32/BF16 (default: ${DEFAULT-VALUE})", defaultValue = "F32") + DType modelType = DType.F32; + @Override public void addResourceHandlers(ResourceHandlerRegistry registry) { registry.addResourceHandler("/ui/**").addResourceLocations("classpath:/static/ui/"); @@ -49,7 +58,28 @@ public void addResourceHandlers(ResourceHandlerRegistry registry) { @Override public void run() { try { - Coordinator c = new Coordinator(model, workingDirectory, grpcPort, workerCount); + + // Download the model metadata + Path model = SimpleBaseCommand.getModel( + modelName, + modelDirectory, + true, + downloadSection.branch, + downloadSection.authToken, + false + ); + + Coordinator c = new Coordinator( + model.toFile(), + SimpleBaseCommand.getOwner(modelName), + SimpleBaseCommand.getName(modelName), + modelType, + workingDirectory, + grpcPort, + workerCount, + Optional.ofNullable(downloadSection.authToken), + Optional.ofNullable(downloadSection.branch) + ); // This wires up the bean for the rest api ApiServiceCommand.m = c; diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ClusterWorkerCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ClusterWorkerCommand.java index d47162e..55c9f29 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ClusterWorkerCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ClusterWorkerCommand.java @@ -16,40 +16,64 @@ package com.github.tjake.jlama.cli.commands; import com.github.tjake.jlama.net.Worker; + +import java.nio.file.Path; import java.util.Optional; + +import com.github.tjake.jlama.safetensors.DType; import picocli.CommandLine; -@CommandLine.Command(name = "cluster-worker", description = "Connects to a cluster coordinator to perform distributed inference") +@CommandLine.Command(name = "cluster-worker", description = "Connects to a cluster coordinator to perform distributed inference", abbreviateSynopsis = true) public class ClusterWorkerCommand extends BaseCommand { private static final Boolean useHostnameAsWorkerId = Boolean.getBoolean("jlama.use_hostname_as_workerid"); private static final String HOSTNAME = System.getenv("HOSTNAME"); - @CommandLine.Option(names = { "-o", "--host" }, description = "hostname of coordinator", required = true) + @CommandLine.Option(names = { "--coordinator" }, paramLabel = "ARG", description = "hostname/ip of coordinator", required = true) String host; - @CommandLine.Option(names = { "-g", - "--grpc-port" }, description = "grpc port to listen on (default: ${DEFAULT-VALUE})", defaultValue = "9777") + @CommandLine.Option(names = { + "--grpc-port" }, description = "grpc port to listen on (default: ${DEFAULT-VALUE})", paramLabel = "ARG", defaultValue = "9777") int grpcPort = 9777; - @CommandLine.Option(names = { "-w", - "--worker-id" }, description = "consistent name to use when register this worker with the coordinator") + @CommandLine.Option(names = { + "--worker-id" }, paramLabel = "ARG", description = "consistent name to use when register this worker with the coordinator") String workerId = useHostnameAsWorkerId ? HOSTNAME : null; + @CommandLine.Option(names = { + "--model-type" }, paramLabel = "ARG", description = "The models base type Q4/F32/BF16 (default: ${DEFAULT-VALUE})", defaultValue = "Q4") + DType modelType = DType.Q4; + @Override public void run() { try { if (workerId != null) System.out.println("Using " + workerId + " as worker id"); + + Path model = SimpleBaseCommand.getModel( + modelName, + modelDirectory, + true, + downloadSection.branch, + downloadSection.authToken, + false + ); + Worker w = new Worker( - model, + model.toFile(), + SimpleBaseCommand.getOwner(modelName), + SimpleBaseCommand.getName(modelName), + modelType, host, grpcPort, workingDirectory, - workingMemoryType, - workingQuantizationType, - Optional.ofNullable(modelQuantization), - Optional.ofNullable(workerId) + advancedSection.workingMemoryType, + advancedSection.workingQuantizationType, + Optional.ofNullable(advancedSection.modelQuantization), + Optional.ofNullable(workerId), + Optional.ofNullable(downloadSection.authToken), + Optional.ofNullable(downloadSection.branch) ); + w.run(); } catch (Exception e) { e.printStackTrace(); diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/CompleteCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/CompleteCommand.java index 51be872..862c5cc 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/CompleteCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/CompleteCommand.java @@ -19,23 +19,38 @@ import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.safetensors.prompt.PromptContext; + +import java.nio.file.Path; import java.util.Optional; import java.util.UUID; + import picocli.CommandLine.*; -@Command(name = "complete", description = "Completes a prompt using the specified model") +@Command(name = "complete", description = "Completes a prompt using the specified model", abbreviateSynopsis = true) public class CompleteCommand extends ModelBaseCommand { + @Option(names = { "--prompt" }, description = "Text to complete", required = true) + protected String prompt; + @Override public void run() { + Path modelPath = SimpleBaseCommand.getModel( + modelName, + modelDirectory, + downloadSection.autoDownload, + downloadSection.branch, + downloadSection.authToken + ); + AbstractModel m = loadModel( - model, + modelPath.toFile(), workingDirectory, - workingMemoryType, - workingQuantizationType, - Optional.ofNullable(modelQuantization), - Optional.ofNullable(threadCount) + advancedSection.workingMemoryType, + advancedSection.workingQuantizationType, + Optional.ofNullable(advancedSection.modelQuantization), + Optional.ofNullable(advancedSection.threadCount) ); + m.generate(UUID.randomUUID(), PromptContext.of(prompt), temperature, tokens, makeOutHandler()); } } diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/DownloadCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/DownloadCommand.java index e96beb1..15a978b 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/DownloadCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/DownloadCommand.java @@ -16,87 +16,34 @@ package com.github.tjake.jlama.cli.commands; import com.github.tjake.jlama.cli.JlamaCli; -import com.github.tjake.jlama.safetensors.SafeTensorSupport; -import com.google.common.util.concurrent.Uninterruptibles; import java.io.File; -import java.io.IOException; -import java.net.URLEncoder; -import java.util.Optional; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import me.tongfei.progressbar.ProgressBar; -import me.tongfei.progressbar.ProgressBarBuilder; -import me.tongfei.progressbar.ProgressBarStyle; import picocli.CommandLine; -@CommandLine.Command(name = "download", description = "Downloads a HuggingFace model - use owner/name format") +import static com.github.tjake.jlama.cli.commands.SimpleBaseCommand.getName; +import static com.github.tjake.jlama.cli.commands.SimpleBaseCommand.getOwner; + +@CommandLine.Command(name = "download", description = "Downloads a HuggingFace model - use owner/name format", abbreviateSynopsis = true) public class DownloadCommand extends JlamaCli { - @CommandLine.Option(names = { "-d", - "--model-directory" }, description = "The directory to download the model to (default: ${DEFAULT-VALUE})", defaultValue = "models") + @CommandLine.Option(names = { + "--model-cache" }, paramLabel = "ARG", description = "The local directory for all downloaded models (default: ${DEFAULT-VALUE})", defaultValue = "models") protected File modelDirectory = new File("models"); - @CommandLine.Option(names = { "-b", - "--branch" }, description = "The branch to download from (default: ${DEFAULT-VALUE})", defaultValue = "main") + @CommandLine.Option(names = { + "--branch" }, paramLabel = "ARG", description = "The branch to download from (default: ${DEFAULT-VALUE})", defaultValue = "main") protected String branch = "main"; - @CommandLine.Option(names = { "-t", "--auth-token" }, description = "The auth token to use for downloading the model (if required)") + @CommandLine.Option(names = { + "--auth-token" }, paramLabel = "ARG", description = "The auth token to use for downloading the model (if required)") protected String authToken = null; - @CommandLine.Parameters(index = "0", arity = "1", description = "The model owner/name pair to download") - protected String model; + @CommandLine.Parameters(index = "0", arity = "1", paramLabel = "", description = "The huggingface model owner/name pair") + protected String modelName; @Override public void run() { - AtomicReference progressRef = new AtomicReference<>(); - - String[] parts = model.split("/"); - if (parts.length == 0 || parts.length > 2) { - System.err.println("Model must be in the form owner/name"); - System.exit(1); - } - - String owner; - String name; - - if (parts.length == 1) { - owner = null; - name = model; - } else { - owner = parts[0]; - name = parts[1]; - } - - try { - SafeTensorSupport.maybeDownloadModel( - modelDirectory.getAbsolutePath(), - Optional.ofNullable(owner), - name, - Optional.ofNullable(URLEncoder.encode(branch)), - Optional.ofNullable(authToken), - Optional.of((n, c, t) -> { - if (progressRef.get() == null || !progressRef.get().getTaskName().equals(n)) { - ProgressBarBuilder builder = new ProgressBarBuilder().setTaskName(n) - .setInitialMax(t) - .setStyle(ProgressBarStyle.ASCII); - - if (t > 1000000) { - builder.setUnit("MB", 1000000); - } else if (t > 1000) { - builder.setUnit("KB", 1000); - } else { - builder.setUnit("B", 1); - } - - progressRef.set(builder.build()); - } + String owner = getOwner(modelName); + String name = getName(modelName); - progressRef.get().stepTo(c); - Uninterruptibles.sleepUninterruptibly(150, TimeUnit.MILLISECONDS); - }) - ); - } catch (IOException e) { - e.printStackTrace(); - System.exit(1); - } + SimpleBaseCommand.downloadModel(owner, name, modelDirectory, branch, authToken, true); } } diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ModelBaseCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ModelBaseCommand.java index b23c769..13f224c 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ModelBaseCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ModelBaseCommand.java @@ -22,20 +22,19 @@ import picocli.CommandLine.*; public class ModelBaseCommand extends BaseCommand { - @Option(names = { "-p", "--prompt" }, description = "Text to complete", required = true) - protected String prompt; - @Option(names = { "-t", - "--temperature" }, description = "Temperature of response [0,1] (default: ${DEFAULT-VALUE})", defaultValue = "0.6") + @Option(names = { + "--temperature" }, paramLabel = "ARG", description = "Temperature of response [0,1] (default: ${DEFAULT-VALUE})", defaultValue = "0.6") protected Float temperature; @Option(names = { - "--top-p" }, description = "Controls how many different words the model considers per token [0,1] (default: ${DEFAULT-VALUE})", defaultValue = ".9") - protected Float topp; - - @Option(names = { "-n", "--tokens" }, description = "Number of tokens to generate (default: ${DEFAULT-VALUE})", defaultValue = "256") + "--tokens" }, paramLabel = "ARG", description = "Number of tokens to generate (default: ${DEFAULT-VALUE})", defaultValue = "256") protected Integer tokens; + @Option(names = { + "--top-p" }, paramLabel = "ARG", description = "Controls how many different words the model considers per token [0,1] (default: ${DEFAULT-VALUE})", defaultValue = ".9") + protected Float topp; + protected BiConsumer makeOutHandler() { PrintWriter out; Charset utf8 = Charset.forName("UTF-8"); diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/QuantizeCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/QuantizeCommand.java index 7d91aba..66d565e 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/QuantizeCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/QuantizeCommand.java @@ -23,24 +23,33 @@ import java.util.Optional; import picocli.CommandLine; -@CommandLine.Command(name = "quantize", description = "Quantize the specified model") +@CommandLine.Command(name = "quantize", description = "Quantize the specified model", abbreviateSynopsis = true) public class QuantizeCommand extends SimpleBaseCommand { @CommandLine.Parameters(index = "1", arity = "0..1", description = "The output location") protected Path output; - @CommandLine.Option(names = { "-q", "--quantization" }, description = "Model quantization type", arity = "1") + @CommandLine.Option(names = { "--quantization" }, paramLabel = "ARG", description = "Model quantization type", arity = "1") protected DType modelQuantization; - @CommandLine.Option(names = { "-s", "--skip-layer" }, description = "Layer name prefix to not quantize") + @CommandLine.Option(names = { "--skip-layer" }, paramLabel = "ARG", description = "Layer name prefix to not quantize") protected String[] skipLayerPrefixes; - @CommandLine.Option(names = { "-d", "--drop-layer" }, description = "Layer name prefix to drop") + @CommandLine.Option(names = { "--drop-layer" }, paramLabel = "ARG", description = "Layer name prefix to drop") protected String[] dropLayerPrefixes; @Override public void run() { + Path modelPath = SimpleBaseCommand.getModel( + modelName, + modelDirectory, + downloadSection.autoDownload, + downloadSection.branch, + downloadSection.authToken + ); + File model = modelPath.toFile(); + if (!model.exists()) { System.err.println("Model location does not exist: " + model); System.exit(1); diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/SimpleBaseCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/SimpleBaseCommand.java index deffb4b..d815f66 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/SimpleBaseCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/SimpleBaseCommand.java @@ -17,9 +17,130 @@ import com.github.tjake.jlama.cli.JlamaCli; import java.io.File; +import java.io.IOException; +import java.net.URLEncoder; +import java.nio.file.Path; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import com.github.tjake.jlama.safetensors.SafeTensorSupport; +import com.github.tjake.jlama.util.TriConsumer; +import com.google.common.util.concurrent.Uninterruptibles; +import me.tongfei.progressbar.ProgressBar; +import me.tongfei.progressbar.ProgressBarBuilder; +import me.tongfei.progressbar.ProgressBarStyle; import picocli.CommandLine; public class SimpleBaseCommand extends JlamaCli { - @CommandLine.Parameters(index = "0", arity = "1", description = "The model location") - protected File model; + static AtomicReference progressRef = new AtomicReference<>(); + + @CommandLine.ArgGroup(exclusive = false, heading = "Download Options:%n") + protected DownloadSection downloadSection = new DownloadSection(); + + @CommandLine.Option(names = { + "--model-cache" }, paramLabel = "ARG", description = "The local directory for downloaded models (default: ${DEFAULT-VALUE})", defaultValue = "models") + protected File modelDirectory = new File("models"); + + @CommandLine.Parameters(index = "0", arity = "1", paramLabel = "", description = "The huggingface model owner/name pair") + protected String modelName; + + static class DownloadSection { + @CommandLine.Option(names = { + "--auto-download" }, paramLabel = "ARG", description = "Download the model if missing (default: ${DEFAULT-VALUE})", defaultValue = "false") + Boolean autoDownload = false; + + @CommandLine.Option(names = { + "--branch" }, paramLabel = "ARG", description = "The model branch to download from (default: ${DEFAULT-VALUE})", defaultValue = "main") + String branch = "main"; + + @CommandLine.Option(names = { "--auth-token" }, paramLabel = "ARG", description = "HuggingFace auth token (for restricted models)") + String authToken = null; + } + + static String getOwner(String modelName) { + String[] parts = modelName.split("/"); + if (parts.length == 0 || parts.length > 2) { + System.err.println("Model name must be in the form owner/name"); + System.exit(1); + } + return parts[0]; + } + + static String getName(String modelName) { + String[] parts = modelName.split("/"); + if (parts.length == 0 || parts.length > 2) { + System.err.println("Model name must be in the form owner/name"); + System.exit(1); + } + return parts[1]; + } + + static Optional> getProgressConsumer() { + if (System.console() == null) return Optional.empty(); + + return Optional.of((n, c, t) -> { + if (progressRef.get() == null || !progressRef.get().getTaskName().equals(n)) { + ProgressBarBuilder builder = new ProgressBarBuilder().setTaskName(n).setInitialMax(t).setStyle(ProgressBarStyle.ASCII); + + if (t > 1000000) { + builder.setUnit("MB", 1000000); + } else if (t > 1000) { + builder.setUnit("KB", 1000); + } else { + builder.setUnit("B", 1); + } + + progressRef.set(builder.build()); + } + + progressRef.get().stepTo(c); + Uninterruptibles.sleepUninterruptibly(150, TimeUnit.MILLISECONDS); + }); + } + + static void downloadModel(String owner, String name, File modelDirectory, String branch, String authToken, boolean downloadWeights) { + try { + SafeTensorSupport.maybeDownloadModel( + modelDirectory.getAbsolutePath(), + Optional.ofNullable(owner), + name, + downloadWeights, + Optional.ofNullable(URLEncoder.encode(branch)), + Optional.ofNullable(authToken), + getProgressConsumer() + ); + } catch (IOException e) { + e.printStackTrace(); + System.exit(1); + } + } + + static Path getModel(String modelName, File modelDirectory, boolean autoDownload, String branch, String authToken) { + return getModel(modelName, modelDirectory, autoDownload, branch, authToken, true); + } + + static Path getModel( + String modelName, + File modelDirectory, + boolean autoDownload, + String branch, + String authToken, + boolean downloadWeights + ) { + String owner = getOwner(modelName); + String name = getName(modelName); + + Path modelPath = SafeTensorSupport.constructLocalModelPath(modelDirectory.getAbsolutePath(), owner, name); + + if (autoDownload) { + downloadModel(owner, name, modelDirectory, branch, authToken, downloadWeights); + } else if (!modelPath.toFile().exists()) { + System.err.println("Model not found: " + modelPath); + System.err.println("Use --auto-download to download the model"); + System.exit(1); + } + + return modelPath; + } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java b/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java index 914ab2c..d9f9776 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/math/VectorMath.java @@ -103,7 +103,7 @@ public static void l2normalize(AbstractTensor x) { } double magnitude = Math.sqrt(sum); for (int i = 0; i < x.shape().last(); i++) - x.set((float)(x.get(0, i) / magnitude), 0, i); + x.set((float) (x.get(0, i) / magnitude), 0, i); } public static void l2normalize(float[] x) { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java index de99563..f68b69b 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java @@ -51,13 +51,13 @@ public abstract class AbstractModel implements Generator { private static final Logger logger = LoggerFactory.getLogger(AbstractModel.class); public enum InferenceType { - //Used for distributed inference + // Used for distributed inference INPUT_TO_EMBEDDING(true, false, false, false, false), OUTPUT_TO_TOKEN(false, false, true, false, false), FORWARD_PASS(true, true, false, false, false), - //Used for different types of inference - FULL_GENERATION(true, true, true, false,false), + // Used for different types of inference + FULL_GENERATION(true, true, true, false, false), FULL_CLASSIFICATION(true, true, false, true, true), FULL_EMBEDDING(true, true, false, false, true); @@ -104,6 +104,7 @@ protected AbstractModel( this.c = c; this.weights = w; this.tokenizer = t; + this.modelDType = w.getModelDType(); this.workingDType = workingMemoryDType; this.modelQType = modelQType; @@ -135,7 +136,12 @@ protected AbstractModel( this.workingQType = workingMemoryQType; } - logger.debug("Working memory type = {}, Quantized memory type = {}", this.workingDType, this.workingQType); + logger.info( + "Model type = {}, Working memory type = {}, Quantized memory type = {}", + this.modelDType, + this.workingDType, + this.workingQType + ); this.embedInput = inferenceType.isInput ? loadInputWeights() : null; this.transformerBlocks = inferenceType.isFwdPass ? loadTransformerBlockWeights() : null; @@ -284,20 +290,14 @@ public float[] embed(String input, PoolingType poolingType) { // Pooling TensorOperationsProvider.get() - .batchDotProduct( - pooled, - output, - poolingLayer.get().getPoolingWeights(), - 0, - 0, - c.embeddingLength); - - poolingLayer.get().getPoolingBias().ifPresent(bias -> { - TensorOperationsProvider.get().accumulate(pooled, bias, 0, c.embeddingLength); - }); + .batchDotProduct(pooled, output, poolingLayer.get().getPoolingWeights(), 0, 0, c.embeddingLength); + + poolingLayer.get() + .getPoolingBias() + .ifPresent(bias -> { TensorOperationsProvider.get().accumulate(pooled, bias, 0, c.embeddingLength); }); VectorMath.pfor(0, c.embeddingLength, i -> { - //BERT seems to use tanh for pooling rather than gelu + // BERT seems to use tanh for pooling rather than gelu outputEmbedding[i] = ActivationFunction.eval(ActivationFunction.Type.TANH, pooled.get(0, i)); }); @@ -345,9 +345,7 @@ public Map classify(String input, PoolingType poolingType) { TensorOperationsProvider.get().batchDotProduct(scores, b, classifyOutput.getClassificationWeights(), 0, 0, c.embeddingLength); - classifyOutput.getClassificationBias().ifPresent(bias -> { - TensorOperationsProvider.get().accumulate(scores, bias, 0, classes); - }); + classifyOutput.getClassificationBias().ifPresent(bias -> { TensorOperationsProvider.get().accumulate(scores, bias, 0, classes); }); VectorMath.softMax(scores, 0, classes); Map result = new HashMap<>(); @@ -401,8 +399,6 @@ public int sample(AbstractTensor output, float temperature, float uniformSample, } } - - @Override public Response generate( UUID sessionId, diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/ModelSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/ModelSupport.java index a1dba25..c3f0efa 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/ModelSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/ModelSupport.java @@ -82,12 +82,32 @@ public static AbstractModel loadModel(File model, DType workingMemoryType, DType /** Shortcut for loading a model for embeddings */ public static AbstractModel loadEmbeddingModel(File model, DType workingMemoryType, DType workingQuantizationType) { - return loadModel(AbstractModel.InferenceType.FULL_EMBEDDING, model, null, workingMemoryType, workingQuantizationType, Optional.empty(), Optional.empty(), Optional.empty()); + return loadModel( + AbstractModel.InferenceType.FULL_EMBEDDING, + model, + null, + workingMemoryType, + workingQuantizationType, + Optional.empty(), + Optional.empty(), + Optional.empty(), + SafeTensorSupport::loadWeights + ); } /** Shortcut for loading a model for embeddings */ public static AbstractModel loadClassifierModel(File model, DType workingMemoryType, DType workingQuantizationType) { - return loadModel(AbstractModel.InferenceType.FULL_CLASSIFICATION, model, null, workingMemoryType, workingQuantizationType, Optional.empty(), Optional.empty(), Optional.empty()); + return loadModel( + AbstractModel.InferenceType.FULL_CLASSIFICATION, + model, + null, + workingMemoryType, + workingQuantizationType, + Optional.empty(), + Optional.empty(), + Optional.empty(), + SafeTensorSupport::loadWeights + ); } public static AbstractModel loadModel( @@ -106,7 +126,8 @@ public static AbstractModel loadModel( workingQuantizationType, modelQuantization, threadCount, - Optional.empty() + Optional.empty(), + SafeTensorSupport::loadWeights ); } @@ -118,7 +139,8 @@ public static AbstractModel loadModel( DType workingQuantizationType, Optional modelQuantization, Optional threadCount, - Optional> distributedContextLoader + Optional> distributedContextLoader, + Function weightLoaderSupplier ) { if (!model.exists()) { @@ -154,7 +176,7 @@ public static AbstractModel loadModel( c.setWorkingDirectory(workingDirectory); Tokenizer t = modelType.tokenizerClass.getConstructor(Path.class).newInstance(baseDir.toPath()); - WeightLoader wl = SafeTensorSupport.loadWeights(baseDir); + WeightLoader wl = weightLoaderSupplier.apply(baseDir); return modelType.modelClass.getConstructor( AbstractModel.InferenceType.class, diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java index c384cb2..3ca4345 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertModel.java @@ -15,8 +15,6 @@ */ package com.github.tjake.jlama.model.bert; -import com.github.tjake.jlama.math.ActivationFunction; -import com.github.tjake.jlama.math.VectorMath; import com.github.tjake.jlama.model.*; import com.github.tjake.jlama.model.functions.ClassifyOutput; import com.github.tjake.jlama.model.functions.EmbedInput; @@ -27,17 +25,13 @@ import com.github.tjake.jlama.safetensors.WeightLoader; import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer; import com.github.tjake.jlama.tensor.AbstractTensor; -import com.github.tjake.jlama.tensor.KvBufferCache; -import com.google.common.base.Preconditions; -import com.google.common.primitives.Ints; import java.util.Arrays; import java.util.NoSuchElementException; import java.util.Optional; -import java.util.UUID; public class BertModel extends AbstractModel { - private static final String[] prefixes = new String[] {"", "bert."}; + private static final String[] prefixes = new String[] { "", "bert." }; public BertModel(Config c, WeightLoader w, Tokenizer tokenizer, DType workingDType, DType workingQType, Optional modelQType) { super(InferenceType.FORWARD_PASS, c, w, tokenizer, workingDType, workingQType, modelQType); @@ -78,11 +72,7 @@ protected EmbedInput loadInputWeights() { AbstractTensor wte = loadWeight("embeddings.token_type_embeddings.weight"); AbstractTensor wpe = loadWeight("embeddings.position_embeddings.weight"); - LayerNorm inputLayerNorm = new LayerNorm( - this, - loadWeight("embeddings.LayerNorm.bias"), - loadWeight("embeddings.LayerNorm.weight") - ); + LayerNorm inputLayerNorm = new LayerNorm(this, loadWeight("embeddings.LayerNorm.bias"), loadWeight("embeddings.LayerNorm.weight")); return (inputToken, position) -> { AbstractTensor embedding = makeDenseTensor(c.embeddingLength); @@ -133,23 +123,19 @@ protected TransformerBlock[] loadTransformerBlockWeights() { prefix = b; MLPBlock mlpBlock = new MLPBlock( this, - c.activationFunction, - loadWeight(prefix + "intermediate.dense.bias"), - loadWeight(prefix + "intermediate.dense.weight"), - loadWeight(prefix + "output.dense.bias"), - loadWeight(prefix + "output.dense.weight") + c.activationFunction, + loadWeight(prefix + "intermediate.dense.bias"), + loadWeight(prefix + "intermediate.dense.weight"), + loadWeight(prefix + "output.dense.bias"), + loadWeight(prefix + "output.dense.weight") ); LayerNorm postAttentionNorm = new LayerNorm( this, - loadWeight(b + "attention.output.LayerNorm.bias"), - loadWeight(b + "attention.output.LayerNorm.weight") - ); - LayerNorm postMlpNorm = new LayerNorm( - this, - loadWeight(b + "output.LayerNorm.bias"), - loadWeight(b + "output.LayerNorm.weight") + loadWeight(b + "attention.output.LayerNorm.bias"), + loadWeight(b + "attention.output.LayerNorm.weight") ); + LayerNorm postMlpNorm = new LayerNorm(this, loadWeight(b + "output.LayerNorm.bias"), loadWeight(b + "output.LayerNorm.weight")); transformerBlocks[i] = new TransformerBlock(this, i, attention, postAttentionNorm, mlpBlock, postMlpNorm); } @@ -184,7 +170,7 @@ protected ClassifyOutput loadClassifierWeights() { if (c.isClassifier()) { final AbstractTensor classifierWeight = loadWeight("classifier.weight"); final AbstractTensor classifierBias = loadWeight("classifier.bias"); - + return new ClassifyOutput() { @Override public AbstractTensor getClassificationWeights() { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/ClassifyOutput.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/ClassifyOutput.java index 9ad4d86..854829a 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/ClassifyOutput.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/ClassifyOutput.java @@ -1,3 +1,18 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.model.functions; import com.github.tjake.jlama.tensor.AbstractTensor; @@ -6,5 +21,6 @@ public interface ClassifyOutput { public AbstractTensor getClassificationWeights(); + public Optional getClassificationBias(); } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/Generator.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/Generator.java index f625d16..f785015 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/Generator.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/Generator.java @@ -138,7 +138,6 @@ Response generate( BiConsumer onTokenWithTimings ); - enum PoolingType { MODEL, // Use the model's pooling layers AVG, diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/PoolingLayer.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/PoolingLayer.java index 5575452..792fa70 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/PoolingLayer.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/PoolingLayer.java @@ -1,3 +1,18 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.model.functions; import com.github.tjake.jlama.tensor.AbstractTensor; @@ -6,5 +21,6 @@ public interface PoolingLayer { AbstractTensor getPoolingWeights(); + Optional getPoolingBias(); } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java index 51def4a..d06e500 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java @@ -21,7 +21,6 @@ import com.github.tjake.jlama.tensor.TensorCache; import com.google.common.base.Preconditions; import com.google.common.collect.BiMap; -import com.google.common.collect.HashBiMap; import com.google.common.collect.ImmutableBiMap; import com.google.common.io.Files; import java.io.File; @@ -55,36 +54,36 @@ public class Config { public final TensorCache tensorCache; public Config( - int contextLength, - int embeddingLength, - int hiddenLength, - int numberOfHeads, - int numberOfKeyValueHeads, - int numberOfLayers, - float layerNormEps, - int vocabularySize, - int bosToken, - List eosToken, - ActivationFunction.Type activationFunction, - Double ropeFreqsTheta, - Double ropeScalingFactor + int contextLength, + int embeddingLength, + int hiddenLength, + int numberOfHeads, + int numberOfKeyValueHeads, + int numberOfLayers, + float layerNormEps, + int vocabularySize, + int bosToken, + List eosToken, + ActivationFunction.Type activationFunction, + Double ropeFreqsTheta, + Double ropeScalingFactor ) { this( - contextLength, - embeddingLength, - hiddenLength, - numberOfHeads, - numberOfKeyValueHeads, - numberOfLayers, - layerNormEps, - vocabularySize, - bosToken, - eosToken, - activationFunction, - ropeFreqsTheta, - ropeScalingFactor, - null, - embeddingLength / numberOfHeads + contextLength, + embeddingLength, + hiddenLength, + numberOfHeads, + numberOfKeyValueHeads, + numberOfLayers, + layerNormEps, + vocabularySize, + bosToken, + eosToken, + activationFunction, + ropeFreqsTheta, + ropeScalingFactor, + null, + embeddingLength / numberOfHeads ); } @@ -163,8 +162,7 @@ public Config( VectorMath.precomputeFreqsCis(headSize, contextLength, ropeFreqsTheta, ropeScalingFactor == null ? 1.0 : ropeScalingFactor) ); - this.classifcationLabels = classifcationLabels == null ? Optional.empty() : - Optional.of(ImmutableBiMap.copyOf(classifcationLabels)); + this.classifcationLabels = classifcationLabels == null ? Optional.empty() : Optional.of(ImmutableBiMap.copyOf(classifcationLabels)); // Set default values this.dctx = DistributedContext.builder(this).build(); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/HTTPSafeTensorLoader.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/HTTPSafeTensorLoader.java new file mode 100644 index 0000000..8aae1f0 --- /dev/null +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/HTTPSafeTensorLoader.java @@ -0,0 +1,243 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.github.tjake.jlama.safetensors; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.github.tjake.jlama.model.DistributedContext; +import com.github.tjake.jlama.tensor.AbstractTensor; +import com.github.tjake.jlama.tensor.TensorShape; +import com.github.tjake.jlama.util.HttpSupport; +import com.github.tjake.jlama.util.JsonSupport; +import com.github.tjake.jlama.util.Pair; +import com.google.common.base.Preconditions; +import com.google.common.primitives.Ints; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +public class HTTPSafeTensorLoader implements WeightLoader { + private static final Logger logger = LoggerFactory.getLogger(HTTPSafeTensorLoader.class); + + private final Path modelRoot; + private final String indexFile; + private final String modelName; + private final Optional branch; + private final Optional authToken; + private final SafeTensorIndex index; + private final Map> layerFiles; + private final Map dynamicTensorInfoMap; + private final Map tensorFileOffsets; + private final DType modelDType; + + /** + * Used for distributed inference + * + * Dynamically fetches weights from a remote server based on the distributed context + * + * @param modelRoot + * @param owner + * @param modelName + * @param branch + * @param authToken + * @throws JsonProcessingException + */ + public HTTPSafeTensorLoader( + Path modelRoot, + String owner, + String modelName, + DType modelDType, + Optional branch, + Optional authToken + ) { + this.modelRoot = modelRoot; + this.modelName = owner + "/" + modelName; + this.branch = branch; + this.indexFile = String.format("%s/%s", modelRoot, SafeTensorIndex.MODEL_INDEX_JSON); + this.authToken = authToken; + + // Check we have the index file + if (!new File(indexFile).exists()) { + this.index = new SafeTensorIndex(Collections.emptyMap(), Map.of("model-file", SafeTensorIndex.SINGLE_MODEL_NAME)); + } else { + try { + this.index = JsonSupport.om.readValue(new File(indexFile), SafeTensorIndex.class); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + this.layerFiles = new HashMap<>(); + this.dynamicTensorInfoMap = new HashMap<>(); + this.tensorFileOffsets = new HashMap<>(); + this.modelDType = modelDType; + } + + @Override + public Map metadata() { + return index.metadata(); + } + + @Override + public Map tensorInfoMap() { + return dynamicTensorInfoMap; + } + + @Override + public AbstractTensor load(String name, DistributedContext dctx, boolean sparseRows, boolean sparseColumns) { + Preconditions.checkArgument(!sparseColumns || !sparseRows, "Cannot have both sparse rows and columns"); + Preconditions.checkArgument(index.weightFileMap.containsKey(name) || index.weightFileMap.size() == 1, "Unknown weight: " + name); + + // Check if we already have the layer loaded + if (layerFiles.containsKey(name)) { + return layerFiles.get(name).right(); + } + + try { + TensorInfo info = maybeLoadTensorInfo(name); + + Pair> offsets = Weights.getLoadOffsets(info, dctx, sparseRows); + + Integer headerOffset = tensorFileOffsets.get(name); + + assert headerOffset != null && headerOffset > 0 : "Failed to find header offset for: " + name; + + TensorShape shape = offsets.left; + long positionOffset = offsets.right.left + headerOffset; + long positionLimit = offsets.right.right + headerOffset; + + String weightFile = index.weightFileMap.getOrDefault(name, SafeTensorIndex.SINGLE_MODEL_NAME); + + Path weightPath = modelRoot.resolve(weightFile + ".part." + positionOffset + "_" + positionLimit); + + if (!weightPath.toFile().exists()) { + logger.info("Downloading file: {} for {} {}MB", weightPath, name, (positionLimit - positionOffset) / 1024 / 1024); + HttpSupport.downloadFile( + modelName, + weightFile, + branch, + authToken, + Optional.of(Pair.of(positionOffset, positionLimit)), + weightPath, + Optional.empty() + ); + } + + int length = Ints.checkedCast(positionLimit - positionOffset); + + RandomAccessFile raf = new RandomAccessFile(weightPath.toFile(), "r"); + ByteBuffer buf = raf.getChannel() + .map(FileChannel.MapMode.READ_ONLY, 0, raf.length()) + .duplicate() + .order(ByteOrder.LITTLE_ENDIAN) + .position(0) + .limit(length); + + if (raf.length() < length) { + throw new RuntimeException( + "Failed to download the correct number of bytes: " + raf.length() + " != " + length + " for " + weightPath + ); + } + + logger.debug("Loading tensor: {} from {} with offsets: {} {}", name, weightPath, positionOffset, positionLimit); + + AbstractTensor tensor = Weights.loadTensorFromBuffer( + name, + info.dType, + modelDType, + shape, + buf, + sparseRows, + sparseColumns, + dctx, + this + ); + + layerFiles.put(name, Pair.of(raf, tensor)); + + return tensor; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private TensorInfo maybeLoadTensorInfo(String name) throws IOException { + if (dynamicTensorInfoMap.containsKey(name)) { + return dynamicTensorInfoMap.get(name); + } + + String weightFile = index.weightFileMap.getOrDefault(name, SafeTensorIndex.SINGLE_MODEL_NAME); + + Path headerFile = modelRoot.resolve(weightFile + ".header"); + + if (!Files.exists(headerFile)) { + // Download the first 1MB of the file to get the tensor info + HttpSupport.downloadFile( + modelName, + weightFile, + branch, + authToken, + Optional.of(Pair.of(0L, (long) 1 << 20)), + headerFile, + Optional.empty() + ); + } + + try (RandomAccessFile raf = new RandomAccessFile(headerFile.toFile(), "r")) { + ByteBuffer header = raf.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, Math.min(1 << 20, raf.length())); + Map info = SafeTensorSupport.readTensorInfoMap(header, Optional.empty()); + int endOfHeaderPosition = header.position(); + for (Map.Entry e : info.entrySet()) { + dynamicTensorInfoMap.put(e.getKey(), e.getValue()); + tensorFileOffsets.put(e.getKey(), endOfHeaderPosition); + } + } + + assert dynamicTensorInfoMap.containsKey(name) : "Failed to load tensor info for: " + name; + return dynamicTensorInfoMap.get(name); + } + + @Override + public DType getModelDType() { + return modelDType; + } + + @Override + public void close() { + for (Pair pair : layerFiles.values()) { + try { + pair.left().close(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + layerFiles.clear(); + dynamicTensorInfoMap.clear(); + } + +} diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorIndex.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorIndex.java index ef81e06..876593d 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorIndex.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorIndex.java @@ -42,7 +42,7 @@ public class SafeTensorIndex implements WeightLoader, AutoCloseable { private final Map metadata; // Map from weight name to file name (this is what's in the JSON file) - private final Map weightFileMap; + final Map weightFileMap; // Map from weight name to Weights data private final Map weightMap = new HashMap<>(); @@ -50,20 +50,28 @@ public class SafeTensorIndex implements WeightLoader, AutoCloseable { // Map from file name to RandomAccessFile private final Map fileMap = new HashMap<>(); - public static SafeTensorIndex loadWithWeights(Path modelRoot) throws IOException { - File indexFile = Paths.get(modelRoot.toString(), MODEL_INDEX_JSON).toFile(); + public static SafeTensorIndex loadWithWeights(Path modelRoot) { + try { + File indexFile = Paths.get(modelRoot.toString(), MODEL_INDEX_JSON).toFile(); - SafeTensorIndex index = om.readValue(indexFile, SafeTensorIndex.class); - loadWeights(index, modelRoot); + SafeTensorIndex index = om.readValue(indexFile, SafeTensorIndex.class); + loadWeights(index, modelRoot); - return index; + return index; + } catch (IOException e) { + throw new RuntimeException(e); + } } - public static SafeTensorIndex loadSingleFile(Path modelRoot, String modelFile) throws IOException { - SafeTensorIndex index = new SafeTensorIndex(Collections.emptyMap(), Map.of("model-file", modelFile)); - loadWeights(index, modelRoot); + public static SafeTensorIndex loadSingleFile(Path modelRoot, String modelFile) { + try { + SafeTensorIndex index = new SafeTensorIndex(Collections.emptyMap(), Map.of("model-file", modelFile)); + loadWeights(index, modelRoot); - return index; + return index; + } catch (IOException e) { + throw new RuntimeException(e); + } } static void loadWeights(SafeTensorIndex index, Path modelRoot) throws IOException { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java index 34acf3b..759fe94 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java @@ -91,7 +91,7 @@ public static ModelType detectModel(File configFile) throws IOException { return ModelType.valueOf(rootNode.get("model_type").textValue().toUpperCase()); } - public static WeightLoader loadWeights(File baseDir) throws IOException { + public static WeightLoader loadWeights(File baseDir) { if (Files.exists(Paths.get(baseDir.getAbsolutePath(), SafeTensorIndex.MODEL_INDEX_JSON))) return SafeTensorIndex.loadWithWeights( baseDir.toPath() ); @@ -104,6 +104,29 @@ public static WeightLoader loadWeights(File baseDir) throws IOException { throw new IllegalArgumentException("No safetensor model found in: " + baseDir); } + public static boolean isModelLocal(Path modelRoot) { + + if (Files.exists(modelRoot.resolve(SafeTensorIndex.SINGLE_MODEL_NAME))) return true; + try { + if (Files.exists(modelRoot.resolve(SafeTensorIndex.MODEL_INDEX_JSON))) { + SafeTensorIndex index = om.readValue(modelRoot.resolve(SafeTensorIndex.MODEL_INDEX_JSON).toFile(), SafeTensorIndex.class); + + for (String file : index.weightFileMap.values()) { + if (!Files.exists(modelRoot.resolve(file))) { + return false; + } + } + + return true; + } + } catch (IOException e) { + logger.error("Error reading model index", e); + return false; + } + + return false; + } + public static TokenizerModel loadTokenizer(Path modelRoot) throws IOException { File tokenizerJson = modelRoot.resolve("tokenizer.json").toFile(); Preconditions.checkArgument(tokenizerJson.exists(), "No tokenizer.json found in " + modelRoot); @@ -302,15 +325,22 @@ public static File maybeDownloadModel(String modelDir, String fullModelName) thr name = parts[1]; } - return maybeDownloadModel(modelDir, Optional.ofNullable(owner), name, Optional.empty(), Optional.empty(), Optional.empty()); + return maybeDownloadModel(modelDir, Optional.ofNullable(owner), name, true, Optional.empty(), Optional.empty(), Optional.empty()); } + public static Path constructLocalModelPath(String modelDir, String owner, String modelName) { + return Paths.get(modelDir, owner + "_" + modelName); + } + + static String FINISHED_MARKER = ".finished"; + /** * Download a model from HuggingFace and return the path to the model directory * * @param modelDir The directory to save the model to * @param modelOwner The owner of the HF model (if any) * @param modelName The name of the HF model + * @param downloadWeights Include the weights or leave them out * @param optionalBranch The branch of the model to download * @param optionalAuthHeader The authorization header to use for the request * @param optionalProgressReporter A consumer to report download progress @@ -321,14 +351,23 @@ public static File maybeDownloadModel( String modelDir, Optional modelOwner, String modelName, + boolean downloadWeights, Optional optionalBranch, Optional optionalAuthHeader, Optional> optionalProgressReporter ) throws IOException { + + Path localModelDir = constructLocalModelPath(modelDir, modelOwner.orElse("na"), modelName); + // Check if the model is already downloaded + if (Files.exists(localModelDir.resolve(FINISHED_MARKER))) { + return localModelDir.toFile(); + } + String hfModel = modelOwner.map(mo -> mo + "/" + modelName).orElse(modelName); InputStream modelInfoStream = HttpSupport.getResponse( "https://huggingface.co/api/models/" + hfModel + "/tree/" + optionalBranch.orElse("main"), - optionalAuthHeader + optionalAuthHeader, + Optional.empty() ).left; String modelInfo = HttpSupport.readInputStream(modelInfoStream); @@ -349,10 +388,16 @@ public static File maybeDownloadModel( || f.contains("readme") || f.equals("config.json") || f.contains("tokenizer")) { - tensorFiles.add(currFile); + if (f.contains("safetensor")) { hasSafetensor = true; } + + if (!downloadWeights && f.contains("safetensor")) { + continue; + } + + tensorFiles.add(currFile); } } @@ -360,7 +405,6 @@ public static File maybeDownloadModel( throw new IOException("Model is not available in safetensor format"); } - Path localModelDir = Paths.get(modelDir, modelName); Files.createDirectories(localModelDir); for (String currFile : tensorFiles) { @@ -369,11 +413,15 @@ public static File maybeDownloadModel( currFile, optionalBranch, optionalAuthHeader, + Optional.empty(), localModelDir.resolve(currFile), optionalProgressReporter ); } + // When fully downloaded, create a .finished file + Files.createFile(localModelDir.resolve(FINISHED_MARKER)); + return localModelDir.toFile(); } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/WeightLoader.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/WeightLoader.java index 69c452d..f68ab6b 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/WeightLoader.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/WeightLoader.java @@ -17,6 +17,7 @@ import com.github.tjake.jlama.model.DistributedContext; import com.github.tjake.jlama.tensor.AbstractTensor; + import java.util.Map; public interface WeightLoader extends AutoCloseable { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Weights.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Weights.java index 8973a9d..22ef8f5 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Weights.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Weights.java @@ -21,11 +21,13 @@ import com.github.tjake.jlama.util.Pair; import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; + import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; import java.nio.ShortBuffer; import java.util.*; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,11 +43,11 @@ public class Weights implements WeightLoader { this.metadata = ImmutableMap.copyOf(metadata); this.tensorInfoMap = ImmutableMap.copyOf(tensorInfoMap); this.bytes = bytes.duplicate(); - this.majorityDType = findDType(); + this.majorityDType = findDType(tensorInfoMap); this.parent = parent; } - private DType findDType() { + public static DType findDType(Map tensorInfoMap) { EnumMap counts = new EnumMap<>(DType.class); for (Map.Entry e : tensorInfoMap.entrySet()) { if (!e.getKey().endsWith(".qb")) counts.put(e.getValue().dType, counts.getOrDefault(e.getValue().dType, 0) + 1); @@ -75,8 +77,7 @@ public Map tensorInfoMap() { } @Override - public AbstractTensor load(String name, DistributedContext dctx, boolean sparseRows, boolean sparseColumns) - throws NoSuchElementException { + public AbstractTensor load(String name, DistributedContext dctx, boolean sparseRows, boolean sparseColumns) { TensorInfo info = tensorInfoMap.get(name); if (info == null) throw new NoSuchElementException(name + " not found in weights"); @@ -86,8 +87,19 @@ public AbstractTensor load(String name, DistributedContext dctx, boolean sparseR throw new RuntimeException("Invalid shape dimensions " + info.shape.length + " encountered for " + name + " with offset"); } - int positionOffset = Ints.checkedCast(info.dataOffsets[0]); - int positionLimit = Ints.checkedCast(info.dataOffsets[1]); + Pair> offsets = getLoadOffsets(info, dctx, sparseRows); + + ByteBuffer b = bytes.duplicate() + .order(ByteOrder.LITTLE_ENDIAN) + .position(Ints.checkedCast(offsets.right.left)) + .limit(Ints.checkedCast(offsets.right.right)); + + return loadTensorFromBuffer(name, info.dType, majorityDType, offsets.left, b, sparseRows, sparseColumns, dctx, parent.orElse(this)); + } + + static Pair> getLoadOffsets(TensorInfo info, DistributedContext dctx, boolean sparseRows) { + long positionOffset = info.dataOffsets[0]; + long positionLimit = info.dataOffsets[1]; TensorShape shape = TensorShape.of(info.shape); // If this is a sparse tensor, we need to fetch only the section of the tensor that is needed @@ -98,21 +110,29 @@ public AbstractTensor load(String name, DistributedContext dctx, boolean sparseR // Hack for Q4 if (info.dType == DType.Q4) columnLength /= 2; - positionOffset = Ints.checkedCast(info.dataOffsets[0]) + (dctx.getShardOffsetForLength(rows) * columnLength); + positionOffset = info.dataOffsets[0] + (dctx.getShardOffsetForLength(rows) * columnLength); positionLimit = positionOffset + (dctx.getShardLength(rows) * columnLength); shape = TensorShape.sparseRow(info.shape, Pair.of(dctx.getShardOffsetForLength(rows), dctx.getShardLength(rows))); } + return Pair.of(shape, Pair.of(positionOffset, positionLimit)); + } - ByteBuffer b = bytes.duplicate() - .order(ByteOrder.LITTLE_ENDIAN) - .position(Ints.checkedCast(positionOffset)) - .limit(Ints.checkedCast(positionLimit)); - + static AbstractTensor loadTensorFromBuffer( + String name, + DType dType, + DType majorityDType, + TensorShape shape, + ByteBuffer b, + boolean sparseRows, + boolean sparseColumns, + DistributedContext dctx, + WeightLoader loader + ) { int len; FloatBuffer fb; ShortBuffer sb; AbstractTensor t; - switch (info.dType) { + switch (dType) { case F32: fb = b.asFloatBuffer().slice(); t = new FloatBufferTensor(name, fb, shape, true); @@ -149,18 +169,20 @@ public AbstractTensor load(String name, DistributedContext dctx, boolean sparseR } break; case Q4: - FloatBufferTensor qb = (FloatBufferTensor) parent.orElse(this).load(name + ".qb", dctx, sparseRows, false); // only need to - // sparsify once + FloatBufferTensor qb = (FloatBufferTensor) loader.load(name + ".qb", dctx, sparseRows, false /*only need sparsify once*/); t = new Q4ByteBufferTensor(name, b.slice(), qb, shape, true); break; case I8: - FloatBufferTensor qb1 = (FloatBufferTensor) parent.orElse(this).load(name + ".qb", dctx, sparseRows, false); // only need to - // sparsify - // once + FloatBufferTensor qb1 = (FloatBufferTensor) loader.load( + name + ".qb", + dctx, + sparseRows, + false /*only need to sparsify once*/ + ); t = new Q8ByteBufferTensor(name, b.slice(), qb1, shape, true); break; default: - throw new IllegalArgumentException("Unsupported Tensor type: " + info.dType.name() + " for " + name); + throw new IllegalArgumentException("Unsupported Tensor type: " + dType.name() + " for " + name); } return dctx != null && sparseColumns && dctx.hasModelShard() diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java index b0b9127..61b2cce 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/AbstractTensor.java @@ -25,8 +25,6 @@ import java.nio.ByteOrder; import java.nio.channels.FileChannel; import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; import jdk.incubator.vector.Vector; import jdk.incubator.vector.VectorMask; import jdk.incubator.vector.VectorSpecies; @@ -48,7 +46,6 @@ public abstract class AbstractTensor, T extends Number> impl protected final TensorShape shape; protected final DType dType; protected final AbstractTensor[] sliceCache; - protected final Map metadata; private final int stride; private volatile TensorCache originCache = null; @@ -56,7 +53,6 @@ protected AbstractTensor(DType dType, TensorShape shape, boolean cacheSlices) { Preconditions.checkArgument(shape != null && shape.dims() > 0); this.dType = dType; this.shape = shape; - this.metadata = new HashMap<>(); this.sliceCache = cacheSlices ? new AbstractTensor[shape.first()] : null; this.stride = shape.first() > 1 && dims() == 2 ? getOffset(shape.sparseRowOffset() + 1, shape.sparseColumnOffset()) : 0; } @@ -70,14 +66,6 @@ public static AbstractTensor make(DType dType, TensorShape shape) { }; } - public void setMetadata(String key, Object value) { - metadata.put(key, value); - } - - public Object getMetadata(String key) { - return metadata.get(key); - } - /** Create a new tensor with the given shape of the same Tensor implementation */ protected abstract AbstractTensor make(TensorShape shape); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/util/HttpSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/util/HttpSupport.java index 2985edb..8cad012 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/util/HttpSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/util/HttpSupport.java @@ -30,28 +30,44 @@ import java.nio.file.StandardCopyOption; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.zip.GZIPInputStream; public class HttpSupport { public static final Logger logger = LoggerFactory.getLogger(HttpSupport.class); - public static Pair getResponse(String urlString, Optional optionalAuthHeader) throws IOException { + public static Pair getResponse( + String urlString, + Optional optionalAuthHeader, + Optional> optionalByteRange + ) throws IOException { URL url = new URL(urlString); HttpURLConnection connection = (HttpURLConnection) url.openConnection(); // Set the request method connection.setRequestMethod("GET"); + connection.setRequestProperty("Accept-Encoding", "gzip"); // Set the request header optionalAuthHeader.ifPresent(authHeader -> connection.setRequestProperty("Authorization", "Bearer " + authHeader)); + optionalByteRange.ifPresent(byteRange -> connection.setRequestProperty("Range", "bytes=" + byteRange.left + "-" + byteRange.right)); // Get the response code int responseCode = connection.getResponseCode(); - if (responseCode == HttpURLConnection.HTTP_OK) { + if (responseCode == HttpURLConnection.HTTP_OK || responseCode == HttpURLConnection.HTTP_PARTIAL) { // If the response code is 200 (HTTP_OK), return the input stream - return Pair.of(connection.getInputStream(), connection.getContentLengthLong()); + + String encoding = connection.getContentEncoding(); + InputStream inputStream; + if (encoding != null && encoding.equals("gzip")) { + inputStream = new GZIPInputStream(connection.getInputStream()); + } else { + inputStream = connection.getInputStream(); + } + + return Pair.of(inputStream, connection.getContentLengthLong()); } else { - // If the response code is not 200, throw an IOException + // If the response code is not 200/206, throw an IOException throw new IOException("HTTP response code: " + responseCode + " for URL: " + urlString); } } @@ -76,13 +92,15 @@ public static void downloadFile( String currFile, Optional optionalBranch, Optional optionalAuthHeader, + Optional> optionalByteRange, Path outputPath, Optional> optionalProgressConsumer ) throws IOException { Pair stream = getResponse( "https://huggingface.co/" + hfModel + "/resolve/" + optionalBranch.orElse("main") + "/" + currFile, - optionalAuthHeader + optionalAuthHeader, + optionalByteRange ); CountingInputStream inStream = new CountingInputStream(stream.left); diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java index bcd5e3c..90bc5d8 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java @@ -21,10 +21,14 @@ import com.github.tjake.jlama.model.functions.Generator; import com.github.tjake.jlama.net.grpc.JlamaService; import com.github.tjake.jlama.safetensors.DType; +import com.github.tjake.jlama.safetensors.HTTPSafeTensorLoader; +import com.github.tjake.jlama.safetensors.SafeTensorSupport; +import com.github.tjake.jlama.safetensors.WeightLoader; import com.github.tjake.jlama.safetensors.prompt.PromptContext; import com.github.tjake.jlama.safetensors.prompt.PromptSupport; import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer; import com.github.tjake.jlama.tensor.AbstractTensor; +import com.google.common.base.Function; import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; import io.grpc.Server; @@ -50,8 +54,23 @@ public class Coordinator implements Generator { private final AbstractModel model; private final JlamaService service; - public Coordinator(File modelPath, File workingDirectory, int port, int workerCount) { + public Coordinator( + File modelPath, + String modelOwner, + String modelName, + DType modelDType, + File workingDirectory, + int port, + int workerCount, + Optional authToken, + Optional branch + ) { Preconditions.checkArgument(workerCount != 0 && ((workerCount & (workerCount - 1)) == 0), "worker count must be a power of 2"); + + Function weightLoaderFunction = SafeTensorSupport.isModelLocal(modelPath.toPath()) + ? b -> SafeTensorSupport.loadWeights(modelPath) + : b -> new HTTPSafeTensorLoader(modelPath.toPath(), modelOwner, modelName, modelDType, authToken, branch); + this.model = loadModel( AbstractModel.InferenceType.OUTPUT_TO_TOKEN, modelPath, @@ -60,7 +79,8 @@ public Coordinator(File modelPath, File workingDirectory, int port, int workerCo DType.I8, Optional.empty(), Optional.empty(), - Optional.empty() + Optional.empty(), + weightLoaderFunction ); this.port = port; this.workerCount = workerCount; @@ -101,7 +121,7 @@ public void stop() throws InterruptedException { } public float[] embed(String input, Generator.PoolingType poolingType) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException(); } public Generator.Response generate( diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java index 2e69fd8..7ea141e 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java @@ -20,8 +20,12 @@ import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.DistributedContext; import com.github.tjake.jlama.safetensors.DType; +import com.github.tjake.jlama.safetensors.HTTPSafeTensorLoader; +import com.github.tjake.jlama.safetensors.SafeTensorSupport; +import com.github.tjake.jlama.safetensors.WeightLoader; import com.github.tjake.jlama.tensor.AbstractTensor; import com.github.tjake.jlama.tensor.KvBufferCache; +import com.google.common.base.Function; import com.google.common.base.Preconditions; import com.google.common.util.concurrent.Uninterruptibles; import com.google.protobuf.ByteString; @@ -54,14 +58,19 @@ public class Worker implements Closeable { private final RegisterResponse registerResponse; public Worker( - File modelPrefix, + File modelPath, + String modelOwner, + String modelName, + DType modelDType, String host, int port, File workingDirectory, DType workingMemoryType, DType workingQuantizationType, Optional modelQuantization, - Optional optionalWorkerId + Optional optionalWorkerId, + Optional authToken, + Optional branch ) { Channel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build(); @@ -79,9 +88,13 @@ public Worker( registerResponse.getNumModelShards() ); + Function weightLoaderFunction = SafeTensorSupport.isModelLocal(modelPath.toPath()) + ? b -> SafeTensorSupport.loadWeights(modelPath) + : b -> new HTTPSafeTensorLoader(modelPath.toPath(), modelOwner, modelName, modelDType, authToken, branch); + this.model = loadModel( AbstractModel.InferenceType.FORWARD_PASS, - modelPrefix, + modelPath, workingDirectory, workingMemoryType, workingQuantizationType, @@ -94,7 +107,8 @@ public Worker( .setLayerShard(registerResponse.getLayerShard()) .setNumLayerShards(registerResponse.getNumLayerShards()) .build() - ) + ), + weightLoaderFunction ); } diff --git a/jlama-net/src/test/java/com/github/tjake/jlama/net/DistributedServiceTest.java b/jlama-net/src/test/java/com/github/tjake/jlama/net/DistributedServiceTest.java index 37d3ae0..a6a0fce 100644 --- a/jlama-net/src/test/java/com/github/tjake/jlama/net/DistributedServiceTest.java +++ b/jlama-net/src/test/java/com/github/tjake/jlama/net/DistributedServiceTest.java @@ -40,43 +40,25 @@ public class DistributedServiceTest { rootLogger.setLevel(Level.toLevel("info")); } - @Test - void oneWorkerTestLLama() throws Exception { - Path modelPath = Paths.get("../models/Llama-2-7b-chat-hf-jlama-Q4"); - Assume.assumeTrue(Files.exists(modelPath)); - - Coordinator coordinator = new Coordinator(modelPath.toFile(), com.google.common.io.Files.createTempDir(), 8888, 1); - try { - new Thread(() -> { - try { - coordinator.start(); - } catch (Exception e) { - e.printStackTrace(); - } - }).start(); - - startWorker(modelPath); - - coordinator.generate( - UUID.randomUUID(), - PromptContext.of("Simply put, the theory of relativity states that"), - 0.7f, - 256, - makeOutHandler() - ); - - } finally { - coordinator.stop(); - } - } - @Test void manyWorkerTestLLama() throws Exception { - // Path modelRoot = Paths.get("../models/Mixtral-8x7B-Instruct-v0.1-jlama-Q4"); Path modelRoot = Paths.get("../models/Meta-Llama-3.1-8B-Instruct-jlama-Q4"); + String modelName = "Meta-Llama-3.1-8B-Instruct-jlama-Q4"; + String modelOwner = "tjake"; + Assume.assumeTrue(Files.exists(modelRoot)); - Coordinator coordinator = new Coordinator(modelRoot.toFile(), null, 8888, 4); + Coordinator coordinator = new Coordinator( + modelRoot.toFile(), + modelOwner, + modelName, + DType.Q4, + null, + 8888, + 4, + Optional.empty(), + Optional.empty() + ); try { new Thread(() -> { try { @@ -86,10 +68,10 @@ void manyWorkerTestLLama() throws Exception { } }).start(); - startWorker(modelRoot); - startWorker(modelRoot); - startWorker(modelRoot); - startWorker(modelRoot); + startWorker(modelRoot, modelOwner, modelName); + startWorker(modelRoot, modelOwner, modelName); + startWorker(modelRoot, modelOwner, modelName); + startWorker(modelRoot, modelOwner, modelName); coordinator.generate( UUID.randomUUID(), @@ -103,8 +85,22 @@ void manyWorkerTestLLama() throws Exception { } } - private void startWorker(Path modelRoot) throws Exception { - Worker worker = new Worker(modelRoot.toFile(), "localhost", 8888, null, DType.F32, DType.I8, Optional.empty(), Optional.empty()); + private void startWorker(Path modelRoot, String modelOwner, String modelName) throws Exception { + Worker worker = new Worker( + modelRoot.toFile(), + modelOwner, + modelName, + DType.Q4, + "localhost", + 8888, + null, + DType.F32, + DType.I8, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty() + ); new Thread(() -> { try { worker.run(); diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java index 62667d7..ddccce2 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java @@ -18,13 +18,7 @@ import static com.github.tjake.jlama.util.JsonSupport.om; import com.github.tjake.jlama.math.VectorMath; -import com.github.tjake.jlama.model.bert.BertConfig; -import com.github.tjake.jlama.model.bert.BertModel; -import com.github.tjake.jlama.model.bert.BertTokenizer; import com.github.tjake.jlama.model.functions.Generator; -import com.github.tjake.jlama.model.gpt2.GPT2Config; -import com.github.tjake.jlama.model.gpt2.GPT2Model; -import com.github.tjake.jlama.model.gpt2.GPT2Tokenizer; import com.github.tjake.jlama.model.llama.LlamaConfig; import com.github.tjake.jlama.model.llama.LlamaModel; import com.github.tjake.jlama.model.llama.LlamaTokenizer; @@ -35,10 +29,7 @@ import com.github.tjake.jlama.safetensors.*; import com.github.tjake.jlama.safetensors.prompt.*; import com.github.tjake.jlama.safetensors.tokenizer.BPETokenizer; -import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer; import java.io.*; -import java.nio.ByteBuffer; -import java.nio.channels.FileChannel; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; @@ -67,23 +58,15 @@ public class TestModels { public void GPT2Run() throws IOException { String modelPrefix = "../models/gpt2-medium"; Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); - try (RandomAccessFile sc = new RandomAccessFile(modelPrefix + "/model.safetensors", "r")) { - ByteBuffer bb = sc.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, sc.length()); - - Weights v = SafeTensorSupport.readWeights(bb); - Tokenizer tokenizer = new GPT2Tokenizer(Paths.get(modelPrefix)); - Config c = om.readValue(new File(modelPrefix + "/config.json"), GPT2Config.class); - GPT2Model gpt2 = new GPT2Model(c, v, tokenizer, DType.F32, DType.F32, Optional.of(DType.F32)); - - PromptContext prompt = PromptContext.of( - "In a shocking finding, scientist discovered a herd of unicorns living in a remote, " - + "previously unexplored valley, in the Andes Mountains. " - + "Even more surprising to the researchers was the fact that the unicorns spoke perfect English." - ); - gpt2.generate(UUID.randomUUID(), prompt, 0.8f, 256, makeOutHandler()); - gpt2.generate(UUID.randomUUID(), prompt, 0.8f, 256, makeOutHandler()); - } + AbstractModel gpt2 = ModelSupport.loadModel(new File(modelPrefix), DType.F32, DType.F32); + PromptContext prompt = PromptContext.of( + "In a shocking finding, scientist discovered a herd of unicorns living in a remote, " + + "previously unexplored valley, in the Andes Mountains. " + + "Even more surprising to the researchers was the fact that the unicorns spoke perfect English." + ); + + gpt2.generate(UUID.randomUUID(), prompt, 0.8f, 256, makeOutHandler()); } @Test @@ -288,40 +271,32 @@ public void BertRun() throws Exception { String modelPrefix = "../models/e5-small-v2"; Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); - try (RandomAccessFile sc = new RandomAccessFile(modelPrefix + "/model.safetensors", "r")) { - ByteBuffer bb = sc.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, sc.length()); - - Weights weights = SafeTensorSupport.readWeights(bb); - Tokenizer tokenizer = new BertTokenizer(Paths.get(modelPrefix)); - Config c = om.readValue(new File(modelPrefix + "/config.json"), BertConfig.class); - BertModel model = new BertModel(c, weights, tokenizer, DType.F32, DType.F32, Optional.of(DType.F32)); - - String base = "A man is eating food."; - String[] examples = new String[] { "A man is eating a piece of bread.", "The girl is carrying a baby.", - "A man is riding a horse.", "A woman is playing violin.", "Two men pushed carts through the woods.", - "A man is riding a white horse on an enclosed ground.", "A monkey is playing drums.", - "Someone in a gorilla costume is playing a set of drums." }; - - float[] be = model.embed(base, Generator.PoolingType.AVG); - logger.info("base is {}", base); - float maxc = 0.0f; - String bestMatch = ""; - for (int i = 0; i < examples.length; i++) { - float vs = VectorMath.cosineSimilarity(be, model.embed(examples[i], Generator.PoolingType.AVG)); - logger.info("vs {} => {}", examples[i], vs); - if (vs > maxc) { - maxc = vs; - bestMatch = examples[i]; - } + AbstractModel model = ModelSupport.loadModel(new File(modelPrefix), DType.F32, DType.F32); + + String base = "A man is eating food."; + String[] examples = new String[] { "A man is eating a piece of bread.", "The girl is carrying a baby.", "A man is riding a horse.", + "A woman is playing violin.", "Two men pushed carts through the woods.", "A man is riding a white horse on an enclosed ground.", + "A monkey is playing drums.", "Someone in a gorilla costume is playing a set of drums." }; + + float[] be = model.embed(base, Generator.PoolingType.AVG); + logger.info("base is {}", base); + float maxc = 0.0f; + String bestMatch = ""; + for (int i = 0; i < examples.length; i++) { + float vs = VectorMath.cosineSimilarity(be, model.embed(examples[i], Generator.PoolingType.AVG)); + logger.info("vs {} => {}", examples[i], vs); + if (vs > maxc) { + maxc = vs; + bestMatch = examples[i]; } + } - logger.info("Best match for: '{}' is '{}'", base, bestMatch); + logger.info("Best match for: '{}' is '{}'", base, bestMatch); - long start = System.currentTimeMillis(); - VectorMath.pfor(0, 1000, i -> model.embed(base, Generator.PoolingType.AVG)); - long elapsed = System.currentTimeMillis() - start; - logger.info("took {} seconds, {}ms per emb", elapsed / 1000f, elapsed / 1000f); - } + long start = System.currentTimeMillis(); + VectorMath.pfor(0, 1000, i -> model.embed(base, Generator.PoolingType.AVG)); + long elapsed = System.currentTimeMillis() - start; + logger.info("took {} seconds, {}ms per emb", elapsed / 1000f, elapsed / 1000f); } private BiConsumer makeOutHandler() { diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestSample.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestSample.java index c8e4bf6..edb3d25 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestSample.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestSample.java @@ -72,16 +72,9 @@ public void sampleEmbed() throws IOException { AbstractModel m = ModelSupport.loadEmbeddingModel(localModelPath, DType.F32, DType.I8); String base = "A man is eating food."; - String[] examples = new String[] { - "A man is eating a piece of bread.", - "The girl is carrying a baby.", - "A man is riding a horse.", - "A woman is playing violin.", - "Two men pushed carts through the woods.", - "A man is riding a white horse on an enclosed ground.", - "A monkey is playing drums.", - "Someone in a gorilla costume is playing a set of drums." - }; + String[] examples = new String[] { "A man is eating a piece of bread.", "The girl is carrying a baby.", "A man is riding a horse.", + "A woman is playing violin.", "Two men pushed carts through the woods.", "A man is riding a white horse on an enclosed ground.", + "A monkey is playing drums.", "Someone in a gorilla costume is playing a set of drums." }; float[] be = m.embed(base, Generator.PoolingType.AVG); float maxc = 0.0f; diff --git a/jlama.java b/jlama.java deleted file mode 100755 index 3963134..0000000 --- a/jlama.java +++ /dev/null @@ -1,21 +0,0 @@ -///usr/bin/env jbang "$0" "$@" ; exit $? -//COMPILE_OPTIONS -source 20 -//RUNTIME_OPTIONS -server -Dstdout.encoding=UTF-8 -Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0 -//RUNTIME_OPTIONS --add-modules=jdk.incubator.vector --add-exports java.base/sun.nio.ch=ALL-UNNAMED -//RUNTIME_OPTIONS --enable-preview --enable-native-access=ALL-UNNAMED -XX:+UnlockDiagnosticVMOptions -//RUNTIME_OPTIONS -XX:+AlignVector -XX:+UseStringDeduplication -XX:+UseCompressedOops -XX:+UseCompressedClassPointers - -//DEPS com.github.tjake:jlama-cli:0.4.0 -//DEPS com.github.tjake:jlama-native:0.4.0:${os.detected.name}-${os.detected.arch} - -import static java.lang.System.*; -import com.github.tjake.jlama.cli.JlamaCli; - -/** - * JBANG! script for running jlama-cli - */ -public class jlama { - public static void main(String... args) { - JlamaCli.main(args); - } -} diff --git a/pom.xml b/pom.xml index d451f80..451d217 100644 --- a/pom.xml +++ b/pom.xml @@ -42,7 +42,7 @@ UTF-8 - 0.4.1 + 0.5.0 2.0.7 1.5.6 diff --git a/run-cli.sh b/run-cli.sh index ad88940..f1c160f 100755 --- a/run-cli.sh +++ b/run-cli.sh @@ -18,8 +18,8 @@ get_java_major_version() { # Verify Java version is JDK 20/21/22 JAVA=$(get_java_exec) JAVA_MAJOR_VERSION=$(get_java_major_version $JAVA) -if [[ "$JAVA_MAJOR_VERSION" != "20" ]] && [[ "$JAVA_MAJOR_VERSION" != "21" ]] && [[ "$JAVA_MAJOR_VERSION" != "22" ]]; then - echo "Error: JDK 20/21/22 is required to run this application." +if [[ "$JAVA_MAJOR_VERSION" != "20" ]] && [[ "$JAVA_MAJOR_VERSION" != "21" ]] && [[ "$JAVA_MAJOR_VERSION" != "22" ]] && [[ "$JAVA_MAJOR_VERSION" != "23" ]]; then + echo "Error: JDK 20/21/22/23 is required to run this application." exit 1 fi @@ -28,9 +28,9 @@ JLAMA_RELATIVE_JAR="./jlama-cli/target/jlama-cli.jar" # Path to the logback.xml LOGBACK_CONFIG="./conf/logback.xml" -JLAMA_JVM_ARGS="$JLAMA_JVM_ARGS -server -Dstdout.encoding=UTF-8 -Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0 --add-modules=jdk.incubator.vector --add-exports java.base/sun.nio.ch=ALL-UNNAMED --enable-preview --enable-native-access=ALL-UNNAMED \ +JLAMA_JVM_ARGS="$JLAMA_JVM_ARGS -server -Dstdout.encoding=UTF-8 -Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0 --add-opens=jdk.incubator.vector/jdk.incubator.vector=ALL-UNNAMED --add-modules=jdk.incubator.vector --add-exports java.base/sun.nio.ch=ALL-UNNAMED --enable-preview --enable-native-access=ALL-UNNAMED \ -XX:+UnlockDiagnosticVMOptions -XX:CompilerDirectivesFile=./inlinerules.json -XX:+AlignVector -XX:+UseStringDeduplication \ - -XX:+UseCompressedOops -XX:+UseCompressedClassPointers " + -XX:+UseCompressedOops -XX:+UseCompressedClassPointers" # Check if PREINSTALLED_JAR environment variable is set if [[ -z "$JLAMA_PREINSTALLED_JAR" ]]; then @@ -44,7 +44,7 @@ if [[ -z "$JLAMA_PREINSTALLED_JAR" ]]; then fi fi # Run the JAR in a relative directory - $JAVA $JLAMA_JVM_ARGS $JLAMA_JVM_ARGS_EXTRA -Dlogback.configurationFile=$LOGBACK_CONFIG -jar $JLAMA_RELATIVE_JAR "$@" + $JAVA $JLAMA_JVM_ARGS $JLAMA_JVM_ARGS_EXTRA -jar $JLAMA_RELATIVE_JAR "$@" else # If PREINSTALLED_JAR is set, run the JAR specified by the variable $JAVA $JLAMA_JVM_ARGS $JLAMA_JVM_ARGS_EXTRA -jar $JLAMA_PREINSTALLED_JAR "$@"