diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ChatCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ChatCommand.java index 1bdafc4..a5d4c51 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ChatCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ChatCommand.java @@ -20,7 +20,7 @@ import com.diogonunes.jcolor.AnsiFormat; import com.diogonunes.jcolor.Attribute; import com.github.tjake.jlama.model.AbstractModel; -import com.github.tjake.jlama.safetensors.tokenizer.PromptSupport; +import com.github.tjake.jlama.safetensors.prompt.PromptSupport; import java.io.PrintWriter; import java.util.Optional; import java.util.Scanner; @@ -84,7 +84,7 @@ public void run() { break; } - PromptSupport.Builder builder = promptSupport.newBuilder(); + PromptSupport.Builder builder = promptSupport.builder(); if (first && systemPrompt != null) { builder.addSystemMessage(systemPrompt); } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java index a381c09..5cdeb30 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java @@ -22,7 +22,7 @@ import com.github.tjake.jlama.safetensors.Config; import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.safetensors.WeightLoader; -import com.github.tjake.jlama.safetensors.tokenizer.PromptSupport; +import com.github.tjake.jlama.safetensors.prompt.PromptSupport; import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer; import com.github.tjake.jlama.tensor.AbstractTensor; import com.github.tjake.jlama.tensor.KvBufferCache; @@ -330,7 +330,7 @@ public Response generate( promptLength = encoded.length; if (useEOS) { - promptTokens[promptTokens.length - 1] = c.eosToken; // Add EOS + promptTokens[promptTokens.length - 1] = getConfig().eosTokens.getLast(); // Add EOS promptLength++; } @@ -371,14 +371,14 @@ public Response generate( if (logger.isTraceEnabled()) logger.trace("Sampled token {} with temperature {}", next, temperature); output.close(); + kvmem.setMetadata(KvBufferCache.TOKEN_COUNT, i); + // Model may tell us it's done - if (next == c.eosToken) { + if (c.eosTokens.contains(next)) { reason = FinishReason.STOP_TOKEN; break; } - kvmem.setMetadata(KvBufferCache.TOKEN_COUNT, i); - try { String c = tokenizer.decode(next); genMsPerToken = (System.currentTimeMillis() - start) / (float) (tokensGenerated); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertConfig.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertConfig.java index 9ff81bd..d98b9bb 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertConfig.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/bert/BertConfig.java @@ -20,6 +20,8 @@ import com.github.tjake.jlama.math.ActivationFunction; import com.github.tjake.jlama.safetensors.Config; +import java.util.List; + public class BertConfig extends Config { @JsonCreator public BertConfig( @@ -41,7 +43,7 @@ public BertConfig( layerNormEps, vocabularySize, 0, - 0, + List.of(0), activationFunction, null, null); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/Generator.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/Generator.java index 3dac12f..6cd39d8 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/Generator.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/Generator.java @@ -27,6 +27,7 @@ public interface Generator { enum FinishReason { MAX_TOKENS, STOP_TOKEN, + TOOL_CALL, ERROR } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaConfig.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaConfig.java index 5171ac1..4cecc94 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaConfig.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma/GemmaConfig.java @@ -19,6 +19,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.github.tjake.jlama.math.ActivationFunction; import com.github.tjake.jlama.safetensors.Config; + +import java.util.List; import java.util.Map; public class GemmaConfig extends Config { @@ -47,7 +49,7 @@ public GemmaConfig( layerNormEps, vocabularySize, bosToken, - eosToken, + List.of(eosToken), activationFunction, ropeFreqsTheta == null ? 10000.0 : ropeFreqsTheta, ropeScaling == null ? 1.0 : Double.parseDouble(ropeScaling.get("factor"))); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Config.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Config.java index ce65e65..d65fdb9 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Config.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gpt2/GPT2Config.java @@ -20,6 +20,8 @@ import com.github.tjake.jlama.math.ActivationFunction; import com.github.tjake.jlama.safetensors.Config; +import java.util.List; + public class GPT2Config extends Config { @JsonCreator @@ -42,7 +44,7 @@ public GPT2Config( layerNormEps, vocabularySize, bosToken, - eosToken, + List.of(eosToken), ActivationFunction.Type.GELU, null, null); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaConfig.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaConfig.java index a5a191e..377feaa 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaConfig.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaConfig.java @@ -50,7 +50,7 @@ public LlamaConfig( layerNormEps, vocabularySize, bosToken, - eosToken instanceof List ? ((List)eosToken).get(((List)eosToken).size() - 1) : (Integer) eosToken, //for llama3.1 + eosToken instanceof List ? (List) eosToken : List.of((Integer) eosToken), activationFunction, ropeFreqsTheta == null ? 10000.0 : ropeFreqsTheta, ropeScaling == null || !("linear".equals(ropeScaling.get("rope_type"))) ? 1.0 : Double.parseDouble(ropeScaling.get("factor"))); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/mistral/MistralConfig.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/mistral/MistralConfig.java index 3dc80f2..cd6d9f3 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/mistral/MistralConfig.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/mistral/MistralConfig.java @@ -20,6 +20,8 @@ import com.github.tjake.jlama.math.ActivationFunction; import com.github.tjake.jlama.safetensors.Config; +import java.util.List; + public class MistralConfig extends Config { @JsonCreator public MistralConfig( @@ -46,7 +48,7 @@ public MistralConfig( layerNormEps, vocabularySize, bosToken, - eosToken, + List.of(eosToken), activationFunction, ropeTheta, 1.0, diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralConfig.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralConfig.java index d5d7e26..fa49faf 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralConfig.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/mixtral/MixtralConfig.java @@ -20,6 +20,8 @@ import com.github.tjake.jlama.math.ActivationFunction; import com.github.tjake.jlama.safetensors.Config; +import java.util.List; + public class MixtralConfig extends Config { public final int numberOfExperts; @@ -52,7 +54,7 @@ public MixtralConfig( layerNormEps, vocabularySize, bosToken, - eosToken, + List.of(eosToken), activationFunction, ropeTheta, 1.0); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java index db9833b..5ce1193 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java @@ -22,6 +22,7 @@ import com.google.common.base.Preconditions; import com.google.common.io.Files; import java.io.File; +import java.util.List; import java.util.Optional; public class Config { @@ -39,7 +40,7 @@ public class Config { public final float layerNormEps; public final int vocabularySize; public final int bosToken; - public final int eosToken; + public final List eosTokens; public final Optional ropeFreqs; private volatile Optional> offset; private volatile File workingDirectory; @@ -69,7 +70,7 @@ public Config( float layerNormEps, int vocabularySize, int bosToken, - int eosToken, + List eosToken, ActivationFunction.Type activationFunction, Double ropeFreqsTheta, Double ropeScalingFactor) { @@ -100,7 +101,7 @@ public Config( float layerNormEps, int vocabularySize, int bosToken, - int eosToken, + List eosTokens, ActivationFunction.Type activationFunction, Double ropeFreqsTheta, Double ropeScalingFactor, @@ -114,7 +115,7 @@ public Config( this.layerNormEps = layerNormEps; this.vocabularySize = vocabularySize; this.bosToken = bosToken; - this.eosToken = eosToken; + this.eosTokens = eosTokens; this.tensorCache = TensorCache.instance; this.headSize = headSize; this.headGroupSize = numberOfHeads / numberOfKeyValueHeads; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Function.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Function.java new file mode 100644 index 0000000..88d8afa --- /dev/null +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Function.java @@ -0,0 +1,76 @@ +package com.github.tjake.jlama.safetensors.prompt; + +import com.fasterxml.jackson.annotation.JsonPropertyOrder; +import com.google.common.base.Preconditions; + +import java.util.Map; + +/** + * FunctionObject + */ +@JsonPropertyOrder({Function.JSON_PROPERTY_NAME, Function.JSON_PROPERTY_DESCRIPTION, Function.JSON_PROPERTY_PARAMETERS}) +public class Function { + public static final String JSON_PROPERTY_NAME = "name"; + private final String name; + + public static final String JSON_PROPERTY_DESCRIPTION = "description"; + private final String description; + + public static final String JSON_PROPERTY_PARAMETERS = "parameters"; + private final Parameters parameters; + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String name; + private String description = ""; + + private Parameters.Builder parameters = Parameters.builder(); + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder addParameter(String name, String type, String description, boolean required) { + this.parameters.addProperty(name, type, description, required); + return this; + } + + public Builder addParameter(String name, Map properties, boolean required) { + this.parameters.addProperty(name, properties, required); + return this; + } + + public Function build() { + Preconditions.checkNotNull(name, "name is required"); + + return new Function(name, description, parameters.build()); + } + } + + private Function(String name, String description, Parameters parameters) { + this.name = name; + this.description = description; + this.parameters = parameters; + } + + public String getName() { + return name; + } + + public String getDescription() { + return description; + } + + public Parameters getParameters() { + return parameters; + } +} diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Parameters.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Parameters.java new file mode 100644 index 0000000..e5ec705 --- /dev/null +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Parameters.java @@ -0,0 +1,74 @@ +package com.github.tjake.jlama.safetensors.prompt; + +import com.fasterxml.jackson.annotation.JsonPropertyOrder; + +import java.util.*; + +/** + * Parameters + */ +@JsonPropertyOrder({ + Parameters.JSON_PROPERTY_TYPE, + Parameters.JSON_PROPERTY_PROPERTIES, + Parameters.JSON_PROPERTY_REQUIRED}) +public class Parameters { + + public static final String JSON_PROPERTY_TYPE = "type"; + private final String type = "object"; + + public static final String JSON_PROPERTY_PROPERTIES = "properties"; + private final Map> properties; + + public static final String JSON_PROPERTY_REQUIRED = "required"; + private List required; + + public Parameters(Map> properties, List required) { + this.properties = properties; + this.required = required; + } + + public Parameters(Map> properties) { + this.properties = properties; + this.required = null; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private Map> properties = new LinkedHashMap<>(); + private Set required = new LinkedHashSet<>(); + + public Builder addProperty(String name, String type, String description, boolean required) { + Map properties = new LinkedHashMap<>(); + properties.put("type", type); + properties.put("description", description); + return addProperty(name, properties, required); + } + + public Builder addProperty(String name, Map properties, boolean required) { + this.properties.put(name, properties); + if (required) { + this.required.add(name); + } + return this; + } + + public Parameters build() { + return new Parameters(properties, new ArrayList<>(required)); + } + } + + public String getType() { + return this.type; + } + + public Map> getProperties() { + return properties; + } + + public List getRequired() { + return required; + } +} \ No newline at end of file diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/PromptSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/PromptSupport.java similarity index 51% rename from jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/PromptSupport.java rename to jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/PromptSupport.java index 8bece47..4e66317 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/PromptSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/PromptSupport.java @@ -13,14 +13,20 @@ * License for the specific language governing permissions and limitations * under the License. */ -package com.github.tjake.jlama.safetensors.tokenizer; +package com.github.tjake.jlama.safetensors.prompt; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.github.tjake.jlama.safetensors.tokenizer.TokenizerModel; import com.hubspot.jinjava.Jinjava; import com.hubspot.jinjava.JinjavaConfig; +import com.hubspot.jinjava.LegacyOverrides; +import com.hubspot.jinjava.interpret.RenderResult; import com.hubspot.jinjava.lib.fn.ELFunctionDefinition; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; + +import java.util.*; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -33,8 +39,13 @@ public class PromptSupport { // This matches the jinja config in huggingface private static final Jinjava jinjava = new Jinjava(JinjavaConfig.newBuilder() - .withLstripBlocks(true) .withTrimBlocks(true) + .withLstripBlocks(true) + .withLegacyOverrides(LegacyOverrides.newBuilder() + .withParseWhitespaceControlStrictly(true) + .withUseTrimmingForNotesAndExpressions(true) + .withUseSnakeCasePropertyNaming(true) + .build()) .build()); static { @@ -49,7 +60,7 @@ public PromptSupport(TokenizerModel model) { this.m = model; } - public Builder newBuilder() { + public Builder builder() { return new Builder(this.m); } @@ -70,25 +81,90 @@ private enum PromptType { private enum PromptRole { USER, SYSTEM, - ASSISTANT + ASSISTANT, + TOOL, + TOOL_CALL } static class Message { - private final String content; + private final Object content; private final PromptRole role; + private final ToolCallFunction toolCalls; - public Message(String content, PromptRole role) { + private Message(Object content, PromptRole role) { this.content = content; this.role = role; + this.toolCalls = null; + } + + private Message(ToolCall toolCall) { + this.content = null; + this.role = PromptRole.TOOL_CALL; + this.toolCalls = new ToolCallFunction(toolCall); } - public String getContent() { + public Object getContent() { return content; } + public Map toMap() { + Map map = new HashMap(); + map.put("role", role.name().toLowerCase()); + + if (content != null) { + map.put("content", content); + } + + if (toolCalls != null) { + map.put("tool_calls", List.of(toolCalls.toMap())); + } + + return map; + } + public String getRole() { return role.name().toLowerCase(); } + + public List toolCalls() { + if (toolCalls == null) { + return null; + } + + return List.of(toolCalls); + } + } + + static class ToolCallFunction { + private final ToolCall call; + + private ToolCallFunction(ToolCall call) { + this.call = call; + } + + public InnerToolCall function() { + return new InnerToolCall(call); + } + + public Map toMap() { + return Map.of("function", Map.of("name", call.getName(), "arguments", call.getParameters())); + } + } + + static class InnerToolCall { + private final ToolCall call; + + private InnerToolCall(ToolCall call) { + this.call = call; + } + + public Map arguments() { + return call.getParameters(); + } + + public String name() { + return call.getName(); + } } public static class Builder { @@ -97,6 +173,7 @@ public static class Builder { private boolean addGenerationPrompt = true; private List messages = new ArrayList<>(2); + private List tools = null; private Builder(TokenizerModel m) { this.m = m; @@ -117,6 +194,16 @@ public Builder addUserMessage(String content) { return this; } + public Builder addToolResult(Result result) { + messages.add(new Message(result.toJson(), PromptRole.TOOL)); + return this; + } + + public Builder addToolCall(ToolCall call) { + messages.add(new Message(call)); + return this; + } + public Builder addSystemMessage(String content) { messages.add(new Message(content, PromptRole.SYSTEM)); return this; @@ -127,6 +214,32 @@ public Builder addAssistantMessage(String content) { return this; } + public Builder addTools(List tools) { + if (this.tools == null) { + this.tools = new ArrayList<>(tools); + } else { + throw new IllegalArgumentException("Tools already set"); + } + return this; + } + + public Builder addTools(Tool... tools) { + if (this.tools == null) { + this.tools = Arrays.asList(tools); + } else { + throw new IllegalArgumentException("Tools already set"); + } + return this; + } + + public boolean hasTools() { + return tools != null && !tools.isEmpty(); + } + + public List getTools() { + return tools; + } + public String build() { if (messages.isEmpty()) { return ""; @@ -141,17 +254,28 @@ public String build() { .orElseThrow( () -> new UnsupportedOperationException("Prompt template not available for type: " + type)); - return jinjava.render( - template, - Map.of( + Map args = new HashMap(); + + args.putAll(Map.of( "messages", - messages, + messages.stream().map(Message::toMap).toList(), "add_generation_prompt", addGenerationPrompt, "eos_token", m.eosToken(), "bos_token", - "")); // We add the BOS ourselves + "")); // We add the BOS ourselves + + if (tools != null) { + args.put("tools", tools); + } + + RenderResult r = jinjava.renderForResult(template, args); + + if (r.hasErrors()) + logger.warn("Prompt template errors: " + r.getErrors()); + + return r.getOutput(); } } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Result.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Result.java new file mode 100644 index 0000000..f7ced16 --- /dev/null +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Result.java @@ -0,0 +1,29 @@ +package com.github.tjake.jlama.safetensors.prompt; + +import com.fasterxml.jackson.annotation.JsonPropertyOrder; +import com.github.tjake.jlama.util.JsonSupport; + +/** + * Result + */ +@JsonPropertyOrder({Result.JSON_PROPERTY_OUTPUT}) +public class Result { + public static final String JSON_PROPERTY_OUTPUT = "output"; + private final Object output; + + private Result(Object output) { + this.output = output; + } + + public static Result from(Object output) { + return new Result(output); + } + + public Object getOutput() { + return output; + } + + public String toJson() { + return JsonSupport.toJson(this); + } +} diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Tool.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Tool.java new file mode 100644 index 0000000..6535063 --- /dev/null +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/Tool.java @@ -0,0 +1,34 @@ +package com.github.tjake.jlama.safetensors.prompt; + +import com.fasterxml.jackson.annotation.JsonPropertyOrder; + + +/** + * Tool + */ +@JsonPropertyOrder({ + Tool.JSON_PROPERTY_TYPE, + Tool.JSON_PROPERTY_FUNCTION}) +public class Tool { + public static final String JSON_PROPERTY_TYPE = "type"; + private final String type = "function"; + + public static final String JSON_PROPERTY_FUNCTION = "function"; + private final Function function; + + private Tool(Function function) { + this.function = function; + } + + public static Tool from(Function function) { + return new Tool(function); + } + + public Function getFunction() { + return function; + } + + public String getType() { + return type; + } +} diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/ToolCall.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/ToolCall.java new file mode 100644 index 0000000..2cc0f77 --- /dev/null +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/ToolCall.java @@ -0,0 +1,25 @@ +package com.github.tjake.jlama.safetensors.prompt; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Map; + +public class ToolCall { + private final String name; + private final Map parameters; + + @JsonCreator + public ToolCall(@JsonProperty("name") String name, @JsonProperty("parameters") Map parameters) { + this.name = name; + this.parameters = parameters; + } + + public String getName() { + return name; + } + + public Map getParameters() { + return parameters; + } +} diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/BPETokenizer.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/BPETokenizer.java index 9c1a5d5..ce416f8 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/BPETokenizer.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/BPETokenizer.java @@ -16,6 +16,7 @@ package com.github.tjake.jlama.safetensors.tokenizer; import com.github.tjake.jlama.safetensors.SafeTensorSupport; +import com.github.tjake.jlama.safetensors.prompt.PromptSupport; import com.google.common.base.Preconditions; import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; @@ -68,6 +69,11 @@ protected BPETokenizer(Path modelRoot) { } } + @Override + public TokenizerModel getModel() { + return model; + } + @Override public List tokenize(String sentence) { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/Tokenizer.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/Tokenizer.java index 7a60400..5fe9e14 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/Tokenizer.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/Tokenizer.java @@ -15,6 +15,8 @@ */ package com.github.tjake.jlama.safetensors.tokenizer; +import com.github.tjake.jlama.safetensors.prompt.PromptSupport; + import java.util.List; import java.util.Optional; @@ -56,4 +58,11 @@ public interface Tokenizer { * @return prompt support */ Optional promptSupport(); + + + /** + * Get the model for this tokenizer (expert mode) + * @return tokenizer model + */ + TokenizerModel getModel(); } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/TokenizerModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/TokenizerModel.java index d47a012..1d1d998 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/TokenizerModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/TokenizerModel.java @@ -19,11 +19,11 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.github.tjake.jlama.safetensors.prompt.PromptSupport; import com.google.common.base.Preconditions; import com.google.common.collect.*; import java.util.*; import java.util.regex.Matcher; -import java.util.regex.Pattern; import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -141,7 +141,7 @@ public void setLegacy(boolean legacy) { this.legacy = legacy; } - Optional> promptTemplates() { + public Optional> promptTemplates() { return promptTemplates; } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/WordPieceTokenizer.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/WordPieceTokenizer.java index 8c877cf..52a8107 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/WordPieceTokenizer.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/WordPieceTokenizer.java @@ -16,6 +16,7 @@ package com.github.tjake.jlama.safetensors.tokenizer; import com.github.tjake.jlama.safetensors.SafeTensorSupport; +import com.github.tjake.jlama.safetensors.prompt.PromptSupport; import com.google.common.base.Preconditions; import java.io.IOException; import java.nio.file.Path; @@ -62,6 +63,11 @@ public WordPieceTokenizer(Path modelRoot) { this.unkToken = model.vocabLookup.get(unkString); } + @Override + public TokenizerModel getModel() { + return model; + } + @Override public List tokenize(String sentence) { diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/util/JsonSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/util/JsonSupport.java index 1388c35..bfdf14d 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/util/JsonSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/util/JsonSupport.java @@ -29,4 +29,12 @@ public class JsonSupport { .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) .configure(DeserializationFeature.FAIL_ON_MISSING_CREATOR_PROPERTIES, false) .enable(MapperFeature.ACCEPT_CASE_INSENSITIVE_ENUMS); + + public static String toJson(Object o) { + try { + return om.writeValueAsString(o); + } catch (Exception e) { + throw new RuntimeException(e); + } + } } diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java index b808f9b..402c311 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java @@ -114,7 +114,7 @@ public Generator.Response generate( int promptLength = encoded.length; if (useEOS) { - promptTokens[promptTokens.length - 1] = model.getConfig().eosToken; // Add EOS + promptTokens[promptTokens.length - 1] = model.getConfig().eosTokens.getLast(); // Add EOS promptLength++; } @@ -138,7 +138,7 @@ public Generator.Response generate( tokensGenerated++; // Model may tell us it's done - if (next == model.getConfig().eosToken) { + if (model.getConfig().eosTokens.contains(next)) { finishReason = FinishReason.STOP_TOKEN; break; } diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java index ab3e51f..b00afab 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java @@ -4,15 +4,12 @@ import com.github.tjake.jlama.model.functions.Generator; import com.github.tjake.jlama.net.openai.model.*; -import com.github.tjake.jlama.safetensors.tokenizer.PromptSupport; +import com.github.tjake.jlama.safetensors.prompt.PromptSupport; import jakarta.validation.Valid; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.HttpStatus; -import org.springframework.http.HttpStatusCode; -import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.validation.annotation.Validated; -import org.springframework.web.HttpMediaTypeNotAcceptableException; import org.springframework.web.bind.annotation.*; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @@ -67,7 +64,7 @@ Object createChatCompletion( UUID sessionId = id; - PromptSupport.Builder builder = model.promptSupport().get().newBuilder(); + PromptSupport.Builder builder = model.promptSupport().get().builder(); for (ChatCompletionRequestMessage m : messages) { diff --git a/jlama-net/src/test/java/com/github/tjake/jlama/net/JlamaServiceTest.java b/jlama-net/src/test/java/com/github/tjake/jlama/net/JlamaServiceTest.java index 3fbd698..1308062 100644 --- a/jlama-net/src/test/java/com/github/tjake/jlama/net/JlamaServiceTest.java +++ b/jlama-net/src/test/java/com/github/tjake/jlama/net/JlamaServiceTest.java @@ -27,8 +27,9 @@ import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.safetensors.TensorInfo; import com.github.tjake.jlama.safetensors.WeightLoader; -import com.github.tjake.jlama.safetensors.tokenizer.PromptSupport; +import com.github.tjake.jlama.safetensors.prompt.PromptSupport; import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer; +import com.github.tjake.jlama.safetensors.tokenizer.TokenizerModel; import com.github.tjake.jlama.tensor.AbstractTensor; import com.github.tjake.jlama.util.Pair; import com.google.protobuf.ByteString; @@ -141,7 +142,7 @@ public MockConfig( layerNormEps, 32000, 1, - 2, + List.of(2), ActivationFunction.Type.SILU, 10000.0, 1.0); @@ -199,6 +200,11 @@ public String decode(long[] ids) { public Optional promptSupport() { return Optional.empty(); } + + @Override + public TokenizerModel getModel() { + return null; + } } public static class MockModel extends AbstractModel { diff --git a/jlama-tests/pom.xml b/jlama-tests/pom.xml index ccd6861..2020a77 100644 --- a/jlama-tests/pom.xml +++ b/jlama-tests/pom.xml @@ -40,6 +40,7 @@ 1.35 test + diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/Mocks.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/Mocks.java index cbf6161..b0e9a70 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/Mocks.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/Mocks.java @@ -22,8 +22,9 @@ import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.safetensors.TensorInfo; import com.github.tjake.jlama.safetensors.WeightLoader; -import com.github.tjake.jlama.safetensors.tokenizer.PromptSupport; +import com.github.tjake.jlama.safetensors.prompt.PromptSupport; import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer; +import com.github.tjake.jlama.safetensors.tokenizer.TokenizerModel; import com.github.tjake.jlama.tensor.AbstractTensor; import com.github.tjake.jlama.util.Pair; import java.util.Collections; @@ -57,7 +58,7 @@ public MockConfig( layerNormEps, 32000, 1, - 2, + List.of(2), ActivationFunction.Type.SILU, 10000.0, 1.0); @@ -115,6 +116,11 @@ public String decode(long[] ids) { public Optional promptSupport() { return Optional.empty(); } + + @Override + public TokenizerModel getModel() { + return null; + } } public static class MockModel extends AbstractModel { diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java index d412862..c46f745 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java @@ -23,7 +23,7 @@ import com.github.tjake.jlama.model.gemma.GemmaTokenizer; import com.github.tjake.jlama.model.gpt2.GPT2Tokenizer; import com.github.tjake.jlama.model.llama.LlamaTokenizer; -import com.github.tjake.jlama.safetensors.tokenizer.PromptSupport; +import com.github.tjake.jlama.safetensors.prompt.*; import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer; import com.github.tjake.jlama.safetensors.tokenizer.WordPieceTokenizer; import com.google.common.io.Resources; @@ -32,7 +32,9 @@ import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.concurrent.ThreadLocalRandom; + import org.junit.Assert; import org.junit.Assume; import org.junit.Test; @@ -310,7 +312,7 @@ public void testPromptSupport() { Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); Tokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); - PromptSupport.Builder builder = tokenizer.promptSupport().get().newBuilder(); + PromptSupport.Builder builder = tokenizer.promptSupport().get().builder(); builder.addSystemMessage("You are a friendly chatbot who always responds in the style of a pirate"); builder.addUserMessage("How many helicopters can a human eat in one sitting?"); @@ -337,13 +339,13 @@ public void testPromptSupport2() { Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); Tokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); - PromptSupport.Builder builder = tokenizer.promptSupport().get().newBuilder(); + PromptSupport.Builder builder = tokenizer.promptSupport().get().builder(); builder.addSystemMessage("You are a friendly chatbot who always responds in the style of a pirate."); builder.addUserMessage("How many helicopters can a human eat in one sitting?"); String prompt = builder.build(); Assert.assertEquals( - "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + "<|start_header_id|>system<|end_header_id|>\n\n" + "You are a friendly chatbot who always responds in the style of a pirate.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" + "How many helicopters can a human eat in one sitting?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", prompt); @@ -359,4 +361,82 @@ public void testPromptSupport2() { Assert.assertEquals(prompt, out); Assert.assertArrayEquals(expected, encoded); } + + @Test + public void testPromptSupportWithTools() { + String modelPrefix = "../models/Meta-Llama-3.1-8B-Instruct"; + Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); + + Tokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); + PromptSupport.Builder builder = tokenizer.promptSupport().get().builder(); + builder.addSystemMessage("You always respond as a pirate"); + builder.addUserMessage("What is the weather in paris right now?"); + builder.addGenerationPrompt(true); + + Tool t = Tool.from(Function.builder() + .name("get_temperature") + .description("Simulates getting the current temperature at a location.") + .addParameter("location", "string", "The location to get the temperature for, in the format \"City, Country\".", true) + .addParameter("unit", "string", "The unit to return the temperature in (e.g., \"celsius\", \"fahrenheit\").", true) + .build()); + + builder.addTools(t); + + builder.addToolCall(new ToolCall("get_temperature", Map.of("location", "paris, france", "unit", "celsius"))); + + builder.addToolResult(Result.from(Map.of("temperature", 25.0, "unit", "celsius"))); + + String prompt = builder.build(); + Assert.assertEquals( + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n" + "\n" + + "Environment: ipython\n" + + "Cutting Knowledge Date: December 2023\n" + + "Today Date: 26 Jul 2024\n" + + "\n" + + "You always respond as a pirate<|eot_id|><|start_header_id|>user<|end_header_id|>\n" + + "\n" + + "Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n" + + "\n" + + "Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n" + + "\n" + + "{\n" + + " \"type\": \"function\",\n" + + " \"function\": {\n" + + " \"name\": \"get_current_temperature\",\n" + + " \"description\": \"Simulates getting the current temperature at a location.\",\n" + + " \"parameters\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"location\": {\n" + + " \"type\": \"string\",\n" + + " \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\".\"\n" + + " },\n" + + " \"unit\": {\n" + + " \"type\": \"string\",\n" + + " \"description\": \"The unit to return the temperature in (e.g., \\\"celsius\\\", \\\"fahrenheit\\\").\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"location\",\n" + + " \"unit\"\n" + + " ]\n" + + " }\n" + + " }\n" + + "}\n" + + "\n" + + "What is the weather in paris right now?<|eot_id|>", + prompt); + + long[] encoded = tokenizer.encode(prompt); + long[] expected = new long[] { + 128000, 128006, 9125, 128007, 271, 2675, 527, 264, 11919, 6369, 6465, 889, 2744, 31680, 304, 279, 1742, 315, + 264, 55066, 128009, 128006, 882, 128007, 271, 4438, 1690, 59432, 649, 264, 3823, 8343, 304, 832, 11961, 30, + 128009, 128006, 78191, 128007, 271 + }; + + String out = tokenizer.decode(encoded); + Assert.assertEquals(prompt, out); + Assert.assertArrayEquals(expected, encoded); + } + } diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java index 11cf267..c7fdd3d 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java @@ -21,6 +21,7 @@ import com.github.tjake.jlama.model.bert.BertConfig; import com.github.tjake.jlama.model.bert.BertModel; import com.github.tjake.jlama.model.bert.BertTokenizer; +import com.github.tjake.jlama.model.functions.Generator; import com.github.tjake.jlama.model.gemma.GemmaConfig; import com.github.tjake.jlama.model.gemma.GemmaModel; import com.github.tjake.jlama.model.gemma.GemmaTokenizer; @@ -35,11 +36,12 @@ import com.github.tjake.jlama.model.mixtral.MixtralConfig; import com.github.tjake.jlama.model.mixtral.MixtralModel; import com.github.tjake.jlama.safetensors.*; +import com.github.tjake.jlama.safetensors.prompt.*; import com.github.tjake.jlama.safetensors.tokenizer.BPETokenizer; -import com.github.tjake.jlama.safetensors.tokenizer.PromptSupport; import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer; import com.github.tjake.jlama.tensor.AbstractTensor; import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; +import com.github.tjake.jlama.util.JsonSupport; import com.github.tjake.jlama.util.Pair; import com.google.common.primitives.Ints; import com.google.common.util.concurrent.AtomicDouble; @@ -98,16 +100,55 @@ public void GPT2Run() throws IOException { @Test public void LlamaRun() throws Exception { - String modelPrefix = "../models/Meta-Llama-3.1-8B-Instruct"; + String modelPrefix = "../models/Meta-Llama-3.1-8B-Instruct-jlama-Q4"; Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); try (WeightLoader weights = SafeTensorSupport.loadWeights(Path.of(modelPrefix).toFile())) { LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); Config c = om.readValue(new File(modelPrefix + "/config.json"), LlamaConfig.class); - LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.BF16, Optional.empty()); + LlamaModel model = new LlamaModel(c, weights, tokenizer, DType.F32, DType.F32, Optional.empty()); - String p = model.promptSupport().get().newBuilder().addUserMessage("Tell me a joke.").build(); - model.generate(UUID.randomUUID(), p, 0.3f, 256, false, makeOutHandler()); + PromptSupport.Builder builder = model.promptSupport().get().builder(); + + builder.addSystemMessage("You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the original question."); + builder.addUserMessage("What is the temp in paris right now?"); + builder.addGenerationPrompt(true); + + Tool t = Tool.from(Function.builder() + .name("get_current_temperature") + .description("Simulates getting the current temperature at a location.") + .addParameter("location", "string", "The location to get the temperature for, in the format \"City, Country\".", true) + .addParameter("unit", "string", "The unit to return the temperature in (e.g., \"celsius\", \"fahrenheit\").", true) + .build()); + + builder.addTools(t); + + logger.info("First prompt \n{}", builder.build()); + Generator.Response r = + model.generate(UUID.randomUUID(), builder.build(), 0.3f, 1024, false, (l,f) ->{}); + + if (r.finishReason == Generator.FinishReason.STOP_TOKEN) { + + if (builder.hasTools()) { + if (r.text.trim().startsWith("{") || r.text.startsWith("<|python_tag|>")) { + // Call the tool + // This is where you would call the tool function + // and then call generate again with the result + // of the tool function + + ToolCall f = JsonSupport.om.readValue(r.text.replace("<|python_tag|>", ""), ToolCall.class); + logger.info("Calling tool: {}", f.getName()); + + builder.addToolCall(f); + + builder.addToolResult(Result.from(20f)); + logger.info("Second prompt {}", builder.build()); + Generator.Response r2 = model.generate(UUID.randomUUID(), builder.build(), 0.0f, 1024, false, (l,p) ->{}); + + System.err.println(r2.text); + } + } + } } } @@ -138,7 +179,7 @@ public void MistralRun() throws Exception { String p = model.promptSupport().isEmpty() ? prompt : model.promptSupport() .get() - .newBuilder() + .builder() .addUserMessage(prompt) .build(); model.generate(UUID.randomUUID(), "[INST] Tell me a joke. [/INST]Assistant", 0.0f, 64, false, makeOutHandler()); @@ -162,7 +203,7 @@ public void MixtralRun() throws Exception { String prompt = "Tell me a joke."; String p = model.promptSupport() .get() - .newBuilder() + .builder() .addUserMessage(prompt) .build(); @@ -182,7 +223,7 @@ public void GemmaRun() throws Exception { String prompt = "Tell me a joke."; String p = model.promptSupport() .get() - .newBuilder() + .builder() .addUserMessage(prompt) .build(); model.generate(UUID.randomUUID(), p, 0.3f, 256, false, makeOutHandler()); diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestSample.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestSample.java index 7972e42..8fb4485 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestSample.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestSample.java @@ -42,7 +42,7 @@ public void sample() throws IOException { if (m.promptSupport().isPresent()) { prompt = m.promptSupport() .get() - .newBuilder() + .builder() .addSystemMessage("You are a helpful chatbot who writes short responses.") .addUserMessage(prompt) .build();