From 4ee18f5fc955115494c96b09b7abd43875ca37dd Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Tue, 22 Oct 2024 15:37:28 +0300 Subject: [PATCH] Introduce llama3-java module --- .../llama3-java/deployment/pom.xml | 56 + .../deployment/ChatModelBuildConfig.java | 16 + .../LangChain4jJlamaBuildTimeConfig.java | 25 + .../llama3/deployment/Llama3Processor.java | 238 +++ model-providers/llama3-java/pom.xml | 24 + model-providers/llama3-java/runtime/pom.xml | 115 ++ .../langchain4j/llama3/Llama3ChatModel.java | 184 +++ .../llama3/Llama3ModelRegistry.java | 245 +++ .../langchain4j/llama3/MessageMapper.java | 22 + .../langchain4j/llama3/ProgressReporter.java | 6 + .../langchain4j/llama3/copy/ChatFormat.java | 88 + .../langchain4j/llama3/copy/Llama.java | 336 ++++ .../langchain4j/llama3/copy/Llama3.java | 1464 +++++++++++++++++ .../langchain4j/llama3/copy/ModelLoader.java | 159 ++ .../langchain4j/llama3/copy/Sampler.java | 8 + .../langchain4j/llama3/copy/Tokenizer.java | 268 +++ .../llama3/runtime/Llama3Recorder.java | 80 + .../runtime/config/ChatModelConfig.java | 21 + .../config/ChatModelFixedRuntimeConfig.java | 20 + .../LangChain4jLlama3FixedRuntimeConfig.java | 52 + .../LangChain4jLlama3RuntimeConfig.java | 50 + .../config/ModelsPathConfigSource.java | 77 + .../llama3/runtime/graal/Llama3Feature.java | 19 + .../src/main/resources/META-INF/beans.xml | 0 .../resources/META-INF/quarkus-extension.yaml | 12 + 25 files changed, 3585 insertions(+) create mode 100644 model-providers/llama3-java/deployment/pom.xml create mode 100644 model-providers/llama3-java/deployment/src/main/java/io/quarkiverse/langchain4j/llama3/deployment/ChatModelBuildConfig.java create mode 100644 model-providers/llama3-java/deployment/src/main/java/io/quarkiverse/langchain4j/llama3/deployment/LangChain4jJlamaBuildTimeConfig.java create mode 100644 model-providers/llama3-java/deployment/src/main/java/io/quarkiverse/langchain4j/llama3/deployment/Llama3Processor.java create mode 100644 model-providers/llama3-java/pom.xml create mode 100644 model-providers/llama3-java/runtime/pom.xml create mode 100644 model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3ChatModel.java create mode 100644 model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3ModelRegistry.java create mode 100644 model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/MessageMapper.java create mode 100644 model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/ProgressReporter.java create mode 100644 model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/ChatFormat.java create mode 100644 model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Llama.java create mode 100755 model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Llama3.java create mode 100644 model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/ModelLoader.java create mode 100644 model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Sampler.java create mode 100644 model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Tokenizer.java create mode 100644 model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/Llama3Recorder.java create mode 100644 model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/ChatModelConfig.java create mode 100644 model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/ChatModelFixedRuntimeConfig.java create mode 100644 model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/LangChain4jLlama3FixedRuntimeConfig.java create mode 100644 model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/LangChain4jLlama3RuntimeConfig.java create mode 100644 model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/ModelsPathConfigSource.java create mode 100644 model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/graal/Llama3Feature.java create mode 100644 model-providers/llama3-java/runtime/src/main/resources/META-INF/beans.xml create mode 100644 model-providers/llama3-java/runtime/src/main/resources/META-INF/quarkus-extension.yaml diff --git a/model-providers/llama3-java/deployment/pom.xml b/model-providers/llama3-java/deployment/pom.xml new file mode 100644 index 000000000..d7cea0ac4 --- /dev/null +++ b/model-providers/llama3-java/deployment/pom.xml @@ -0,0 +1,56 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-llama3-java-parent + 999-SNAPSHOT + + quarkus-langchain4j-llama3-java-deployment + Quarkus LangChain4j - Llama3 - Java - Deployment + + + io.quarkiverse.langchain4j + quarkus-langchain4j-llama3-java + ${project.version} + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core-deployment + ${project.version} + + + io.quarkus + quarkus-junit5-internal + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + io.quarkiverse.langchain4j + quarkus-langchain4j-testing-internal + ${project.version} + test + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${quarkus.version} + + + + + + + diff --git a/model-providers/llama3-java/deployment/src/main/java/io/quarkiverse/langchain4j/llama3/deployment/ChatModelBuildConfig.java b/model-providers/llama3-java/deployment/src/main/java/io/quarkiverse/langchain4j/llama3/deployment/ChatModelBuildConfig.java new file mode 100644 index 000000000..2a22b6f2e --- /dev/null +++ b/model-providers/llama3-java/deployment/src/main/java/io/quarkiverse/langchain4j/llama3/deployment/ChatModelBuildConfig.java @@ -0,0 +1,16 @@ +package io.quarkiverse.langchain4j.llama3.deployment; + +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigGroup; + +@ConfigGroup +public interface ChatModelBuildConfig { + + /** + * Whether the model should be enabled + */ + @ConfigDocDefault("true") + Optional enabled(); +} diff --git a/model-providers/llama3-java/deployment/src/main/java/io/quarkiverse/langchain4j/llama3/deployment/LangChain4jJlamaBuildTimeConfig.java b/model-providers/llama3-java/deployment/src/main/java/io/quarkiverse/langchain4j/llama3/deployment/LangChain4jJlamaBuildTimeConfig.java new file mode 100644 index 000000000..685642470 --- /dev/null +++ b/model-providers/llama3-java/deployment/src/main/java/io/quarkiverse/langchain4j/llama3/deployment/LangChain4jJlamaBuildTimeConfig.java @@ -0,0 +1,25 @@ +package io.quarkiverse.langchain4j.llama3.deployment; + +import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME; + +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithDefault; + +@ConfigRoot(phase = BUILD_TIME) +@ConfigMapping(prefix = "quarkus.langchain4j.llama3") +public interface LangChain4jJlamaBuildTimeConfig { + + /** + * Determines whether the necessary Jlama models are downloaded and included in the jar at build time. + * Currently, this option is only valid for {@code fast-jar} deployments. + */ + @WithDefault("true") + boolean includeModelsInArtifact(); + + /** + * Chat model related settings + */ + ChatModelBuildConfig chatModel(); + +} diff --git a/model-providers/llama3-java/deployment/src/main/java/io/quarkiverse/langchain4j/llama3/deployment/Llama3Processor.java b/model-providers/llama3-java/deployment/src/main/java/io/quarkiverse/langchain4j/llama3/deployment/Llama3Processor.java new file mode 100644 index 000000000..893d2694f --- /dev/null +++ b/model-providers/llama3-java/deployment/src/main/java/io/quarkiverse/langchain4j/llama3/deployment/Llama3Processor.java @@ -0,0 +1,238 @@ +package io.quarkiverse.langchain4j.llama3.deployment; + +import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.CHAT_MODEL; + +import java.nio.file.Path; +import java.util.List; + +import jakarta.enterprise.context.ApplicationScoped; + +import org.jboss.jandex.AnnotationInstance; +import org.jboss.logging.Logger; +import org.slf4j.LoggerFactory; + +import io.quarkiverse.langchain4j.ModelName; +import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem; +import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; +import io.quarkiverse.langchain4j.llama3.runtime.Llama3Recorder; +import io.quarkiverse.langchain4j.llama3.runtime.config.LangChain4jLlama3FixedRuntimeConfig; +import io.quarkiverse.langchain4j.llama3.runtime.config.LangChain4jLlama3RuntimeConfig; +import io.quarkiverse.langchain4j.llama3.runtime.graal.Llama3Feature; +import io.quarkiverse.langchain4j.runtime.NamedConfigUtil; +import io.quarkus.arc.deployment.SyntheticBeanBuildItem; +import io.quarkus.builder.item.MultiBuildItem; +import io.quarkus.deployment.annotations.BuildProducer; +import io.quarkus.deployment.annotations.BuildStep; +import io.quarkus.deployment.annotations.ExecutionTime; +import io.quarkus.deployment.annotations.Record; +import io.quarkus.deployment.builditem.FeatureBuildItem; +import io.quarkus.deployment.builditem.NativeImageFeatureBuildItem; +import io.quarkus.deployment.builditem.nativeimage.NativeImageEnableModule; +import io.quarkus.deployment.builditem.nativeimage.RuntimeInitializedPackageBuildItem; + +public class Llama3Processor { + + private final static Logger LOGGER = Logger.getLogger(Llama3Processor.class); + + private static final String FEATURE = "langchain4j-llama3-java"; + private static final String PROVIDER = "llama3-java"; + private static final org.slf4j.Logger log = LoggerFactory.getLogger(Llama3Processor.class); + + @BuildStep + FeatureBuildItem feature() { + return new FeatureBuildItem(FEATURE); + } + + @BuildStep + public void providerCandidates(BuildProducer chatProducer, + LangChain4jJlamaBuildTimeConfig config) { + if (config.chatModel().enabled().isEmpty() || config.chatModel().enabled().get()) { + chatProducer.produce(new ChatModelProviderCandidateBuildItem(PROVIDER)); + } + } + + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + @BuildStep + @Record(ExecutionTime.RUNTIME_INIT) + void generateBeans(Llama3Recorder recorder, + List selectedChatItem, + LangChain4jLlama3RuntimeConfig runtimeConfig, + LangChain4jLlama3FixedRuntimeConfig fixedRuntimeConfig, + BuildProducer beanProducer) { + + for (var selected : selectedChatItem) { + if (PROVIDER.equals(selected.getProvider())) { + String configName = selected.getConfigName(); + var builder = SyntheticBeanBuildItem.configure(CHAT_MODEL).setRuntimeInit().defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.chatModel(runtimeConfig, fixedRuntimeConfig, configName)); + addQualifierIfNecessary(builder, configName); + beanProducer.produce(builder.done()); + } + } + } + + private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String configName) { + if (!NamedConfigUtil.isDefault(configName)) { + builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", configName).build()); + } + } + + @BuildStep + public void nativeSupport(BuildProducer runtimeInitializedPackageProducer, + BuildProducer moduleProducer, + BuildProducer nativeFeatureProducer) { + runtimeInitializedPackageProducer + .produce(new RuntimeInitializedPackageBuildItem("io.quarkiverse.langchain4j.llama3.copy")); + moduleProducer.produce(new NativeImageEnableModule("jdk.incubator.vector")); + nativeFeatureProducer.produce(new NativeImageFeatureBuildItem(Llama3Feature.class)); + } + + // @Produce(ServiceStartBuildItem.class) + // @BuildStep + // void downloadModels(List selectedChatModels, + // List selectedEmbeddingModels, + // LoggingSetupBuildItem loggingSetupBuildItem, + // Optional consoleInstalledBuildItem, + // LaunchModeBuildItem launchMode, + // LangChain4jJlamaBuildTimeConfig buildTimeConfig, + // LangChain4jJlamaFixedRuntimeConfig fixedRuntimeConfig, + // BuildProducer modelDownloadedProducer) { + // if (!buildTimeConfig.includeModelsInArtifact()) { + // return; + // } + // JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(fixedRuntimeConfig.modelsPath()); + // + // BigDecimal ONE_HUNDRED = new BigDecimal("100"); + // + // if (buildTimeConfig.chatModel().enabled().orElse(true) || buildTimeConfig.embeddingModel().enabled().orElse(true)) { + // List modelsNeeded = new ArrayList<>(); + // for (var selected : selectedChatModels) { + // if (PROVIDER.equals(selected.getProvider())) { + // String configName = selected.getConfigName(); + // + // String modelName = NamedConfigUtil.isDefault(configName) + // ? fixedRuntimeConfig.defaultConfig().chatModel().modelName() + // : fixedRuntimeConfig.namedConfig().get(configName).chatModel().modelName(); + // modelsNeeded.add(modelName); + // } + // } + // + // for (var selected : selectedEmbeddingModels) { + // if (PROVIDER.equals(selected.getProvider())) { + // String configName = selected.getConfigName(); + // + // String modelName = NamedConfigUtil.isDefault(configName) + // ? fixedRuntimeConfig.defaultConfig().embeddingModel().modelName() + // : fixedRuntimeConfig.namedConfig().get(configName).embeddingModel().modelName(); + // modelsNeeded.add(modelName); + // } + // } + // + // if (!modelsNeeded.isEmpty()) { + // StartupLogCompressor compressor = new StartupLogCompressor( + // (launchMode.isTest() ? "(test) " : "") + "Jlama model pull:", + // consoleInstalledBuildItem, + // loggingSetupBuildItem); + // + // for (String modelName : modelsNeeded) { + // JlamaModelRegistry.ModelInfo modelInfo = JlamaModelRegistry.ModelInfo.from(modelName); + // Path pathOfModelDirOnDisk = SafeTensorSupport.constructLocalModelPath( + // registry.getModelCachePath().toAbsolutePath().toString(), modelInfo.owner(), + // modelInfo.name()); + // // Check if the model is already downloaded + // // this is done automatically by download model, but we want to provide a good progress experience, so we do it again here + // if (Files.exists(pathOfModelDirOnDisk.resolve(".finished"))) { + // LOGGER.debug("Model " + modelName + "already exists in " + pathOfModelDirOnDisk); + // } else { + // // we pull one model at a time and provide progress updates to the user via logging + // LOGGER.info("Pulling model " + modelName); + // + // try { + // registry.downloadModel(modelName, Optional.empty(), Optional.of(new ProgressReporter() { + // @Override + // public void update(String filename, long sizeDownloaded, long totalSize) { + // // Jlama downloads a bunch of files for each mode of which only the weights file is large + // // and makes sense to report progress on + // if (totalSize < 100_000) { + // return; + // } + // + // BigDecimal percentage = new BigDecimal(sizeDownloaded).divide(new BigDecimal(totalSize), 4, + // RoundingMode.HALF_DOWN).multiply(ONE_HUNDRED); + // BigDecimal progress = percentage.setScale(2, RoundingMode.HALF_DOWN); + // if (progress.compareTo(ONE_HUNDRED) >= 0) { + // // avoid showing 100% for too long + // LOGGER.infof("Verifying and cleaning up\n", progress); + // } else { + // LOGGER.infof("Progress: %s%%\n", progress); + // } + // } + // })); + // } catch (IOException e) { + // compressor.closeAndDumpCaptured(); + // throw new UncheckedIOException(e); + // } + // } + // + // modelDownloadedProducer.produce(new ModelDownloadedBuildItem(modelName, pathOfModelDirOnDisk)); + // } + // + // compressor.close(); + // } + // } + // + // } + + /** + * When building a fast jar, we can copy the model files into the directory + * + */ + // @BuildStep(onlyIf = IsNormal.class) + // @Produce(ArtifactResultBuildItem.class) + // public void copyToFastJar(List models, + // Optional jarBuildItem) { + // if (!jarBuildItem.isPresent()) { + // return; + // } + // + // Path jarPath = jarBuildItem.get().getPath(); + // if (!JarResultBuildStep.QUARKUS_RUN_JAR.equals(jarPath.getFileName().toString())) { + // return; + // } + // + // Path quarkusAppDir = jarPath.getParent(); + // Path jlamaInQuarkusAppDir = quarkusAppDir.resolve("jlama"); + // + // for (ModelDownloadedBuildItem bi : models) { + // try { + // JlamaModelRegistry.ModelInfo modelInfo = JlamaModelRegistry.ModelInfo.from(bi.getModelName()); + // Path targetDir = jlamaInQuarkusAppDir.resolve(modelInfo.toFileName()); + // Files.createDirectories(targetDir); + // PathUtils.copyDirectory(bi.getDirectory(), targetDir); + // } catch (IOException e) { + // throw new UncheckedIOException(e); + // } + // } + // + // } + + public static final class ModelDownloadedBuildItem extends MultiBuildItem { + + private final String modelName; + private final Path directory; + + public ModelDownloadedBuildItem(String modelName, Path directory) { + this.modelName = modelName; + this.directory = directory; + } + + public String getModelName() { + return modelName; + } + + public Path getDirectory() { + return directory; + } + } +} diff --git a/model-providers/llama3-java/pom.xml b/model-providers/llama3-java/pom.xml new file mode 100644 index 000000000..45f4f1867 --- /dev/null +++ b/model-providers/llama3-java/pom.xml @@ -0,0 +1,24 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-parent + 999-SNAPSHOT + ../../pom.xml + + quarkus-langchain4j-llama3-java-parent + Quarkus LangChain4j - Llama 3 - Java - Parent + pom + + + 21 + + + + deployment + runtime + + + + diff --git a/model-providers/llama3-java/runtime/pom.xml b/model-providers/llama3-java/runtime/pom.xml new file mode 100644 index 000000000..54a3cba3d --- /dev/null +++ b/model-providers/llama3-java/runtime/pom.xml @@ -0,0 +1,115 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-llama3-java-parent + 999-SNAPSHOT + + quarkus-langchain4j-llama3-java + Quarkus LangChain4j - Llama3 - Java - Runtime + + + io.quarkus + quarkus-arc + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core + ${project.version} + + + + io.smallrye.common + smallrye-common-resource + + + + org.graalvm.sdk + graal-sdk + provided + + + + + io.quarkus + quarkus-junit5-internal + test + + + org.mockito + mockito-core + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + + + + io.quarkus + quarkus-extension-maven-plugin + ${quarkus.version} + + + compile + + extension-descriptor + + + ${project.groupId}:${project.artifactId}-deployment:${project.version} + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${quarkus.version} + + + + 21 + 21 + + + --add-modules=jdk.incubator.vector + + + + + maven-jar-plugin + + + generate-codestart-jar + generate-resources + + jar + + + ${project.basedir}/src/main + + codestarts/** + + codestarts + true + + + + + + + + diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3ChatModel.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3ChatModel.java new file mode 100644 index 000000000..24be10811 --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3ChatModel.java @@ -0,0 +1,184 @@ +package io.quarkiverse.langchain4j.llama3; + +import static dev.langchain4j.data.message.AiMessage.aiMessage; +import static io.quarkiverse.langchain4j.llama3.MessageMapper.toLlama3Message; +import static io.quarkiverse.langchain4j.llama3.copy.Llama3.selectSampler; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import io.quarkiverse.langchain4j.llama3.copy.ChatFormat; +import io.quarkiverse.langchain4j.llama3.copy.Llama; +import io.quarkiverse.langchain4j.llama3.copy.Llama3; +import io.quarkiverse.langchain4j.llama3.copy.Sampler; + +public class Llama3ChatModel implements ChatLanguageModel { + + private final Path modelPath; + private final Llama model; + private final Float temperature; + private final Integer maxTokens; + private final Float topP; + private final Integer seed; + + public Llama3ChatModel(Llama3ChatModelBuilder builder) { + Llama3ModelRegistry llama3ModelRegistry = Llama3ModelRegistry.getOrCreate(builder.modelCachePath); + try { + modelPath = llama3ModelRegistry.downloadModel(builder.modelName, builder.quantization, + Optional.ofNullable(builder.authToken), Optional.empty()); + model = llama3ModelRegistry.loadModel(builder.modelName, builder.quantization, builder.maxTokens, true); + } catch (IOException e) { + throw new UncheckedIOException(e); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + temperature = builder.temperature; + maxTokens = builder.maxTokens; + topP = builder.topP; + seed = builder.seed; + } + + @Override + public Response generate(List messages) { + + List llamaMessages = new ArrayList<>(); + for (ChatMessage message : messages) { + llamaMessages.add(toLlama3Message(message)); + } + + String systemPrompt = llamaMessages.stream().filter(m -> m.role().equals(ChatFormat.Role.SYSTEM)).findFirst().map( + ChatFormat.Message::content).orElse(null); + String prompt = llamaMessages.stream().filter(m -> m.role().equals(ChatFormat.Role.USER)).findFirst().map( + ChatFormat.Message::content).orElse(null); + if (prompt == null) { + throw new IllegalArgumentException("No UserMessage found"); + } + + Llama3.Options options = new Llama3.Options( + modelPath, + prompt, + systemPrompt, + false, + temperature, + topP, + seed, + maxTokens, + false, // stream + false // echo + ); + Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(), + options.seed()); + InferenceResponse inferenceResponse = runInstructOnce(model, sampler, options); + + return Response.from(aiMessage(inferenceResponse.text()), + new TokenUsage(inferenceResponse.promptTokens(), inferenceResponse.responseTokens())); + } + + private InferenceResponse runInstructOnce(Llama model, Sampler sampler, Llama3.Options options) { + if (options.stream()) { + throw new IllegalStateException("stream in not supported"); + } + + Llama.State state = model.createNewState(); + ChatFormat chatFormat = new ChatFormat(model.tokenizer()); + + List promptTokens = new ArrayList<>(); + promptTokens.add(chatFormat.getBeginOfText()); + if (options.systemPrompt() != null) { + promptTokens + .addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); + } + promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt()))); + promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + + Set stopTokens = chatFormat.getStopTokens(); + List responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(), + sampler, options.echo(), token -> { + if (options.stream()) { + if (!model.tokenizer().isSpecialToken(token)) { + System.out.print(model.tokenizer().decode(List.of(token))); + } + } + }); + if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { + responseTokens.removeLast(); + } + + return new InferenceResponse(model.tokenizer().decode(responseTokens), promptTokens.size(), responseTokens.size()); + } + + record InferenceResponse(String text, int promptTokens, int responseTokens) { + + } + + public static Llama3ChatModelBuilder builder() { + return new Llama3ChatModelBuilder(); + } + + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + public static class Llama3ChatModelBuilder { + + private Optional modelCachePath; + private String modelName = "mukel/Llama-3.2-3B-Instruct-GGUF"; + private String quantization = "Q4_0"; + private String authToken; + private Integer maxTokens = 4_000; + private Float temperature = 0.7f; + private Float topP = 0.95f; + private Integer seed = 17; + + public Llama3ChatModelBuilder modelCachePath(Optional modelCachePath) { + this.modelCachePath = modelCachePath; + return this; + } + + public Llama3ChatModelBuilder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public Llama3ChatModelBuilder quantization(String quantization) { + this.quantization = quantization; + return this; + } + + public Llama3ChatModelBuilder authToken(String authToken) { + this.authToken = authToken; + return this; + } + + public Llama3ChatModelBuilder temperature(Float temperature) { + this.temperature = temperature; + return this; + } + + public Llama3ChatModelBuilder maxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Llama3ChatModelBuilder topP(Float topP) { + this.topP = topP; + return this; + } + + public Llama3ChatModelBuilder seed(Integer seed) { + this.seed = seed; + return this; + } + + public Llama3ChatModel build() { + return new Llama3ChatModel(this); + } + } +} diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3ModelRegistry.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3ModelRegistry.java new file mode 100644 index 000000000..344ff72eb --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3ModelRegistry.java @@ -0,0 +1,245 @@ +package io.quarkiverse.langchain4j.llama3; + +import java.io.File; +import java.io.FilterInputStream; +import java.io.IOError; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import org.jboss.logging.Logger; + +import io.quarkiverse.langchain4j.llama3.copy.Llama; +import io.quarkiverse.langchain4j.llama3.copy.ModelLoader; + +/** + * A registry for managing Jlama models on local disk. + */ +@SuppressWarnings("OptionalUsedAsFieldOrParameterType") +public class Llama3ModelRegistry { + + private static final Logger log = Logger.getLogger(Llama3ModelRegistry.class); + + private static final String DEFAULT_MODEL_CACHE_PATH = System.getProperty("user.home", "") + File.separator + ".llama3java" + + File.separator + "models"; + private final Path modelCachePath; + + private Llama3ModelRegistry(Path modelCachePath) { + this.modelCachePath = modelCachePath; + if (!Files.exists(modelCachePath)) { + try { + Files.createDirectories(modelCachePath); + } catch (IOException e) { + throw new IOError(e); + } + } + } + + public static Llama3ModelRegistry getOrCreate(Optional modelCachePath) { + return new Llama3ModelRegistry(modelCachePath.orElse(Path.of(DEFAULT_MODEL_CACHE_PATH))); + } + + public static Path constructLocalModelPath(String modelDir, String owner, String modelName) { + return Paths.get(modelDir, owner + "_" + modelName); + } + + public Path getModelCachePath() { + return modelCachePath; + } + + public Path downloadModel(String modelName, String quantization, Optional authToken, + Optional maybeProgressReporter) + throws IOException, InterruptedException { + ModelInfo modelInfo = ModelInfo.from(modelName); + + String effectiveFileName = getEffectiveFileName(modelInfo, quantization); + Path modelDirectory = constructLocalModelPath(modelCachePath.toAbsolutePath().toString(), modelInfo.owner(), + modelInfo.name()); + Path result = modelDirectory.resolve(effectiveFileName); + if (Files.exists(result)) { + return result; + } + + HttpClient client = HttpClient.newBuilder().followRedirects(HttpClient.Redirect.ALWAYS).build(); + URI uri = URI.create( + String.format("https://huggingface.co/%s/%s/resolve/main/%s", modelInfo.owner(), modelInfo.name(), + effectiveFileName)); + HttpRequest request = HttpRequest.newBuilder().uri(uri).build(); + HttpResponse httpResponse = client.send(request, HttpResponse.BodyHandlers.ofInputStream()); + if (httpResponse.statusCode() != 200) { + throw new RuntimeException( + "Unable to download model " + modelName + ". Response code from " + uri + " is : " + + httpResponse.statusCode()); + } + Files.createDirectories(result.getParent()); + long totalBytes = httpResponse.headers().firstValueAsLong("content-length").orElse(-1); + ProgressReporter progressReporter = maybeProgressReporter.orElse((filename, sizeDownloaded, totalSize) -> { + }); + + if (maybeProgressReporter.isEmpty()) { + log.info("Downloading file " + result.toAbsolutePath()); + } + String resultFileName = result.getFileName().toString(); + progressReporter.update(resultFileName, 0L, totalBytes); + + try (CountingInputStream inStream = new CountingInputStream(httpResponse.body())) { + CompletableFuture cf = CompletableFuture.supplyAsync(() -> { + try { + return Files.copy(inStream, result, StandardCopyOption.REPLACE_EXISTING); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + while (!cf.isDone()) { + progressReporter.update(resultFileName, inStream.count, totalBytes); + } + if (cf.isCompletedExceptionally()) { + progressReporter.update(resultFileName, inStream.count, totalBytes); + } else { + progressReporter.update(resultFileName, totalBytes, totalBytes); + } + + try { + cf.get(); + } catch (Throwable e) { + throw new IOException("Failed to download file: " + resultFileName, e); + } + if (maybeProgressReporter.isEmpty()) { + log.info("Downloaded file " + result.toAbsolutePath()); + } + } + + return result; + } + + private String getEffectiveFileName(ModelInfo modelInfo, String quantization) { + String effectiveFileName = modelInfo.name(); + if (effectiveFileName.endsWith("-GGUF")) { + effectiveFileName = effectiveFileName.substring(0, effectiveFileName.length() - 5); + } + effectiveFileName = effectiveFileName + "-" + quantization + ".gguf"; + return effectiveFileName; + } + + public Llama loadModel(String modelName, String quantization, int contextLength, boolean loadWeights) throws IOException { + ModelInfo modelInfo = ModelInfo.from(modelName); + + String effectiveFileName = getEffectiveFileName(modelInfo, quantization); + Path modelDirectory = constructLocalModelPath(modelCachePath.toAbsolutePath().toString(), modelInfo.owner(), + modelInfo.name()); + Path result = modelDirectory.resolve(effectiveFileName); + if (Files.exists(result)) { + return ModelLoader.loadModel(result, contextLength, loadWeights); + } + throw new IllegalStateException("No gguf file found for model name " + modelName + " and quantization " + quantization); + } + + public record ModelInfo(String owner, String name) { + + public static ModelInfo from(String modelName) { + String[] parts = modelName.split("/"); + if (parts.length == 0 || parts.length > 2) { + throw new IllegalArgumentException("Model must be in the form owner/name"); + } + + String owner; + String name; + + if (parts.length == 1) { + owner = null; + name = modelName; + } else { + owner = parts[0]; + name = parts[1]; + } + + return new ModelInfo(owner, name); + } + + public String toFileName() { + return owner + "_" + name; + } + } + + /** + * An {@link InputStream} that counts the number of bytes read. + * + * @author Chris Nokleberg + * + * Copied from Guava + */ + public static final class CountingInputStream extends FilterInputStream { + + private long count; + private long mark = -1; + + /** + * Wraps another input stream, counting the number of bytes read. + * + * @param in the input stream to be wrapped + */ + public CountingInputStream(InputStream in) { + super(Objects.requireNonNull(in)); + } + + /** Returns the number of bytes read. */ + public long getCount() { + return count; + } + + @Override + public int read() throws IOException { + int result = in.read(); + if (result != -1) { + count++; + } + return result; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + int result = in.read(b, off, len); + if (result != -1) { + count += result; + } + return result; + } + + @Override + public long skip(long n) throws IOException { + long result = in.skip(n); + count += result; + return result; + } + + @Override + public synchronized void mark(int readlimit) { + in.mark(readlimit); + mark = count; + // it's okay to mark even if mark isn't supported, as reset won't work + } + + @Override + public synchronized void reset() throws IOException { + if (!in.markSupported()) { + throw new IOException("Mark not supported"); + } + if (mark == -1) { + throw new IOException("Mark not set"); + } + + in.reset(); + count = mark; + } + } +} diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/MessageMapper.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/MessageMapper.java new file mode 100644 index 000000000..f968bf2b4 --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/MessageMapper.java @@ -0,0 +1,22 @@ +package io.quarkiverse.langchain4j.llama3; + +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ChatMessageType; +import io.quarkiverse.langchain4j.llama3.copy.ChatFormat; + +final class MessageMapper { + + static ChatFormat.Message toLlama3Message(ChatMessage langchainMessage) { + ChatFormat.Role role = toJllamaRole(langchainMessage.type()); + return new ChatFormat.Message(role, langchainMessage.text()); + } + + private static ChatFormat.Role toJllamaRole(ChatMessageType chatMessageType) { + return switch (chatMessageType) { + case SYSTEM -> ChatFormat.Role.SYSTEM; + case USER -> ChatFormat.Role.USER; + case AI -> ChatFormat.Role.ASSISTANT; + default -> throw new IllegalArgumentException("Unsupported chat message type: " + chatMessageType); + }; + } +} diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/ProgressReporter.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/ProgressReporter.java new file mode 100644 index 000000000..bdd8bd7b8 --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/ProgressReporter.java @@ -0,0 +1,6 @@ +package io.quarkiverse.langchain4j.llama3; + +public interface ProgressReporter { + + void update(String filename, long sizeDownloaded, long totalSize); +} diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/ChatFormat.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/ChatFormat.java new file mode 100644 index 000000000..b6f5799c8 --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/ChatFormat.java @@ -0,0 +1,88 @@ +package io.quarkiverse.langchain4j.llama3.copy; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Utility tailored for Llama 3 instruct prompt format. + */ +public class ChatFormat { + + final Tokenizer tokenizer; + final int beginOfText; + final int endHeader; + final int startHeader; + final int endOfTurn; + final int endOfText; + final int endOfMessage; + final Set stopTokens; + + public ChatFormat(Tokenizer tokenizer) { + this.tokenizer = tokenizer; + Map specialTokens = this.tokenizer.getSpecialTokens(); + this.beginOfText = specialTokens.get("<|begin_of_text|>"); + this.startHeader = specialTokens.get("<|start_header_id|>"); + this.endHeader = specialTokens.get("<|end_header_id|>"); + this.endOfTurn = specialTokens.get("<|eot_id|>"); + this.endOfText = specialTokens.get("<|end_of_text|>"); + this.endOfMessage = specialTokens.getOrDefault("<|eom_id|>", -1); // only in 3.1 + this.stopTokens = Set.of(endOfText, endOfTurn); + } + + public Tokenizer getTokenizer() { + return tokenizer; + } + + public Set getStopTokens() { + return stopTokens; + } + + public int getBeginOfText() { + return beginOfText; + } + + public List encodeHeader(Message message) { + List tokens = new ArrayList<>(); + tokens.add(startHeader); + tokens.addAll(this.tokenizer.encodeAsList(message.role().name())); + tokens.add(endHeader); + tokens.addAll(this.tokenizer.encodeAsList("\n")); + return tokens; + } + + public List encodeMessage(Message message) { + List tokens = this.encodeHeader(message); + tokens.addAll(this.tokenizer.encodeAsList(message.content().strip())); + tokens.add(endOfTurn); + return tokens; + } + + public List encodeDialogPrompt(boolean appendAssistantTurn, List dialog) { + List tokens = new ArrayList<>(); + tokens.add(beginOfText); + for (Message message : dialog) { + tokens.addAll(this.encodeMessage(message)); + } + if (appendAssistantTurn) { + // Add the start of an assistant message for the model to complete. + tokens.addAll(this.encodeHeader(new Message(Role.ASSISTANT, ""))); + } + return tokens; + } + + public record Message(Role role, String content) { + } + + public record Role(String name) { + public static Role SYSTEM = new Role("system"); + public static Role USER = new Role("user"); + public static Role ASSISTANT = new Role("assistant"); + + @Override + public String toString() { + return name; + } + } +} diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Llama.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Llama.java new file mode 100644 index 000000000..a9bc2b6fd --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Llama.java @@ -0,0 +1,336 @@ +package io.quarkiverse.langchain4j.llama3.copy; + +import java.nio.FloatBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.function.IntConsumer; +import java.util.stream.Stream; + +public record Llama(Configuration configuration, Tokenizer tokenizer, Weights weights) { + + public State createNewState() { + State state = new State(configuration()); + state.latestToken = tokenizer.getSpecialTokens().get("<|begin_of_text|>"); + return state; + } + + public static final class Configuration { + public final int dim; // transformer dimension + public final int hiddenDim; // for ffn layers + public final int numberOfLayers; // number of layers + public final int numberOfHeads; // number of query heads + public final int numberOfKeyValueHeads; // number of key/value heads (can be < query heads because of multiquery) + public final int vocabularySize; // vocabulary size, usually 256 (byte-level) + public final int contextLength; // max sequence length + public final float rmsNormEps; + public final float ropeTheta; + public final int headSize; + + Configuration withContextLength(int newContextLength) { + if (newContextLength < 0) { + return this; // no change + } + return new Configuration(this.dim, this.hiddenDim, this.numberOfLayers, this.numberOfHeads, + this.numberOfKeyValueHeads, this.vocabularySize, newContextLength, this.rmsNormEps, this.ropeTheta); + } + + public Configuration(int dim, int hiddenDim, int numberOfLayers, int numberOfHeads, int numberOfKeyValueHeads, + int vocabularySize, int contextLength, float rmsNormEps, float ropeTheta) { + this.dim = dim; + this.hiddenDim = hiddenDim; + this.numberOfLayers = numberOfLayers; + this.numberOfHeads = numberOfHeads; + this.numberOfKeyValueHeads = numberOfKeyValueHeads; + this.vocabularySize = vocabularySize; + this.contextLength = contextLength; + this.rmsNormEps = rmsNormEps; + this.ropeTheta = ropeTheta; + this.headSize = dim / numberOfHeads; + } + } + + public static final class Weights { + // token embedding table + public final FloatTensor token_embedding_table; // (vocab_size, dim) + // weights for rmsnorms + public final FloatBuffer[] rms_att_weight; // (layer, dim) rmsnorm weights + // weights for matmuls + public final FloatTensor[] wq; // (layer, n_heads * head_size) + public final FloatTensor[] wk; // (layer, n_kv_heads, head_size) + public final FloatTensor[] wv; // (layer, n_kv_heads * head_size) + public final FloatTensor[] wo; // (layer, n_heads * head_size, dim) + public final FloatBuffer[] rms_ffn_weight; // (layer, dim) + // weights for ffn + public final FloatTensor[] w1; // (layer, hidden_dim, dim) + public final FloatTensor[] w2; // (layer, dim, hidden_dim) + public final FloatTensor[] w3; // (layer, hidden_dim, dim) + // public final rmsnorm + public final FloatBuffer rms_final_weight; // (dim,) + // freq_cis for RoPE relatively positional embeddings + public final FloatBuffer freq_cis_real; // (seq_len, head_size/2) + public final FloatBuffer freq_cis_imag; // (seq_len, head_size/2) + // (optional) classifier weights for the logits, on the last layer + public final FloatTensor wcls; // (vocab_size, dim) + + public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight, FloatTensor[] wq, FloatTensor[] wk, + FloatTensor[] wv, FloatTensor[] wo, FloatBuffer[] rms_ffn_weight, FloatTensor[] w1, FloatTensor[] w2, + FloatTensor[] w3, FloatBuffer rms_final_weight, FloatBuffer freq_cis_real, FloatBuffer freq_cis_imag, + FloatTensor wcls) { + this.token_embedding_table = token_embedding_table; + this.rms_att_weight = rms_att_weight; + this.wq = wq; + this.wk = wk; + this.wv = wv; + this.wo = wo; + this.rms_ffn_weight = rms_ffn_weight; + this.w1 = w1; + this.w2 = w2; + this.w3 = w3; + this.rms_final_weight = rms_final_weight; + this.freq_cis_real = freq_cis_real; + this.freq_cis_imag = freq_cis_imag; + this.wcls = wcls; + } + } + + public static final class State { + + // current wave of activations + public final FloatTensor x; // activation at current time stamp (dim,) + public final FloatTensor xb; // same, but inside a residual branch (dim,) + public final FloatTensor xb2; // an additional buffer just for convenience (dim,) + public final FloatTensor hb; // buffer for hidden dimension in the ffn (hidden_dim,) + public final FloatTensor hb2; // buffer for hidden dimension in the ffn (hidden_dim,) + public final FloatTensor q; // query (dim,) + public final FloatTensor k; // key (dim,) + public final FloatTensor v; // value (dim,) + public final FloatTensor att; // buffer for scores/attention values (n_heads, seq_len) + public final FloatTensor logits; // output logits + // kv cache + public final FloatTensor[] keyCache; // (n_layer, seq_len, kv_dim) + public final FloatTensor[] valueCache; // (n_layer, seq_len, kv_dim) + + public int latestToken; + + State(Configuration config) { + this.x = ArrayFloatTensor.allocate(config.dim); + this.xb = ArrayFloatTensor.allocate(config.dim); + this.xb2 = ArrayFloatTensor.allocate(config.dim); + this.hb = ArrayFloatTensor.allocate(config.hiddenDim); + this.hb2 = ArrayFloatTensor.allocate(config.hiddenDim); + this.q = ArrayFloatTensor.allocate(config.dim); + this.k = ArrayFloatTensor.allocate(config.dim); + this.v = ArrayFloatTensor.allocate(config.dim); + this.att = ArrayFloatTensor.allocate(config.numberOfHeads, config.contextLength); + this.logits = ArrayFloatTensor.allocate(config.vocabularySize); + int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads; + this.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength, kvDim)) + .limit(config.numberOfLayers).toArray(FloatTensor[]::new); + this.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength, kvDim)) + .limit(config.numberOfLayers).toArray(FloatTensor[]::new); + } + } + + static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size, float rmsNormEps) { + // calculate sum of squares + float ss = x.reduce(0, size, 0f, (acc, xi) -> acc + xi * xi); + ss /= size; + ss += rmsNormEps; + ss = (float) (1.0 / Math.sqrt(ss)); + // normalize and scale + final float finalss = ss; // for the lambda + out.mapWithIndexInPlace(0, size, (value, index) -> weight.get(index) * (finalss * x.getFloat(index))); + } + + static FloatTensor forward(Llama model, State state, int token, int position) { + // a few convenience variables + Configuration config = model.configuration(); + Weights weights = model.weights(); + int dim = config.dim; + int headSize = config.headSize; + int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads; + int kvMul = config.numberOfHeads / config.numberOfKeyValueHeads; // integer multiplier of the kv sharing in multiquery + float sqrtHeadSize = (float) Math.sqrt(headSize); + + // copy the token embedding into x + weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim); + + // forward all the layers + for (int l = 0; l < config.numberOfLayers; l++) { + // attention rmsnorm + rmsnorm(state.xb, state.x, weights.rms_att_weight[l], dim, config.rmsNormEps); + + // qkv matmuls for this position + weights.wq[l].matmul(state.xb, state.q, dim, dim); + weights.wk[l].matmul(state.xb, state.k, kvDim, dim); + weights.wv[l].matmul(state.xb, state.v, kvDim, dim); + + // RoPE relative positional encoding: complex-valued rotate q and k in each head + for (int i = 0; i < dim; i += 2) { + int head_dim = i % headSize; + float fcr = weights.freq_cis_real.get(position * (headSize / 2) + (head_dim / 2)); + float fci = weights.freq_cis_imag.get(position * (headSize / 2) + (head_dim / 2)); + int rotn = i < kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only + for (int v = 0; v < rotn; v++) { + FloatTensor vec = v == 0 ? state.q : state.k; // the vector to rotate (query or key) + float v0 = vec.getFloat(i); + float v1 = vec.getFloat(i + 1); + vec.setFloat(i, v0 * fcr - v1 * fci); + vec.setFloat(i + 1, v0 * fci + v1 * fcr); + } + } + + // save key,value at this time step (position) to our kv cache + //int loff = l * config.seq_len * kvDim; // kv cache layer offset for convenience + state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim); + state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim); + + int curLayer = l; + + // multihead attention. iterate over all heads + Parallel.parallelFor(0, config.numberOfHeads, h -> { + // get the query vector for this head + // float* q = s.q + h * headSize; + int qOffset = h * headSize; + + // attention scores for this head + // float* att = s.att + h * config.seq_len; + int attOffset = h * config.contextLength; + + // iterate over all timesteps, including the current one + for (int t = 0; t <= position; t++) { + // get the key vector for this head and at this timestep + // float* k = s.key_cache + loff + t * dim + h * headSize; + int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize; + // calculate the attention score as the dot product of q and k + float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize); + score /= sqrtHeadSize; + // save the score to the attention buffer + state.att.setFloat(attOffset + t, score); + } + + // softmax the scores to get attention weights, from 0..position inclusively + state.att.softmaxInPlace(attOffset, position + 1); + + // weighted sum of the values, store back into xb + // float* xb = s.xb + h * headSize; + int xbOffset = h * headSize; + // memset(xb, 0, headSize * sizeof(float)); + state.xb.fillInPlace(xbOffset, headSize, 0f); + + for (int t = 0; t <= position; t++) { + // get the value vector for this head and at this timestep + // float* v = s.value_cache + loff + t * dim + h * headSize; + int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize; + // get the attention weight for this timestep + float a = state.att.getFloat(attOffset + t); + // accumulate the weighted value into xb + state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a); + } + }); + + // final matmul to get the output of the attention + weights.wo[l].matmul(state.xb, state.xb2, dim, dim); + + // residual connection back into x + state.x.addInPlace(state.xb2); + + // ffn rmsnorm + rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], dim, config.rmsNormEps); + + // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) + // first calculate self.w1(x) and self.w3(x) + weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim, dim); + weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim, dim); + + // SwiGLU non-linearity + // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid + state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + + // elementwise multiply with w3(x) + state.hb.multiplyInPlace(state.hb2); + + // final matmul to get the output of the ffn + weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim); + + // residual connection + state.x.addInPlace(state.xb); + } + + // final rmsnorm + rmsnorm(state.x, state.x, weights.rms_final_weight, dim, config.rmsNormEps); + + // classifier into logits + weights.wcls.matmul(state.x, state.logits, config.vocabularySize, dim); + + return state.logits; + } + + /** + * LLM generation entry point, ingest prompt tokens and generates new tokens. + * + *

+ * All prompt tokens are ingested first, then inference starts, until a stop token is found. + * The returned tokens only include generated/inferred tokens. + * + * @param model model to run inference (including weights, configuration, tokenizer ...) + * @param state state of the model e.g. key/value caches ... this is mutated by this call + * @param startPosition start prompt ingestion + inference at this position in the context e.g. useful if state was kept + * across calls (chained generation). 0 implies run with no previous context. + * @param promptTokens prompt tokens to ingest, all the prompt tokens will be ingested, given there's enough capacity left + * in the context + * @param stopTokens set of tokens that abort generation during inference, stop tokens do not affect prompt ingestion + * @param maxTokens maximum number of tokens (can go up to {@link Configuration#contextLength context length} + * if this value is negative or greater than {@link Configuration#contextLength context length} + * @param sampler {@link Sampler strategy} used to select tokens + * @param echo debugging flag, prints ALL, prompt and inferred tokens, to {@link System#err stderr} + * @param onTokenGenerated callback, if non-null, it's called every time a token is inferred e.g. it's not called when + * ingesting prompt tokens + * @return list of generated/inferred tokens, including the stop token, if any e.g. does not include any token from the + * prompt + */ + public static List generateTokens(Llama model, State state, int startPosition, List promptTokens, + Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated) { + long startNanos = System.nanoTime(); + if (maxTokens < 0 || model.configuration().contextLength < maxTokens) { + maxTokens = model.configuration().contextLength; + } + List generatedTokens = new ArrayList<>(maxTokens); + int token = state.latestToken; // BOS? + int nextToken; + int promptIndex = 0; + for (int position = startPosition; position < maxTokens; ++position) { + forward(model, state, token, position); + if (promptIndex < promptTokens.size()) { + // Force-pick token from prompt. + nextToken = promptTokens.get(promptIndex++); + if (echo) { + // log prompt token (different color?) + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + } + } else { + nextToken = sampler.sampleToken(state.logits); + if (echo) { + // log inferred token + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + } + generatedTokens.add(nextToken); + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + if (stopTokens.contains(nextToken)) { + break; + } + } + state.latestToken = token = nextToken; + } + + long elapsedNanos = System.nanoTime() - startNanos; + int totalTokens = promptIndex + generatedTokens.size(); + System.err.printf("%n%.2f tokens/s (%d)%n", totalTokens / (elapsedNanos / 1_000_000_000.0), totalTokens); + + return generatedTokens; + } +} diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Llama3.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Llama3.java new file mode 100755 index 000000000..4eb0b9034 --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Llama3.java @@ -0,0 +1,1464 @@ +///usr/bin/env jbang "$0" "$@" ; exit $? +//JAVA 21+ +//PREVIEW +//COMPILE_OPTIONS --add-modules=jdk.incubator.vector +//RUNTIME_OPTIONS --add-modules=jdk.incubator.vector +//MAIN com.llama4j.Llama3 + +// Practical Llama 3 (and 3.1) inference in a single Java file +// Author: Alfonso² Peterssen +// Based on Andrej Karpathy's llama2.c and minbpe projects +// +// Supports llama.cpp's GGUF format, restricted to Q4_0 and Q8_0 quantized models +// Multi-threaded matrix vector multiplication routines implemented using Java's Vector API +// Simple CLI with --chat and --instruct mode +// +// To run just: +// jbang Llama3.java --help +// +// Enjoy! +package io.quarkiverse.langchain4j.llama3.copy; + +import java.io.IOException; +import java.io.PrintStream; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.lang.reflect.Field; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.OptionalInt; +import java.util.Scanner; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.function.IntConsumer; +import java.util.random.RandomGenerator; +import java.util.random.RandomGeneratorFactory; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorSpecies; +import sun.misc.Unsafe; + +public class Llama3 { + + public static Sampler selectSampler(int vocabularySize, float temperature, float topp, long rngSeed) { + Sampler sampler; + if (temperature == 0.0f) { + // greedy argmax sampling: take the token with the highest probability + sampler = Sampler.ARGMAX; + } else { + // we sample from this distribution to get the next token + RandomGenerator rng = RandomGeneratorFactory.getDefault().create(rngSeed); + Sampler innerSampler; + if (topp <= 0 || topp >= 1) { + // simply sample from the predicted probability distribution + innerSampler = new CategoricalSampler(rng); + } else { + // top-p (nucleus) sampling, clamping the least likely tokens to zero + innerSampler = new ToppSampler(vocabularySize, topp, rng); + } + sampler = logits -> { + // apply the temperature to the logits + logits.divideInPlace(0, logits.size(), temperature); + // apply softmax to the logits to get the probabilities for next token + logits.softmaxInPlace(0, logits.size()); + return innerSampler.sampleToken(logits); + }; + } + return sampler; + } + + static void runInteractive(Llama model, Sampler sampler, Options options) { + Llama.State state = null; + List conversationTokens = new ArrayList<>(); + ChatFormat chatFormat = new ChatFormat(model.tokenizer()); + conversationTokens.add(chatFormat.beginOfText); + if (options.systemPrompt() != null) { + conversationTokens + .addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); + } + int startPosition = 0; + Scanner in = new Scanner(System.in); + while (true) { + System.out.print("> "); + System.out.flush(); + String userText = in.nextLine(); + if (List.of("quit", "exit").contains(userText)) { + break; + } + if (state == null) { + state = model.createNewState(); + } + conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText))); + conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + Set stopTokens = chatFormat.getStopTokens(); + List responseTokens = Llama.generateTokens(model, state, startPosition, + conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), + sampler, options.echo(), token -> { + if (options.stream()) { + if (!model.tokenizer().isSpecialToken(token)) { + System.out.print(model.tokenizer().decode(List.of(token))); + } + } + }); + // Include stop token in the prompt history, but not in the response displayed to the user. + conversationTokens.addAll(responseTokens); + startPosition = conversationTokens.size(); + Integer stopToken = null; + if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { + stopToken = responseTokens.getLast(); + responseTokens.removeLast(); + } + if (!options.stream()) { + String responseText = model.tokenizer().decode(responseTokens); + System.out.println(responseText); + } + if (stopToken == null) { + System.err.println("Ran out of context length..."); + break; + } + } + } + + static void runInstructOnce(Llama model, Sampler sampler, Options options) { + Llama.State state = model.createNewState(); + ChatFormat chatFormat = new ChatFormat(model.tokenizer()); + + List promptTokens = new ArrayList<>(); + promptTokens.add(chatFormat.beginOfText); + if (options.systemPrompt() != null) { + promptTokens + .addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); + } + promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt()))); + promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + + Set stopTokens = chatFormat.getStopTokens(); + List responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(), + sampler, options.echo(), token -> { + if (options.stream()) { + if (!model.tokenizer().isSpecialToken(token)) { + System.out.print(model.tokenizer().decode(List.of(token))); + } + } + }); + if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { + responseTokens.removeLast(); + } + if (!options.stream()) { + String responseText = model.tokenizer().decode(responseTokens); + System.out.println(responseText); + } + } + + public record Options(Path modelPath, String prompt, String systemPrompt, boolean interactive, + float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo) { + + static final int DEFAULT_MAX_TOKENS = 512; + + public Options { + require(modelPath != null, "Missing argument: --model is required"); + require(interactive || prompt != null, + "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\""); + require(0 <= temperature, "Invalid argument: --temperature must be non-negative"); + require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]"); + } + + static void require(boolean condition, String messageFormat, Object... args) { + if (!condition) { + System.out.println("ERROR " + messageFormat.formatted(args)); + System.out.println(); + printUsage(System.out); + System.exit(-1); + } + } + + static void printUsage(PrintStream out) { + out.println("Usage: jbang Llama3.java [options]"); + out.println(); + out.println("Options:"); + out.println(" --model, -m required, path to .gguf file"); + out.println(" --interactive, --chat, -i run in chat mode"); + out.println(" --instruct run in instruct (once) mode, default mode"); + out.println(" --prompt, -p input prompt"); + out.println(" --system-prompt, -sp (optional) system prompt"); + out.println(" --temperature, -temp temperature in [0,inf], default 0.1"); + out.println(" --top-p p value in top-p (nucleus) sampling in [0,1] default 0.95"); + out.println(" --seed random seed, default System.nanoTime()"); + out.println(" --max-tokens, -n number of steps to run for < 0 = limited by context length, default " + + DEFAULT_MAX_TOKENS); + out.println( + " --stream print tokens during generation; may cause encoding artifacts for non ASCII text, default true"); + out.println( + " --echo print ALL tokens to stderr, if true, recommended to set --stream=false, default false"); + out.println(); + out.println("Examples:"); + out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --prompt \"Tell me a joke\""); + out.println( + " jbang Llama3.java --model llama3.2-1b-q4_0.gguf --system-prompt \"Reply concisely, in French\" --prompt \"Who was Marie Curie?\""); + out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --system-prompt \"Answer concisely\" --chat"); + out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --chat"); + out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --prompt \"Print 5 emojis\" --stream=false"); + } + + static Options parseOptions(String[] args) { + String prompt = null; + String systemPrompt = null; + float temperature = 0.1f; + float topp = 0.95f; + Path modelPath = null; + long seed = System.nanoTime(); + // Keep max context length small for low-memory devices. + int maxTokens = DEFAULT_MAX_TOKENS; + boolean interactive = false; + boolean stream = true; + boolean echo = false; + + for (int i = 0; i < args.length; i++) { + String optionName = args[i]; + require(optionName.startsWith("-"), "Invalid option %s", optionName); + switch (optionName) { + case "--interactive", "--chat", "-i" -> interactive = true; + case "--instruct" -> interactive = false; + case "--help", "-h" -> { + printUsage(System.out); + System.exit(0); + } + default -> { + String nextArg; + if (optionName.contains("=")) { + String[] parts = optionName.split("=", 2); + optionName = parts[0]; + nextArg = parts[1]; + } else { + require(i + 1 < args.length, "Missing argument for option %s", optionName); + nextArg = args[i + 1]; + i += 1; // skip arg + } + switch (optionName) { + case "--prompt", "-p" -> prompt = nextArg; + case "--system-prompt", "-sp" -> systemPrompt = nextArg; + case "--temperature", "--temp" -> temperature = Float.parseFloat(nextArg); + case "--top-p" -> topp = Float.parseFloat(nextArg); + case "--model", "-m" -> modelPath = Paths.get(nextArg); + case "--seed", "-s" -> seed = Long.parseLong(nextArg); + case "--max-tokens", "-n" -> maxTokens = Integer.parseInt(nextArg); + case "--stream" -> stream = Boolean.parseBoolean(nextArg); + case "--echo" -> echo = Boolean.parseBoolean(nextArg); + default -> require(false, "Unknown option: %s", optionName); + } + } + } + } + return new Options(modelPath, prompt, systemPrompt, interactive, temperature, topp, seed, maxTokens, stream, echo); + } + } + + public static void main(String[] args) throws IOException { + Options options = Options.parseOptions(args); + Llama model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens()); + if (model == null) { + // No compatible preloaded model found, fallback to fully parse and load the specified file. + model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true); + } + Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(), + options.seed()); + if (options.interactive()) { + runInteractive(model, sampler, options); + } else { + runInstructOnce(model, sampler, options); + } + } +} + +final class GGUF { + private static final int GGUF_MAGIC = 0x46554747; + private static final int DEFAULT_ALIGNMENT = 32; // must be a power of 2 + private static final List SUPPORTED_GGUF_VERSIONS = List.of(2, 3); + private int magic; + private int version; + private int tensorCount; // uint64_t + private int alignment; + private int metadata_kv_count; // uint64_t + private Map metadata; + + public Map getTensorInfos() { + return tensorInfos; + } + + private Map tensorInfos; + + private long tensorDataOffset; + + public long getTensorDataOffset() { + return tensorDataOffset; + } + + public Map getMetadata() { + return metadata; + } + + private final ByteBuffer BB_1 = ByteBuffer.allocate(Byte.BYTES).order(ByteOrder.LITTLE_ENDIAN); + private final ByteBuffer BB_2 = ByteBuffer.allocate(Short.BYTES).order(ByteOrder.LITTLE_ENDIAN); + private final ByteBuffer BB_4 = ByteBuffer.allocate(Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); + private final ByteBuffer BB_8 = ByteBuffer.allocate(Long.BYTES).order(ByteOrder.LITTLE_ENDIAN); + + public static GGUF loadModel(Path modelPath) throws IOException { + try (FileChannel fileChannel = FileChannel.open(modelPath); + var ignored = Timer.log("Parse " + modelPath)) { + GGUF gguf = new GGUF(); + gguf.loadModelImpl(fileChannel); + return gguf; + } + } + + enum MetadataValueType { + // The value is a 8-bit unsigned integer. + UINT8(1), + // The value is a 8-bit signed integer. + INT8(1), + // The value is a 16-bit unsigned little-endian integer. + UINT16(2), + // The value is a 16-bit signed little-endian integer. + INT16(2), + // The value is a 32-bit unsigned little-endian integer. + UINT32(4), + // The value is a 32-bit signed little-endian integer. + INT32(4), + // The value is a 32-bit IEEE754 floating point number. + FLOAT32(4), + // The value is a boolean. + // 1-byte value where 0 is false and 1 is true. + // Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy. + BOOL(1), + // The value is a UTF-8 non-null-terminated string, with length prepended. + STRING(-8), + // The value is an array of other values, with the length and type prepended. + // Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes. + ARRAY(-8), + // The value is a 64-bit unsigned little-endian integer. + UINT64(8), + // The value is a 64-bit signed little-endian integer. + INT64(8), + // The value is a 64-bit IEEE754 floating point number. + FLOAT64(8); + + private final int byteSize; + + MetadataValueType(int byteSize) { + this.byteSize = byteSize; + } + + private static final MetadataValueType[] VALUES = values(); + + public static MetadataValueType fromIndex(int index) { + return VALUES[index]; + } + + public int byteSize() { + return byteSize; + } + } + + private void loadModelImpl(FileChannel fileChannel) throws IOException { + // The header of the file. + readHeader(fileChannel); // gguf_header_t header; + // Tensor infos, which can be used to locate the tensor data. + // gguf_tensor_info_t tensor_infos[header.tensor_count]; + this.tensorInfos = HashMap.newHashMap(tensorCount); + for (int i = 0; i < tensorCount; ++i) { + GGUFTensorInfo ti = readTensorInfo(fileChannel); + assert !tensorInfos.containsKey(ti.name); + tensorInfos.put(ti.name, ti); + } + // Padding to the nearest multiple of `ALIGNMENT`. + // uint8_t _padding[ALIGNMENT - (sizeof(header + tensor_infos) % ALIGNMENT)]; + //long _padding = -fileChannel.position() & (ALIGNMENT - 1); + long _padding = getAlignment() - (fileChannel.position() % getAlignment()); + fileChannel.position(fileChannel.position() + _padding); + // Tensor data. + // + // This is arbitrary binary data corresponding to the weights of the model. This data should be close + // or identical to the data in the original model file, but may be different due to quantization or + // other optimizations for inference. Any such deviations should be recorded in the metadata or as + // part of the architecture definition. + // + // Each tensor's data must be stored within this array, and located through its `tensor_infos` entry. + // The offset of each tensor's data must be a multiple of `ALIGNMENT`, and the space between tensors + // should be padded to `ALIGNMENT` bytes. + // uint8_t tensor_data[]; + this.tensorDataOffset = fileChannel.position(); + } + + public static Map loadTensors(FileChannel fileChannel, long tensorDataOffset, + Map tensorInfos) throws IOException { + Arena arena = Arena.ofAuto(); + MemorySegment tensorData = fileChannel.map(FileChannel.MapMode.READ_ONLY, tensorDataOffset, + fileChannel.size() - tensorDataOffset, arena); + Map tensorEntries = HashMap.newHashMap(tensorInfos.size()); + for (Map.Entry entry : tensorInfos.entrySet()) { + GGUFTensorInfo ti = entry.getValue(); + int numberOfElements = FloatTensor.numberOfElements(ti.dimensions()); + int sizeInBytes = Math.toIntExact(ti.ggmlType().byteSizeFor(numberOfElements)); + MemorySegment memorySegment = tensorData.asSlice(ti.offset(), sizeInBytes); + tensorEntries.put(ti.name(), + new GGMLTensorEntry(tensorData, ti.name(), ti.ggmlType(), ti.dimensions(), memorySegment)); + } + return tensorEntries; + } + + public record GGUFTensorInfo(String name, int[] dimensions, GGMLType ggmlType, long offset) { + } + + private GGMLType readGGMLType(FileChannel fileChannel) throws IOException { + int ggmlTypeId = readInt(fileChannel); // ggml_type type; + return GGMLType.fromId(ggmlTypeId); + } + + private GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOException { + // The name of the tensor. It is a standard GGUF string, with the caveat that + // it must be at most 64 bytes long. + String name = readString(fileChannel); // gguf_string_t name; + assert name.length() <= 64; + // The number of dimensions in the tensor. + // Currently at most 4, but this may change in the future. + int n_dimensions = readInt(fileChannel); // uint32_t n_dimensions; + assert n_dimensions <= 4; + // The dimensions of the tensor. + int[] dimensions = new int[n_dimensions]; // uint64_t dimensions[n_dimensions]; + for (int i = 0; i < n_dimensions; ++i) { + dimensions[i] = Math.toIntExact(readLong(fileChannel)); + } + // The type of the tensor. + GGMLType ggmlType = readGGMLType(fileChannel); // ggml_type type; + // The offset of the tensor's data in this file in bytes. + // This offset is relative to `tensor_data`, not to the start + // of the file, to make it easier for writers to write the file. + // Readers should consider exposing this offset relative to the + // file to make it easier to read the data. + // Must be a multiple of `ALIGNMENT`. + long offset = readLong(fileChannel); // uint64_t offset; + assert offset % getAlignment() == 0; + return new GGUFTensorInfo(name, dimensions, ggmlType, offset); + } + + private String readString(FileChannel fileChannel) throws IOException { + // A string in GGUF. + // The length of the string, in bytes. + int len = Math.toIntExact(readLong(fileChannel)); // uint64_t len; + // The string as a UTF-8 non-null-terminated string. + byte[] bytes = new byte[len]; // char string[len]; + int bytesRead = fileChannel.read(ByteBuffer.wrap(bytes)); + assert len == bytesRead; + return new String(bytes, StandardCharsets.UTF_8); + } + + private Pair readKeyValuePair(FileChannel fileChannel) throws IOException { + // The key of the metadata. It is a standard GGUF string, with the following caveats: + // - It must be a valid ASCII string. + // - It must be a hierarchical key, where each segment is `lower_snake_case` and separated by a `.`. + // - It must be at most 2^16-1/65535 bytes long. + // Any keys that do not follow these rules are invalid. + String key = readString(fileChannel); // gguf_string_t key; + assert key.length() < (1 << 16); + assert key.codePoints().allMatch(cp -> ('a' <= cp && cp <= 'z') || ('0' <= cp && cp <= '9') || cp == '_' || cp == '.'); + Object value = readMetadataValue(fileChannel); + return new Pair<>(key, value); + } + + private Object readMetadataValue(FileChannel fileChannel) throws IOException { + // The type of the value. + // Must be one of the `gguf_metadata_value_type` values. + MetadataValueType value_type = readMetadataValueType(fileChannel); // gguf_metadata_value_type value_type; + // The value. + return readMetadataValueOfType(value_type, fileChannel); // gguf_metadata_value_t value; + } + + void readHeader(FileChannel fileChannel) throws IOException { + // Magic number to announce that this is a GGUF file. + // Must be `GGUF` at the byte level: `0x47` `0x47` `0x55` `0x46`. + // Your executor might do little-endian byte order, so it might be + // check for 0x46554747 and letting the endianness cancel out. + // Consider being *very* explicit about the byte order here. + this.magic = readInt(fileChannel); // uint32_t magic; + if (magic != GGUF_MAGIC) { + throw new IllegalArgumentException("unsupported header.magic " + magic); + } + // The version of the format implemented. + // Must be `3` for version described in this spec. + // + // This version should only be increased for structural changes to the format. + // Changes that do not affect the structure of the file should instead update the metadata + // to signify the change. + this.version = readInt(fileChannel); // uint32_t version; + if (!SUPPORTED_GGUF_VERSIONS.contains(version)) { + throw new IllegalArgumentException("unsupported header.version " + version); + } + // The number of tensors in the file. + // This is explicit, instead of being included in the metadata, to ensure it is always present + // for loading the tensors. + this.tensorCount = Math.toIntExact(readLong(fileChannel)); // uint64_t tensor_count; + // The number of metadata key-value pairs. + this.metadata_kv_count = Math.toIntExact(readLong(fileChannel)); // uint64_t metadata_kv_count; + // The metadata key-value pairs. + // gguf_metadata_kv_t metadata_kv[metadata_kv_count]; + this.metadata = HashMap.newHashMap(metadata_kv_count); + for (int i = 0; i < metadata_kv_count; ++i) { + Pair keyValue = readKeyValuePair(fileChannel); + assert !metadata.containsKey(keyValue.first()); + metadata.put(keyValue.first(), keyValue.second()); + } + } + + private Object readArray(FileChannel fileChannel) throws IOException { + // Any value type is valid, including arrays. + MetadataValueType value_type = readMetadataValueType(fileChannel); // gguf_metadata_value_type type; + // Number of elements, not bytes + int len = Math.toIntExact(readLong(fileChannel)); // uint64_t len; + // The array of values. + // gguf_metadata_value_t array[len]; + switch (value_type) { + case UINT8, INT8 -> { + byte[] bytes = new byte[len]; + for (int i = 0; i < len; ++i) { + bytes[i] = readByte(fileChannel); + } + return bytes; + } + case UINT16, INT16 -> { + short[] shorts = new short[len]; + for (int i = 0; i < len; ++i) { + shorts[i] = readShort(fileChannel); + } + return shorts; + } + case UINT32, INT32 -> { + int[] ints = new int[len]; + for (int i = 0; i < len; ++i) { + ints[i] = readInt(fileChannel); + } + return ints; + } + case FLOAT32 -> { + float[] floats = new float[len]; + for (int i = 0; i < len; ++i) { + floats[i] = readFloat(fileChannel); + } + return floats; + } + case BOOL -> { + boolean[] booleans = new boolean[len]; + for (int i = 0; i < len; ++i) { + booleans[i] = readBoolean(fileChannel); + } + return booleans; + } + case STRING -> { + String[] strings = new String[len]; + for (int i = 0; i < len; ++i) { + strings[i] = readString(fileChannel); + } + return strings; + } + case ARRAY -> { + Object[] arrays = new Object[len]; + for (int i = 0; i < len; ++i) { + arrays[i] = readArray(fileChannel); + } + return arrays; + } + default -> throw new UnsupportedOperationException("read array of " + value_type); + } + } + + private Object readMetadataValueOfType(MetadataValueType valueType, FileChannel fileChannel) throws IOException { + return switch (valueType) { + case UINT8, INT8 -> readByte(fileChannel); + case UINT16, INT16 -> readShort(fileChannel); + case UINT32, INT32 -> readInt(fileChannel); + case FLOAT32 -> readFloat(fileChannel); + case UINT64, INT64 -> readLong(fileChannel); + case FLOAT64 -> readDouble(fileChannel); + case BOOL -> readBoolean(fileChannel); + case STRING -> readString(fileChannel); + case ARRAY -> readArray(fileChannel); + }; + } + + private byte readByte(FileChannel fileChannel) throws IOException { + int bytesRead = fileChannel.read(BB_1); + assert bytesRead == 1; + return BB_1.clear().get(0); + } + + private boolean readBoolean(FileChannel fileChannel) throws IOException { + return readByte(fileChannel) != 0; + } + + private short readShort(FileChannel fileChannel) throws IOException { + int bytesRead = fileChannel.read(BB_2); + assert bytesRead == 2; + return BB_2.clear().getShort(0); + } + + private int readInt(FileChannel fileChannel) throws IOException { + int bytesRead = fileChannel.read(BB_4); + assert bytesRead == 4; + return BB_4.clear().getInt(0); + } + + private long readLong(FileChannel fileChannel) throws IOException { + int bytesRead = fileChannel.read(BB_8); + assert bytesRead == 8; + return BB_8.clear().getLong(0); + } + + private float readFloat(FileChannel fileChannel) throws IOException { + return Float.intBitsToFloat(readInt(fileChannel)); + } + + private double readDouble(FileChannel fileChannel) throws IOException { + return Double.longBitsToDouble(readLong(fileChannel)); + } + + private MetadataValueType readMetadataValueType(FileChannel fileChannel) throws IOException { + int index = readInt(fileChannel); + return MetadataValueType.fromIndex(index); + } + + public int getAlignment() { + if (alignment != 0) { + return alignment; + } + alignment = (int) metadata.getOrDefault("general.alignment", DEFAULT_ALIGNMENT); + assert Integer.bitCount(alignment) == 1 : "alignment must be a power of two"; + return alignment; + } +} + +interface Timer extends AutoCloseable { + @Override + void close(); // no Exception + + static Timer log(String label) { + return log(label, TimeUnit.MILLISECONDS); + } + + static Timer log(String label, TimeUnit timeUnit) { + return new Timer() { + final long startNanos = System.nanoTime(); + + @Override + public void close() { + long elapsedNanos = System.nanoTime() - startNanos; + System.err.println(label + ": " + + timeUnit.convert(elapsedNanos, TimeUnit.NANOSECONDS) + " " + + timeUnit.toChronoUnit().name().toLowerCase()); + } + }; + } +} + +final class Parallel { + public static void parallelFor(int startInclusive, int endExclusive, IntConsumer action) { + IntStream.range(startInclusive, endExclusive).parallel().forEach(action); + } +} + +record Pair(First first, Second second) { +} + +record GGMLTensorEntry(MemorySegment mappedFile, String name, GGMLType ggmlType, int[] shape, + MemorySegment memorySegment) { +} + +final class Float16 { + public static final int BYTES = 2; +} + +enum GGMLType { + F32(Float.BYTES), + F16(Float16.BYTES), + Q4_0(Float16.BYTES + 16 * Byte.BYTES, 32), + Q4_1(2 * Float16.BYTES + 16 * Byte.BYTES, 32), + UNSUPPORTED_Q4_2(Integer.MAX_VALUE), // support has been removed + UNSUPPORTED_Q4_3(Integer.MAX_VALUE), // support has been removed + Q5_0(Integer.MAX_VALUE), + Q5_1(Integer.MAX_VALUE), + Q8_0(Float16.BYTES + 32 * Byte.BYTES, 32), + Q8_1(32 * Byte.BYTES + 2 * Float.BYTES, 32), + // k-quantizations + Q2_K(Integer.MAX_VALUE), + Q3_K(Integer.MAX_VALUE), + Q4_K(2 * Float16.BYTES + ((GGMLType.QK_K / 16) / 8 * 6) + GGMLType.QK_K / 2, GGMLType.QK_K), + Q5_K(2 * Float16.BYTES + ((GGMLType.QK_K / 16) / 8 * 6) + GGMLType.QK_K / 8 + GGMLType.QK_K / 2, GGMLType.QK_K), + Q6_K(GGMLType.QK_K / 2 + GGMLType.QK_K / 4 + GGMLType.QK_K / 16 + Float16.BYTES, GGMLType.QK_K), + Q8_K(Integer.MAX_VALUE), + I8(Byte.BYTES), + I16(Short.BYTES), + I32(Integer.BYTES); + + private static final GGMLType[] VALUES = values(); + + private final int typeSize; + + private final int blockSize; + + public int getTypeSize() { + return typeSize; + } + + public int getBlockSize() { + return blockSize; + } + + public static GGMLType fromId(int id) { + return VALUES[id]; + } + + GGMLType(int typeSize) { + this(typeSize, 1); + } + + public long byteSizeFor(int numberOfElements) { + long t = numberOfElements * (long) getTypeSize(); + assert t % getBlockSize() == 0; + return Math.toIntExact(t / getBlockSize()); + } + + public static final int QK_K = 256; // or 64? + + GGMLType(int typeSize, int blockSize) { + assert blockSize > 0; + assert typeSize > 0; + assert isPowerOf2(blockSize); + this.typeSize = typeSize; + this.blockSize = blockSize; + } + + private static boolean isPowerOf2(int n) { + return n > 0 && (n & (n - 1)) == 0; + } +} + +/** + * Over-simplified, shapeless, float tensor. + *

+ * Not a strict tensor, but rather just a sequence of floats, not required to be backed by memory + * e.g. can represent a sequence of quantized floats. + */ +abstract class FloatTensor { + static final boolean USE_VECTOR_API = Boolean.parseBoolean(System.getProperty("llama.VectorAPI", "true")); + + // static final ValueLayout.OfFloat JAVA_FLOAT_LE = ValueLayout.JAVA_FLOAT.withOrder(ByteOrder.LITTLE_ENDIAN); + // static final ValueLayout.OfShort JAVA_SHORT_LE = ValueLayout.JAVA_SHORT.withOrder(ByteOrder.LITTLE_ENDIAN); + + // The use of Unsafe in this file is a temporary workaround to support native-image. + static final Unsafe UNSAFE; + + static { + try { + Field f = Unsafe.class.getDeclaredField("theUnsafe"); + f.setAccessible(true); + UNSAFE = (Unsafe) f.get(null); + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + static short readShort(MemorySegment memorySegment, long offset) { + // The MemorySegment.get* methods should be used instead. + return UNSAFE.getShort(memorySegment.address() + offset); + } + + static byte readByte(MemorySegment memorySegment, long offset) { + // The MemorySegment.get* methods should be used instead. + return UNSAFE.getByte(memorySegment.address() + offset); + } + + // Preferred vector size for the fast multiplication routines. + // (Apple Silicon) NEON only supports up-to 128bit vectors. + static final VectorSpecies F_SPECIES = FloatVector.SPECIES_PREFERRED.vectorBitSize() == 128 ? FloatVector.SPECIES_128 + : FloatVector.SPECIES_256; + + abstract int size(); + + abstract float getFloat(int index); + + abstract void setFloat(int index, float value); + + abstract FloatVector getFloatVector(VectorSpecies species, int offset); + + abstract GGMLType type(); + + public static int numberOfElements(int... dimensions) { + assert Arrays.stream(dimensions).allMatch(i -> i > 0); + return Arrays.stream(dimensions).reduce(Math::multiplyExact).orElseThrow(); + } + + static float scalarDot(FloatTensor thiz, int thisOffset, FloatTensor that, int thatOffset, int size) { + float result = 0f; + for (int j = 0; j < size; j++) { + result += thiz.getFloat(thisOffset + j) * that.getFloat(thatOffset + j); + } + return result; + } + + float dot(int thisOffset, FloatTensor that, int thatOffset, int size) { + return scalarDot(this, thisOffset, that, thatOffset, size); + } + + void matmul(FloatTensor that, FloatTensor out, int dim0, int dim1) { + Parallel.parallelFor(0, dim0, i -> out.setFloat(i, dot(i * dim1, that, 0, dim1))); + } + + @FunctionalInterface + interface AggregateFunction { + float apply(float acc, float value); + } + + float reduce(int thisOffset, int size, float seed, AggregateFunction reduce) { + float result = seed; + for (int i = 0; i < size; ++i) { + result = reduce.apply(result, getFloat(thisOffset + i)); + } + return result; + } + + float sum(int thisOffset, int size) { + return reduce(thisOffset, size, 0f, Float::sum); + } + + float max(int thisOffset, int size) { + return reduce(thisOffset, size, Float.NEGATIVE_INFINITY, Float::max); + } + + void copyTo(int thisOffset, FloatTensor that, int thatOffset, int size) { + that.mapWithIndexInPlace(thatOffset, size, (value, index) -> this.getFloat(index - thatOffset + thisOffset)); + } + + int argmax(int thisOffset, int size) { + assert size > 0; + int maxIndex = thisOffset; + float maxValue = this.getFloat(maxIndex); + int endIndex = thisOffset + size; + for (int i = thisOffset; i < endIndex; ++i) { + float f = this.getFloat(i); + if (f > maxValue) { + maxValue = f; + maxIndex = i; + } + } + return maxIndex; + } + + int argmax() { + return argmax(0, size()); + } + + @FunctionalInterface + interface MapFunction { + float apply(float value); + } + + @FunctionalInterface + interface MapWithIndexFunction { + float apply(float value, int index); + } + + FloatTensor mapInPlace(int thisOffset, int size, MapFunction mapFunction) { + int endIndex = thisOffset + size; + for (int i = thisOffset; i < endIndex; ++i) { + setFloat(i, mapFunction.apply(getFloat(i))); + } + return this; + } + + FloatTensor mapInPlace(MapFunction mapFunction) { + return mapInPlace(0, size(), mapFunction); + } + + FloatTensor mapWithIndexInPlace(int thisOffset, int size, MapWithIndexFunction mapWithIndexFunction) { + int endOffset = thisOffset + size; + for (int i = thisOffset; i < endOffset; ++i) { + setFloat(i, mapWithIndexFunction.apply(getFloat(i), i)); + } + return this; + } + + FloatTensor addInPlace(int thisOffset, FloatTensor that, int thatOffset, int size) { + return mapWithIndexInPlace(thisOffset, size, (value, index) -> value + that.getFloat(index - thisOffset + thatOffset)); + } + + FloatTensor addInPlace(FloatTensor that) { + return addInPlace(0, that, 0, size()); + } + + FloatTensor multiplyInPlace(int thisOffset, FloatTensor that, int thatOffset, int size) { + return mapWithIndexInPlace(thisOffset, size, (value, index) -> value * that.getFloat(index - thisOffset + thatOffset)); + } + + FloatTensor multiplyInPlace(FloatTensor that) { + return multiplyInPlace(0, that, 0, size()); + } + + FloatTensor divideInPlace(int thisOffset, int size, float value) { + return mapInPlace(thisOffset, size, f -> f / value); + } + + FloatTensor fillInPlace(int thisOffset, int size, float value) { + return mapInPlace(thisOffset, size, unused -> value); + } + + FloatTensor softmaxInPlace(int thisOffset, int size) { + // find max value (for numerical stability) + float maxVal = max(thisOffset, size); + // exp and sum + mapInPlace(thisOffset, size, f -> (float) Math.exp(f - maxVal)); + float sum = sum(thisOffset, size); + // normalize + return divideInPlace(thisOffset, size, sum); + } + + FloatTensor saxpyInPlace(int thisOffset, FloatTensor that, int thatOffset, int size, float a) { + // this[thatOffset ... thatOffset + size) = a * that[thatOffset ... thatOffset + size) + this[thisOffset ... thisOffset + size) + for (int i = 0; i < size; ++i) { + setFloat(thisOffset + i, a * that.getFloat(thatOffset + i) + this.getFloat(thisOffset + i)); + } + return this; + } +} + +/** + * {@link FloatTensor} quantized in the {@link GGMLType#Q4_0} format. + *

+ * This tensor implementation is not compatible with {@link FloatTensor}, but + * {@link #dot(int, FloatTensor, int, int)} has a vectorized implementation that is used when + * the second argument implements {@link FloatTensor}. + */ +final class Q4_0FloatTensor extends FloatTensor { + + final int size; + final MemorySegment memorySegment; + + public Q4_0FloatTensor(int size, MemorySegment memorySegment) { + this.size = size; + this.memorySegment = memorySegment; + } + + @Override + int size() { + return size; + } + + @Override + public void setFloat(int index, float value) { + throw new UnsupportedOperationException("setFloat"); + } + + @Override + FloatVector getFloatVector(VectorSpecies species, int index) { + throw new UnsupportedOperationException("getFloatVector"); + } + + @Override + public GGMLType type() { + return GGMLType.Q4_0; + } + + @Override + public float getFloat(int index) { + assert 0 <= index && index < size; + int blockIndex = index / GGMLType.Q4_0.getBlockSize(); + int blockOffset = blockIndex * GGMLType.Q4_0.getTypeSize(); + float scale = Float.float16ToFloat(readShort(memorySegment, blockOffset)); + byte quant; + int modIndex = index % GGMLType.Q4_0.getBlockSize(); + if (modIndex < GGMLType.Q4_0.getBlockSize() / 2) { + quant = (byte) (readByte(memorySegment, blockOffset + Float16.BYTES + modIndex) & 0x0F); + } else { + quant = (byte) ((readByte(memorySegment, + blockOffset + Float16.BYTES + modIndex - GGMLType.Q4_0.getBlockSize() / 2) >>> 4) & 0x0F); + } + quant -= 8; + return quant * scale; + } + + @Override + public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) { + if (FloatTensor.USE_VECTOR_API) { + return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size); + } else { + return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size); + } + } + + private static float vectorDot(Q4_0FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) { + float result = 0f; + int j = 0; + + // Align thisOffset + j to type().getBlockSize(). + assert Integer.bitCount(GGMLType.Q4_0.getBlockSize()) == 1 : "power of 2"; + int alignmentBound = Math.min(size, -thisOffset & (GGMLType.Q4_0.getBlockSize() - 1)); + if (alignmentBound > 0) { + result += FloatTensor.scalarDot(thiz, thisOffset, that, thatOffset, alignmentBound); + j += alignmentBound; + } + assert (thisOffset + j) % GGMLType.Q4_0.getBlockSize() == 0; + + FloatVector val = FloatVector.zero(F_SPECIES); + int blockOffset = (thisOffset + j) / GGMLType.Q4_0.getBlockSize() * GGMLType.Q4_0.getTypeSize(); + int upperBound = size / GGMLType.Q4_0.getBlockSize() * GGMLType.Q4_0.getBlockSize(); + for (; j < upperBound; j += GGMLType.Q4_0.getBlockSize(), blockOffset += GGMLType.Q4_0.getTypeSize()) { + float wScaleValue = Float.float16ToFloat(readShort(thiz.memorySegment, blockOffset)); + var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue); + var B_SPECIES = ByteVector.SPECIES_128; + var wBytes = ByteVector.fromMemorySegment(B_SPECIES, thiz.memorySegment, blockOffset + Float16.BYTES, + ByteOrder.LITTLE_ENDIAN); + var loBytes = wBytes.and((byte) 0xF).sub((byte) 8); + var hiBytes = wBytes.lanewise(VectorOperators.LSHR, 4).sub((byte) 8); + if (F_SPECIES.vectorBitSize() == 256) { + var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()) + .mul(loBytes.castShape(F_SPECIES, 0)); + var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()) + .mul(loBytes.castShape(F_SPECIES, 1)); + var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 2 * F_SPECIES.length()) + .mul(hiBytes.castShape(F_SPECIES, 0)); + var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + 3 * F_SPECIES.length()) + .mul(hiBytes.castShape(F_SPECIES, 1)); + val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); + } else if (F_SPECIES.vectorBitSize() == 128) { + // This loop cannot be unrolled, why? + for (int i = 0; i < 2; ++i) { + var tmp = i == 0 ? loBytes : hiBytes; + var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 0) * F_SPECIES.length()) + .mul(tmp.castShape(F_SPECIES, 0)); + var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 1) * F_SPECIES.length()) + .mul(tmp.castShape(F_SPECIES, 1)); + var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 2) * F_SPECIES.length()) + .mul(tmp.castShape(F_SPECIES, 2)); + var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 3) * F_SPECIES.length()) + .mul(tmp.castShape(F_SPECIES, 3)); + val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); + } + } else { + throw new UnsupportedOperationException(F_SPECIES.toString()); + } + } + result += val.reduceLanes(VectorOperators.ADD); + + // Remaining entries. + if (j < size) { + result += FloatTensor.scalarDot(thiz, thisOffset + j, that, thatOffset + j, size - j); + } + + return result; + } +} + +final class Q8_0FloatTensor extends FloatTensor { + + final int size; + final MemorySegment memorySegment; + + public Q8_0FloatTensor(int size, MemorySegment memorySegment) { + this.size = size; + this.memorySegment = memorySegment; + } + + @Override + int size() { + return size; + } + + @Override + public void setFloat(int index, float value) { + throw new UnsupportedOperationException("setFloat"); + } + + @Override + FloatVector getFloatVector(VectorSpecies species, int index) { + throw new UnsupportedOperationException("getFloatVector"); + } + + @Override + public GGMLType type() { + return GGMLType.Q8_0; + } + + @Override + public float getFloat(int index) { + assert 0 <= index && index < size; + int blockIndex = index / GGMLType.Q8_0.getBlockSize(); + int withinBlockIndex = index % GGMLType.Q8_0.getBlockSize(); + int blockOffset = blockIndex * GGMLType.Q8_0.getTypeSize(); + byte quant = readByte(memorySegment, blockOffset + Float16.BYTES + withinBlockIndex); + float scale = Float.float16ToFloat(readShort(memorySegment, blockOffset)); + return quant * scale; + } + + public static final ValueLayout.OfShort JAVA_SHORT_LE = ValueLayout.JAVA_SHORT.withOrder(ByteOrder.LITTLE_ENDIAN); + + @Override + public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) { + if (FloatTensor.USE_VECTOR_API) { + return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size); + } else { + return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size); + } + } + + private static float vectorDot(Q8_0FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) { + float result = 0f; + int j = 0; + + // Align thisOffset + startIndex to type().getBlockSize(). + assert Integer.bitCount(GGMLType.Q8_0.getBlockSize()) == 1 : "power of 2"; + int alignmentBound = Math.min(size, -thisOffset & (GGMLType.Q8_0.getBlockSize() - 1)); + if (alignmentBound > 0) { + result += FloatTensor.scalarDot(thiz, thisOffset, that, thatOffset, alignmentBound); + j += alignmentBound; + } + assert (thisOffset + j) % GGMLType.Q8_0.getBlockSize() == 0; + + FloatVector val = FloatVector.zero(F_SPECIES); + int blockOffset = (thisOffset + j) / GGMLType.Q8_0.getBlockSize() * GGMLType.Q8_0.getTypeSize(); + int upperBound = size / GGMLType.Q8_0.getBlockSize() * GGMLType.Q8_0.getBlockSize(); + for (; j < upperBound; j += GGMLType.Q8_0.getBlockSize(), blockOffset += GGMLType.Q8_0.getTypeSize()) { + float wScaleValue = Float.float16ToFloat(readShort(thiz.memorySegment, blockOffset)); + var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue); + if (F_SPECIES.vectorBitSize() == 256) { + var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, thiz.memorySegment, + blockOffset + Float16.BYTES, ByteOrder.LITTLE_ENDIAN); + var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()) + .mul(wBytes.castShape(F_SPECIES, 0)); + var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()) + .mul(wBytes.castShape(F_SPECIES, 1)); + var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 2 * F_SPECIES.length()) + .mul(wBytes.castShape(F_SPECIES, 2)); + var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + 3 * F_SPECIES.length()) + .mul(wBytes.castShape(F_SPECIES, 3)); + val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); + } else if (F_SPECIES.vectorBitSize() == 128) { + VectorSpecies B_128 = ByteVector.SPECIES_128; + // This loop cannot be unrolled, why? + for (int i = 0; i < 2; ++i) { + var wBytes = ByteVector.fromMemorySegment(B_128, thiz.memorySegment, + blockOffset + Float16.BYTES + i * B_128.vectorByteSize(), ByteOrder.LITTLE_ENDIAN); + var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 0 * F_SPECIES.length()) + .mul(wBytes.castShape(F_SPECIES, 0)); + var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 1 * F_SPECIES.length()) + .mul(wBytes.castShape(F_SPECIES, 1)); + var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 2 * F_SPECIES.length()) + .mul(wBytes.castShape(F_SPECIES, 2)); + var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 3 * F_SPECIES.length()) + .mul(wBytes.castShape(F_SPECIES, 3)); + val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); + } + } else { + throw new UnsupportedOperationException(F_SPECIES.toString()); + } + } + result += val.reduceLanes(VectorOperators.ADD); + + // Remaining entries. + if (j < size) { + result += FloatTensor.scalarDot(thiz, thisOffset + j, that, thatOffset + j, size - j); + } + + return result; + } +} + +final class ArrayFloatTensor extends FloatTensor { + + final float[] values; + + ArrayFloatTensor(float[] values) { + this.values = values; + } + + public static FloatTensor allocate(int... dims) { + int numberOfElements = FloatTensor.numberOfElements(dims); + return new ArrayFloatTensor(new float[numberOfElements]); + } + + @Override + public int size() { + return values.length; + } + + @Override + public float getFloat(int index) { + return values[index]; + } + + @Override + public void setFloat(int index, float value) { + values[index] = value; + } + + @Override + public GGMLType type() { + return GGMLType.F32; + } + + @Override + public FloatTensor fillInPlace(int thisOffset, int size, float value) { + Arrays.fill(values, thisOffset, thisOffset + size, value); + return this; + } + + @Override + public FloatVector getFloatVector(VectorSpecies species, int index) { + if (!USE_VECTOR_API) { + throw new UnsupportedOperationException(); + } + return FloatVector.fromArray(species, values, index); + } +} + +final class RoPE { + public static Pair precomputeFreqsCis(int contextLength, int headSize, double theta, + boolean ropeScaling, float scaleFactor, float loFreqFactor, float hiFreqFactor, float oldContextLength) { + assert headSize % 2 == 0; + float[] cr = new float[contextLength * (headSize / 2)]; + float[] ci = new float[contextLength * (headSize / 2)]; + int n = 0; + for (int pos = 0; pos < contextLength; ++pos) { + for (int i = 0; i < headSize; i += 2) { + float freq = (float) (1.0 / Math.pow(theta, i / (double) headSize)); + if (ropeScaling) { + // Llama 3.1 scaling + float loFreqWavelen = oldContextLength / loFreqFactor; + float hiFreqWavelen = oldContextLength / hiFreqFactor; + float wavelen = (float) (2.0 * Math.PI / freq); + if (wavelen < hiFreqWavelen) { + freq = freq; + } else if (wavelen > loFreqWavelen) { + freq = freq / scaleFactor; + } else { + float smooth = (oldContextLength / wavelen - loFreqFactor) / (hiFreqFactor - loFreqFactor); + freq = (1.0f - smooth) * freq / scaleFactor + smooth * freq; + } + } + float val = pos * freq; + cr[n] = (float) Math.cos(val); + ci[n] = (float) Math.sin(val); + n++; + } + } + assert contextLength * (headSize / 2) == n; + return new Pair<>(cr, ci); + } +} + +record Vocabulary(String[] tokens, float[] scores, Map tokenToIndex) { + public Vocabulary(String[] vocabulary, float[] scores) { + this(vocabulary, scores, + IntStream.range(0, vocabulary.length) + .boxed() + .collect(Collectors.toMap(i -> vocabulary[i], i -> i))); + } + + public String get(int tokenIndex) { + return tokens[tokenIndex]; + } + + public OptionalInt getIndex(String token) { + Integer value = tokenToIndex.get(token); + return value != null ? OptionalInt.of(value) : OptionalInt.empty(); + } + + public int size() { + return tokens.length; + } +} + +record CategoricalSampler(RandomGenerator rng) implements Sampler { + + @Override + public int sampleToken(FloatTensor logits) { + // sample index from probabilities (they must sum to 1!) + float random0to1 = rng.nextFloat(1f); + float cdf = 0.0f; + for (int i = 0; i < logits.size(); i++) { + cdf += logits.getFloat(i); + if (random0to1 < cdf) { + return i; + } + } + return logits.size() - 1; // in case of rounding errors + } +} + +final class ToppSampler implements Sampler { + + final int[] indices; + final float topp; + final RandomGenerator rng; + + public ToppSampler(int maxNumberOfElements, float topp, RandomGenerator rng) { + this.indices = new int[maxNumberOfElements]; + this.topp = topp; + this.rng = rng; + } + + static void swap(int[] array, int from, int to) { + int tmp = array[from]; + array[from] = array[to]; + array[to] = tmp; + } + + static void siftDown(int[] array, int from, int n, Comparator comparator) { + int prev = from, next; + while ((next = 2 * prev + 1) < n) { + int r = 2 * prev + 2; + if (r < n && comparator.compare(array[r], array[next]) < 0) { + next = r; + } + if (comparator.compare(array[next], array[prev]) < 0) { + swap(array, prev, next); + prev = next; + } else { + break; + } + } + } + + @Override + public int sampleToken(FloatTensor logits) { + // top-p sampling (or "nucleus sampling") samples from the smallest set of + // tokens that exceed probability topp. This way we never sample tokens that + // have very low probabilities and are less likely to go "off the rails". + Comparator comparator = Comparator.comparingDouble(logits::getFloat).reversed(); + + int n = logits.size(); + int head = 0; + int tail = n - 1; + // values smaller than (1 - topp) / (n - 1) cannot be part of the result + // so for efficiency we crop these out as candidates before sorting + float cutoff = (1.0f - topp) / (n - 1); + for (int i = 0; i < indices.length; i++) { + if (logits.getFloat(i) >= cutoff) { + indices[head++] = i; + } else { + indices[tail--] = i; + } + } + + int n0 = head; + // build heap O(n0) + for (int i = n0 / 2 - 1; i >= 0; --i) { + siftDown(indices, i, n0, comparator); + } + + // truncate the list where cumulative probability of the largest k elements exceeds topp + // O(k lg n0) + float cumulativeProb = 0.0f; + int lastIndex = 0; + for (int i = n0 - 1; i >= 0; i--) { + swap(indices, 0, i); + cumulativeProb += logits.getFloat(indices[i]); + if (cumulativeProb > topp) { + lastIndex = i; + break; // we've exceeded topp by including lastIndex + } + siftDown(indices, 0, i - 1, comparator); + } + + // sample from the truncated list + float r = rng.nextFloat(1f) * cumulativeProb; + float cdf = 0.0f; + for (int i = n0 - 1; i >= lastIndex; i--) { + cdf += logits.getFloat(indices[i]); + if (r < cdf) { + return indices[i]; + } + } + + return indices[lastIndex]; // in case of rounding errors + } +} + +/** + * Support for AOT preloading of GGUF metadata with GraalVM's Native Image. + * + *

+ * To preload a model at build time, pass {@code -Dllama.PreloadGGUF=/path/to/model.gguf} + * to the native-image builder command. At runtime, the preloaded model will be used + * iff the specified and preloaded file names (base name) match. + */ +final class AOT { + record PartialModel(String modelFileName, Llama model, long tensorDataOffset, + Map tensorInfos) { + } + + private static final PartialModel PRELOADED_GGUF = preLoadGGUF(System.getProperty("llama.PreloadGGUF")); + + private static PartialModel preLoadGGUF(String modelPath) { + if (modelPath == null || modelPath.isEmpty()) { + return null; + } + try { + Path path = Path.of(modelPath); + if (!Files.exists(path) || !Files.isRegularFile(path)) { + throw new IllegalArgumentException("Cannot pre-load model: " + path); + } + GGUF gguf = GGUF.loadModel(path); + try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) { + return new PartialModel( + path.getFileName().toString(), + ModelLoader.loadModel(fileChannel, gguf, Llama3.Options.DEFAULT_MAX_TOKENS, false), + gguf.getTensorDataOffset(), + gguf.getTensorInfos()); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** + * Tries to reuse a compatible AOT preloaded model. + * The file name (base name) must match with the preloaded file name. + * No checksum/hash is checked for performance reasons. + */ + public static Llama tryUsePreLoaded(Path modelPath, int contextLength) throws IOException { + PartialModel preLoaded = AOT.PRELOADED_GGUF; + if (preLoaded == null) { + return null; // no pre-loaded model stored + } + String optionsModel = modelPath.getFileName().toString(); + String preLoadedModel = preLoaded.modelFileName(); + if (!Objects.equals(optionsModel, preLoadedModel)) { + // Preloaded and specified model file names didn't match. + return null; + } + Llama baseModel = preLoaded.model(); + try (var timer = Timer.log("Load tensors from pre-loaded model"); + var fileChannel = FileChannel.open(modelPath, StandardOpenOption.READ)) { + // Load only the tensors (mmap slices). + Map tensorEntries = GGUF.loadTensors(fileChannel, preLoaded.tensorDataOffset(), + preLoaded.tensorInfos()); + Llama.Weights weights = ModelLoader.loadWeights(tensorEntries, baseModel.configuration()); + return new Llama(baseModel.configuration().withContextLength(contextLength), baseModel.tokenizer(), weights); + } + } +} diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/ModelLoader.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/ModelLoader.java new file mode 100644 index 000000000..36e044115 --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/ModelLoader.java @@ -0,0 +1,159 @@ +package io.quarkiverse.langchain4j.llama3.copy; + +import java.io.IOException; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.function.IntFunction; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public final class ModelLoader { + private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2"; + + private static final String LLAMA_3_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; + + private static Vocabulary loadVocabulary(Map metadata) { + String model = (String) metadata.get("tokenizer.ggml.model"); + if (!TOKENIZER_LLAMA_3_MODEL.equals(model)) { + throw new IllegalArgumentException("expected " + TOKENIZER_LLAMA_3_MODEL + " but found " + model); + } + String[] tokens = (String[]) metadata.get("tokenizer.ggml.tokens"); + return new Vocabulary(tokens, null); + } + + public static Llama loadModel(Path ggufPath, int contextLength, boolean loadWeights) throws IOException { + GGUF gguf = GGUF.loadModel(ggufPath); + FileChannel fileChannel = FileChannel.open(ggufPath, StandardOpenOption.READ); + return loadModel(fileChannel, gguf, contextLength, loadWeights); + } + + public static Llama loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) + throws IOException { + try (var ignored = Timer.log("Load LlaMa model")) { + Map metadata = gguf.getMetadata(); + Vocabulary vocabulary = loadVocabulary(metadata); + Tokenizer tokenizer = createTokenizer(metadata, vocabulary); + + Llama.Configuration config = new Llama.Configuration( + (int) metadata.get("llama.embedding_length"), + (int) metadata.get("llama.feed_forward_length"), + (int) metadata.get("llama.block_count"), + (int) metadata.get("llama.attention.head_count"), + + metadata.containsKey("llama.attention.head_count_kv") + ? (int) metadata.get("llama.attention.head_count_kv") + : (int) metadata.get("llama.attention.head_count"), + + vocabulary.size(), + (int) metadata.get("llama.context_length"), + (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), + (float) metadata.getOrDefault("llama.rope.freq_base", 10000f)).withContextLength(contextLength); + + Llama.Weights weights = null; + if (loadWeights) { + Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), + gguf.getTensorInfos()); + weights = loadWeights(tensorEntries, config); + } + return new Llama(config, tokenizer, weights); + } + } + + static Llama.Weights loadWeights(Map tensorEntries, Llama.Configuration config) { + boolean ropeScaling = tensorEntries.containsKey("rope_freqs"); + float scaleFactor = 8; + float loFreqFactor = 1; + float hiFreqFactor = 3; + int oldContextLength = 8192; + Pair ropeFreqs = RoPE.precomputeFreqsCis(config.contextLength, config.headSize, config.ropeTheta, + ropeScaling, scaleFactor, loFreqFactor, hiFreqFactor, oldContextLength); + float[] ropeFreqsReal = ropeFreqs.first(); + float[] ropeFreqsImag = ropeFreqs.second(); + + GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight"); + Llama.Weights qw = new Llama.Weights( + loadQuantized(tokenEmbeddings), + loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1 + loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2 + loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3 + toFloatBuffer(tensorEntries.get("output_norm.weight")), + FloatBuffer.wrap(ropeFreqsReal), + FloatBuffer.wrap(ropeFreqsImag), + // If "output.weight" is not present then the embedding weights are tied/shared with the decoder. + // This is commonly referred as "tie word embeddings". + loadQuantized(tensorEntries.getOrDefault("output.weight", tokenEmbeddings))); + + return qw; + } + + private static Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { + String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges"); + List> merges = Arrays.stream(mergeLines) + .map(line -> line.split(" ")) + .map(parts -> new Pair<>( + vocabulary.getIndex(parts[0]).orElseThrow(), + vocabulary.getIndex(parts[1]).orElseThrow())) + .toList(); + + int allTokens = vocabulary.size(); + int baseTokens = 128000; // assume all tokens after the base ones are special. + int reservedSpecialTokens = allTokens - baseTokens; + List specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList(); + + assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent()); + + Map specialTokens = IntStream.range(0, specialTokensList.size()) + .boxed() + .collect(Collectors.toMap( + i -> specialTokensList.get(i), + i -> baseTokens + i)); + + return new Tokenizer(vocabulary, merges, LLAMA_3_PATTERN, specialTokens); + } + + public static FloatTensor loadQuantized(GGMLTensorEntry entry) { + GGMLType ggmlType = entry.ggmlType(); + return switch (ggmlType) { + //case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); + case Q8_0 -> new Q8_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); + case Q4_0 -> new Q4_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); + default -> throw new UnsupportedOperationException("Quantization format " + ggmlType); + }; + } + + public static FloatTensor[] loadArrayOfQuantized(int size, IntFunction getTensorEntry) { + FloatTensor[] array = new FloatTensor[size]; + for (int i = 0; i < size; i++) { + array[i] = loadQuantized(getTensorEntry.apply(i)); + } + return array; + } + + public static FloatBuffer[] loadArrayOfFloatBuffer(int size, IntFunction getTensorEntry) { + FloatBuffer[] array = new FloatBuffer[size]; + for (int i = 0; i < size; i++) { + array[i] = toFloatBuffer(getTensorEntry.apply(i)); + } + return array; + } + + public static FloatBuffer toFloatBuffer(GGMLTensorEntry tensorEntry) { + GGMLType ggmlType = tensorEntry.ggmlType(); + return switch (ggmlType) { + case F32 -> tensorEntry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(); + default -> throw new UnsupportedOperationException("Conversion to " + ggmlType); + }; + } +} diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Sampler.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Sampler.java new file mode 100644 index 000000000..a4c646d0a --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Sampler.java @@ -0,0 +1,8 @@ +package io.quarkiverse.langchain4j.llama3.copy; + +@FunctionalInterface +public interface Sampler { + int sampleToken(FloatTensor logits); + + Sampler ARGMAX = FloatTensor::argmax; +} diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Tokenizer.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Tokenizer.java new file mode 100644 index 000000000..51fbd4af8 --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/Tokenizer.java @@ -0,0 +1,268 @@ +package io.quarkiverse.langchain4j.llama3.copy; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HexFormat; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * Byte Pair Encoding tokenizer. + *

+ * Based on minbpe, algorithmically follows along the + * GPT 2 tokenizer + */ +public class Tokenizer { + private final Pattern compiledPattern; + private final Vocabulary vocabulary; + private final Map, Integer> merges; + private final Map specialTokens; + + public String regexPattern() { + if (compiledPattern == null) { + return null; + } + return compiledPattern.pattern(); + } + + public Map getSpecialTokens() { + return specialTokens; + } + + public boolean isSpecialToken(int tokenIndex) { + return specialTokens.containsValue(tokenIndex); + } + + public Tokenizer(Vocabulary vocabulary, List> merges, String regexPattern, + Map specialTokens) { + this.vocabulary = vocabulary; + this.compiledPattern = regexPattern != null ? Pattern.compile(regexPattern) : null; + this.specialTokens = new HashMap<>(specialTokens); + this.merges = new HashMap<>(); + for (Pair pair : merges) { + int firstIndex = pair.first(); + int secondIndex = pair.second(); + int mergeIndex = vocabulary.getIndex(vocabulary.get(firstIndex) + vocabulary.get(secondIndex)).orElseThrow(); + this.merges.put(pair, mergeIndex); + } + } + + private int[] encodeImpl(String text) { + return encode(text, Set.of()).stream().mapToInt(i -> i).toArray(); + } + + /** + * Unlike {@link #encodeOrdinary(String)}, this function handles special tokens. + * allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens + * if none_raise, then an error is raised if any special token is encountered in text + * this is the default tiktoken behavior right now as well + * any other behavior is either annoying, or a major footgun. + */ + List encode(String text, Set allowedSpecial) { + // decode the user desire w.r.t. handling of special tokens + Set special = allowedSpecial; + assert getSpecialTokens().keySet().containsAll(special); + if (special.isEmpty()) { + // shortcut: if no special tokens, just use the ordinary encoding + return encodeOrdinary(text); + } + + // otherwise, we have to be careful with potential special tokens in text + // we handle special tokens by splitting the text + // based on the occurrence of any exact match with any of the special tokens + // we can use re.split for this. note that surrounding the pattern with () + // makes it into a capturing group, so the special tokens will be included + String specialPattern = special + .stream() + .map(Pattern::quote) + .collect(Collectors.joining("|", "(", ")")); + + String[] specialChunks = text.split(specialPattern); + // now all the special characters are separated from the rest of the text + // all chunks of text are encoded separately, then results are joined + List ids = new ArrayList<>(); + for (String part : specialChunks) { + if (special.contains(part)) { + // this is a special token, encode it separately as a special case + ids.add(getSpecialTokens().get(part)); + } else { + // this is an ordinary sequence, encode it normally + ids.addAll(encodeOrdinary(part)); + } + } + return ids; + } + + private static List findAll(Pattern pattern, String text) { + List allMatches = new ArrayList<>(); + Matcher matcher = pattern.matcher(text); + while (matcher.find()) { + allMatches.add(matcher.group()); + } + return allMatches; + } + + /** + * Encoding that ignores any special tokens. + */ + public List encodeOrdinary(String text) { + // split text into chunks of text by categories defined in regex pattern + List textChunks = findAll(compiledPattern, text); + // all chunks of text are encoded separately, then results are joined + List ids = new ArrayList<>(); + for (String chunk : textChunks) { + List chunkIds = encodeChunk(chunk); + ids.addAll(chunkIds); + } + return ids; + } + + private Map, Integer> getStats(List ids) { + Map, Integer> map = new HashMap<>(); + for (int i = 0; i + 1 < ids.size(); i++) { + Pair key = new Pair<>(ids.get(i), ids.get(i + 1)); + map.put(key, map.getOrDefault(key, 0) + 1); + } + return map; + } + + private List encodeChunk(String chunk) { + // return the token ids + // let's begin. first, convert all bytes to integers in range 0..255 + List ids = new ArrayList<>(); + for (int b : chunk.toCharArray()) { + int tokenIndex = this.vocabulary.getIndex(String.valueOf((char) b)).orElseThrow(); + ids.add(tokenIndex); + } + + while (ids.size() >= 2) { + // find the pair with the lowest merge index + Map, Integer> stats = getStats(ids); + Pair pair = stats.keySet().stream() + .min(Comparator.comparingInt(key -> this.merges.getOrDefault(key, Integer.MAX_VALUE))).orElseThrow(); + // subtle: if there are no more merges available, the key will + // result in an inf for every single pair, and the min will be + // just the first pair in the list, arbitrarily + // we can detect this terminating case by a membership check + if (!this.merges.containsKey(pair)) { + break; // nothing else can be merged anymore + } + // otherwise let's merge the best pair (lowest merge index) + int idx = this.merges.get(pair); + ids = merge(ids, pair, idx); + } + return ids; + } + + private static List merge(List ids, Pair pair, int idx) { + List newids = new ArrayList<>(); + int i = 0; + while (i < ids.size()) { + // if not at the very last position AND the pair matches, replace it + if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) { + newids.add(idx); + i += 2; + } else { + newids.add(ids.get(i)); + i += 1; + } + } + return newids; + } + + public String decodeImpl(List tokens) { + StringBuilder sb = new StringBuilder(); + for (int token : tokens) { + String tokenString = vocabulary.get(token); + sb.append(tokenString); + } + return sb.toString(); + } + + /** + * Returns list of utf-8 byte and a corresponding list of unicode strings. + * The reversible bpe codes work on unicode strings. + * This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + * When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + * This is a significant percentage of your normal, say, 32K bpe vocab. + * To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + * And avoids mapping to whitespace/control characters the bpe code barfs on. + */ + private static Map bytesToUnicode() { + List bs = new ArrayList<>(); + IntStream.rangeClosed('!', '~').forEach(bs::add); + IntStream.rangeClosed('¡', '¬').forEach(bs::add); + IntStream.rangeClosed('®', 'ÿ').forEach(bs::add); + + List cs = new ArrayList<>(bs); + int n = 0; + for (int b = 0; b < 256; ++b) { + if (!bs.contains(b)) { + bs.add(b); + cs.add(256 + n); + n += 1; + } + } + + // return dict(zip(bs, cs)) + return IntStream.range(0, bs.size()) + .boxed() + .collect(Collectors.toMap(bs::get, cs::get)); + } + + static final Map BYTE_ENCODER = bytesToUnicode(); + static final Map BYTE_DECODER = BYTE_ENCODER.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); + + public int[] encode(String text) { + StringBuilder sb = new StringBuilder(); + byte[] bytes = text.getBytes(StandardCharsets.UTF_8); + for (byte b : bytes) { + sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b))); + } + return encodeImpl(sb.toString()); + } + + public static String replaceControlCharacters(int[] codePoints) { + // we don't want to print control characters + // which distort the output (e.g. \n or much worse) + // https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117 + // http://www.unicode.org/reports/tr44/#GC_Values_Table\ + StringBuilder chars = new StringBuilder(); + for (int cp : codePoints) { + if (Character.getType(cp) == Character.CONTROL && cp != '\n') { + chars.append("\\u").append(HexFormat.of().toHexDigits(cp, 4)); // escape + } else { + chars.appendCodePoint(cp); // this character is ok + } + } + return chars.toString(); + } + + public static String replaceControlCharacters(String str) { + return replaceControlCharacters(str.codePoints().toArray()); + } + + public List encodeAsList(String text) { + return Arrays.stream(encode(text)).boxed().toList(); + } + + public String decode(List tokens) { + String decoded = decodeImpl(tokens); + int[] decodedBytesAsInts = decoded.codePoints().map(BYTE_DECODER::get).toArray(); + byte[] rawBytes = new byte[decodedBytesAsInts.length]; + for (int i = 0; i < decoded.length(); i++) { + rawBytes[i] = (byte) decodedBytesAsInts[i]; + } + return new String(rawBytes, StandardCharsets.UTF_8); + } +} diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/Llama3Recorder.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/Llama3Recorder.java new file mode 100644 index 000000000..d831418c4 --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/Llama3Recorder.java @@ -0,0 +1,80 @@ +package io.quarkiverse.langchain4j.llama3.runtime; + +import java.util.function.Supplier; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.DisabledChatLanguageModel; +import io.quarkiverse.langchain4j.llama3.Llama3ChatModel; +import io.quarkiverse.langchain4j.llama3.runtime.config.ChatModelConfig; +import io.quarkiverse.langchain4j.llama3.runtime.config.LangChain4jLlama3FixedRuntimeConfig; +import io.quarkiverse.langchain4j.llama3.runtime.config.LangChain4jLlama3RuntimeConfig; +import io.quarkiverse.langchain4j.runtime.NamedConfigUtil; +import io.quarkus.runtime.annotations.Recorder; + +@Recorder +public class Llama3Recorder { + + public Supplier chatModel(LangChain4jLlama3RuntimeConfig runtimeConfig, + LangChain4jLlama3FixedRuntimeConfig fixedRuntimeConfig, + String configName) { + LangChain4jLlama3RuntimeConfig.Llama3Config llama3Config = correspondingJlamaConfig(runtimeConfig, configName); + LangChain4jLlama3FixedRuntimeConfig.Llama3Config llama3FixedRuntimeConfig = correspondingJlamaFixedRuntimeConfig( + fixedRuntimeConfig, configName); + + if (llama3Config.enableIntegration()) { + ChatModelConfig chatModelConfig = llama3Config.chatModel(); + + String modelName = llama3FixedRuntimeConfig.chatModel().modelName(); + var builder = Llama3ChatModel.builder() + .modelName(llama3FixedRuntimeConfig.chatModel().modelName()) + .quantization(llama3FixedRuntimeConfig.chatModel().quantization()) + .modelCachePath(fixedRuntimeConfig.modelsPath()); + + if (chatModelConfig.temperature().isPresent()) { + builder.temperature((float) chatModelConfig.temperature().getAsDouble()); + } + if (chatModelConfig.maxTokens().isPresent()) { + builder.maxTokens(chatModelConfig.maxTokens().getAsInt()); + } + + return new Supplier<>() { + @Override + public ChatLanguageModel get() { + return builder.build(); + } + }; + } else { + return new Supplier<>() { + @Override + public ChatLanguageModel get() { + return new DisabledChatLanguageModel(); + } + }; + } + } + + private LangChain4jLlama3RuntimeConfig.Llama3Config correspondingJlamaConfig( + LangChain4jLlama3RuntimeConfig runtimeConfig, + String configName) { + LangChain4jLlama3RuntimeConfig.Llama3Config llama3Config; + if (NamedConfigUtil.isDefault(configName)) { + llama3Config = runtimeConfig.defaultConfig(); + } else { + llama3Config = runtimeConfig.namedConfig().get(configName); + } + return llama3Config; + } + + private LangChain4jLlama3FixedRuntimeConfig.Llama3Config correspondingJlamaFixedRuntimeConfig( + LangChain4jLlama3FixedRuntimeConfig runtimeConfig, + String configName) { + LangChain4jLlama3FixedRuntimeConfig.Llama3Config llama3Config; + if (NamedConfigUtil.isDefault(configName)) { + llama3Config = runtimeConfig.defaultConfig(); + } else { + llama3Config = runtimeConfig.namedConfig().get(configName); + } + return llama3Config; + } + +} diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/ChatModelConfig.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/ChatModelConfig.java new file mode 100644 index 000000000..6953c3dd7 --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/ChatModelConfig.java @@ -0,0 +1,21 @@ +package io.quarkiverse.langchain4j.llama3.runtime.config; + +import java.util.OptionalDouble; +import java.util.OptionalInt; + +import io.quarkus.runtime.annotations.ConfigGroup; + +@ConfigGroup +public interface ChatModelConfig { + + /** + * TODO + */ + OptionalDouble temperature(); + + /** + * TODO + */ + OptionalInt maxTokens(); + +} diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/ChatModelFixedRuntimeConfig.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/ChatModelFixedRuntimeConfig.java new file mode 100644 index 000000000..d3ad0855a --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/ChatModelFixedRuntimeConfig.java @@ -0,0 +1,20 @@ +package io.quarkiverse.langchain4j.llama3.runtime.config; + +import io.quarkus.runtime.annotations.ConfigGroup; +import io.smallrye.config.WithDefault; + +@ConfigGroup +public interface ChatModelFixedRuntimeConfig { + + /** + * Model name to use + */ + @WithDefault("mukel/Llama-3.2-3B-Instruct-GGUF") + String modelName(); + + /** + * TODO + */ + @WithDefault("Q4_0") + String quantization(); +} diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/LangChain4jLlama3FixedRuntimeConfig.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/LangChain4jLlama3FixedRuntimeConfig.java new file mode 100644 index 000000000..6cb213c16 --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/LangChain4jLlama3FixedRuntimeConfig.java @@ -0,0 +1,52 @@ +package io.quarkiverse.langchain4j.llama3.runtime.config; + +import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_AND_RUN_TIME_FIXED; + +import java.nio.file.Path; +import java.util.Map; +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigDocMapKey; +import io.quarkus.runtime.annotations.ConfigDocSection; +import io.quarkus.runtime.annotations.ConfigGroup; +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithDefaults; +import io.smallrye.config.WithParentName; + +@ConfigRoot(phase = BUILD_AND_RUN_TIME_FIXED) +@ConfigMapping(prefix = "quarkus.langchain4j.llama3") +public interface LangChain4jLlama3FixedRuntimeConfig { + + /** + * Default model config. + */ + @WithParentName + Llama3Config defaultConfig(); + + /** + * Named model config. + */ + @ConfigDocSection + @ConfigDocMapKey("model-name") + @WithParentName + @WithDefaults + Map namedConfig(); + + /** + * Location on the file-system which serves as a cache for the models + * + */ + @ConfigDocDefault("${user.name}/.jlama/models") + Optional modelsPath(); + + @ConfigGroup + interface Llama3Config { + + /** + * Chat model related settings + */ + ChatModelFixedRuntimeConfig chatModel(); + } +} diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/LangChain4jLlama3RuntimeConfig.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/LangChain4jLlama3RuntimeConfig.java new file mode 100644 index 000000000..b1f950549 --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/LangChain4jLlama3RuntimeConfig.java @@ -0,0 +1,50 @@ +package io.quarkiverse.langchain4j.llama3.runtime.config; + +import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME; + +import java.util.Map; + +import io.quarkus.runtime.annotations.ConfigDocMapKey; +import io.quarkus.runtime.annotations.ConfigDocSection; +import io.quarkus.runtime.annotations.ConfigGroup; +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithDefault; +import io.smallrye.config.WithDefaults; +import io.smallrye.config.WithParentName; + +@ConfigRoot(phase = RUN_TIME) +@ConfigMapping(prefix = "quarkus.langchain4j.llama3") +public interface LangChain4jLlama3RuntimeConfig { + + /** + * Default model config. + */ + @WithParentName + Llama3Config defaultConfig(); + + /** + * Named model config. + */ + @ConfigDocSection + @ConfigDocMapKey("model-name") + @WithParentName + @WithDefaults + Map namedConfig(); + + @ConfigGroup + interface Llama3Config { + + /** + * Chat model related settings + */ + ChatModelConfig chatModel(); + + /** + * Whether to enable the integration. Set to {@code false} to disable + * all requests. + */ + @WithDefault("true") + Boolean enableIntegration(); + } +} diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/ModelsPathConfigSource.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/ModelsPathConfigSource.java new file mode 100644 index 000000000..b472c1530 --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/config/ModelsPathConfigSource.java @@ -0,0 +1,77 @@ +package io.quarkiverse.langchain4j.llama3.runtime.config; + +import java.io.File; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Set; + +import org.eclipse.microprofile.config.spi.ConfigSource; + +import io.quarkus.runtime.LaunchMode; + +/** + * Sets {@code quarkus.langchain4j.jlama.models-path} to {@code quarkus-app/jlama} if it exists + */ +public class ModelsPathConfigSource implements ConfigSource { + + private static final String SENTINEL = "#sen-val#"; + public static final String SUPPORTED_PROPERTY_NAME = "quarkus.langchain4j.jlama.models-path"; + + private volatile String value = null; + + @Override + public String getName() { + return "ModelsPathConfigSource"; + } + + @Override + public int getOrdinal() { + // make it overridable by users + return DEFAULT_ORDINAL; + } + + @Override + public Set getPropertyNames() { + return Set.of(SUPPORTED_PROPERTY_NAME); + } + + @Override + public String getValue(String name) { + if (!SUPPORTED_PROPERTY_NAME.equals(name)) { + return null; + } + if (LaunchMode.current() != LaunchMode.NORMAL) { + return null; + } + String result = value; + if (result == null) { + result = value = produceValue(); + } + if (result.equals(SENTINEL)) { + return null; + } + return result; + } + + private String produceValue() { + try { + Class clazz = Class.forName("io.quarkus.bootstrap.runner.QuarkusEntryPoint", false, Thread.currentThread() + .getContextClassLoader()); + String path = clazz.getProtectionDomain().getCodeSource().getLocation().getPath(); + if (path == null) { + return SENTINEL; + } + String decodedPath = URLDecoder.decode(path, StandardCharsets.UTF_8); + Path appRoot = new File(decodedPath).toPath().getParent().getParent().getParent(); + Path jlamaRoot = appRoot.resolve("jlama"); + if (Files.isDirectory(jlamaRoot)) { + return jlamaRoot.toAbsolutePath().toString(); + } + } catch (ClassNotFoundException ignored) { + + } + return SENTINEL; + } +} diff --git a/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/graal/Llama3Feature.java b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/graal/Llama3Feature.java new file mode 100644 index 000000000..fc571748b --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/runtime/graal/Llama3Feature.java @@ -0,0 +1,19 @@ +package io.quarkiverse.langchain4j.llama3.runtime.graal; + +import org.graalvm.nativeimage.hosted.Feature; +import org.graalvm.nativeimage.hosted.RuntimeClassInitialization; + +public class Llama3Feature implements Feature { + + @Override + public void beforeAnalysis(BeforeAnalysisAccess access) { + try { + // needed to make the native image run at acceptable speed + RuntimeClassInitialization.initializeAtBuildTime( + Class.forName("io.quarkiverse.langchain4j.llama3.copy.FloatTensor", false, Thread.currentThread() + .getContextClassLoader())); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/model-providers/llama3-java/runtime/src/main/resources/META-INF/beans.xml b/model-providers/llama3-java/runtime/src/main/resources/META-INF/beans.xml new file mode 100644 index 000000000..e69de29bb diff --git a/model-providers/llama3-java/runtime/src/main/resources/META-INF/quarkus-extension.yaml b/model-providers/llama3-java/runtime/src/main/resources/META-INF/quarkus-extension.yaml new file mode 100644 index 000000000..99846d1b4 --- /dev/null +++ b/model-providers/llama3-java/runtime/src/main/resources/META-INF/quarkus-extension.yaml @@ -0,0 +1,12 @@ +name: LangChain4j Llama3 Java +artifact: ${project.groupId}:${project.artifactId}:${project.version} +description: Provides integration of Quarkus LangChain4j with Llama3 Java +metadata: + keywords: + - ai + - langchain4j + - llama3-java + guide: "https://docs.quarkiverse.io/quarkus-langchain4j/dev/index.html" + categories: + - "ai" + status: "experimental"