Skip to content

Commit

Permalink
Merge pull request #44 from tjake/tool-calling
Browse files Browse the repository at this point in the history
Add tool calling support
  • Loading branch information
tjake authored Aug 18, 2024
2 parents 229146e + 0144505 commit d234716
Show file tree
Hide file tree
Showing 29 changed files with 592 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import com.diogonunes.jcolor.AnsiFormat;
import com.diogonunes.jcolor.Attribute;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.safetensors.tokenizer.PromptSupport;
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;
import java.io.PrintWriter;
import java.util.Optional;
import java.util.Scanner;
Expand Down Expand Up @@ -84,7 +84,7 @@ public void run() {
break;
}

PromptSupport.Builder builder = promptSupport.newBuilder();
PromptSupport.Builder builder = promptSupport.builder();
if (first && systemPrompt != null) {
builder.addSystemMessage(systemPrompt);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.tokenizer.PromptSupport;
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.KvBufferCache;
Expand Down Expand Up @@ -330,7 +330,7 @@ public Response generate(
promptLength = encoded.length;

if (useEOS) {
promptTokens[promptTokens.length - 1] = c.eosToken; // Add EOS
promptTokens[promptTokens.length - 1] = getConfig().eosTokens.getLast(); // Add EOS
promptLength++;
}

Expand Down Expand Up @@ -371,14 +371,14 @@ public Response generate(
if (logger.isTraceEnabled()) logger.trace("Sampled token {} with temperature {}", next, temperature);
output.close();

kvmem.setMetadata(KvBufferCache.TOKEN_COUNT, i);

// Model may tell us it's done
if (next == c.eosToken) {
if (c.eosTokens.contains(next)) {
reason = FinishReason.STOP_TOKEN;
break;
}

kvmem.setMetadata(KvBufferCache.TOKEN_COUNT, i);

try {
String c = tokenizer.decode(next);
genMsPerToken = (System.currentTimeMillis() - start) / (float) (tokensGenerated);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.safetensors.Config;

import java.util.List;

public class BertConfig extends Config {
@JsonCreator
public BertConfig(
Expand All @@ -41,7 +43,7 @@ public BertConfig(
layerNormEps,
vocabularySize,
0,
0,
List.of(0),
activationFunction,
null,
null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public interface Generator {
enum FinishReason {
MAX_TOKENS,
STOP_TOKEN,
TOOL_CALL,
ERROR
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.safetensors.Config;

import java.util.List;
import java.util.Map;

public class GemmaConfig extends Config {
Expand Down Expand Up @@ -47,7 +49,7 @@ public GemmaConfig(
layerNormEps,
vocabularySize,
bosToken,
eosToken,
List.of(eosToken),
activationFunction,
ropeFreqsTheta == null ? 10000.0 : ropeFreqsTheta,
ropeScaling == null ? 1.0 : Double.parseDouble(ropeScaling.get("factor")));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.safetensors.Config;

import java.util.List;

public class GPT2Config extends Config {

@JsonCreator
Expand All @@ -42,7 +44,7 @@ public GPT2Config(
layerNormEps,
vocabularySize,
bosToken,
eosToken,
List.of(eosToken),
ActivationFunction.Type.GELU,
null,
null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public LlamaConfig(
layerNormEps,
vocabularySize,
bosToken,
eosToken instanceof List ? ((List<Integer>)eosToken).get(((List<Integer>)eosToken).size() - 1) : (Integer) eosToken, //for llama3.1
eosToken instanceof List<?> ? (List<Integer>) eosToken : List.of((Integer) eosToken),
activationFunction,
ropeFreqsTheta == null ? 10000.0 : ropeFreqsTheta,
ropeScaling == null || !("linear".equals(ropeScaling.get("rope_type"))) ? 1.0 : Double.parseDouble(ropeScaling.get("factor")));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.safetensors.Config;

import java.util.List;

public class MistralConfig extends Config {
@JsonCreator
public MistralConfig(
Expand All @@ -46,7 +48,7 @@ public MistralConfig(
layerNormEps,
vocabularySize,
bosToken,
eosToken,
List.of(eosToken),
activationFunction,
ropeTheta,
1.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.safetensors.Config;

import java.util.List;

public class MixtralConfig extends Config {

public final int numberOfExperts;
Expand Down Expand Up @@ -52,7 +54,7 @@ public MixtralConfig(
layerNormEps,
vocabularySize,
bosToken,
eosToken,
List.of(eosToken),
activationFunction,
ropeTheta,
1.0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.common.base.Preconditions;
import com.google.common.io.Files;
import java.io.File;
import java.util.List;
import java.util.Optional;

public class Config {
Expand All @@ -39,7 +40,7 @@ public class Config {
public final float layerNormEps;
public final int vocabularySize;
public final int bosToken;
public final int eosToken;
public final List<Integer> eosTokens;
public final Optional<float[][]> ropeFreqs;
private volatile Optional<Pair<Integer, Integer>> offset;
private volatile File workingDirectory;
Expand Down Expand Up @@ -69,7 +70,7 @@ public Config(
float layerNormEps,
int vocabularySize,
int bosToken,
int eosToken,
List<Integer> eosToken,
ActivationFunction.Type activationFunction,
Double ropeFreqsTheta,
Double ropeScalingFactor) {
Expand Down Expand Up @@ -100,7 +101,7 @@ public Config(
float layerNormEps,
int vocabularySize,
int bosToken,
int eosToken,
List<Integer> eosTokens,
ActivationFunction.Type activationFunction,
Double ropeFreqsTheta,
Double ropeScalingFactor,
Expand All @@ -114,7 +115,7 @@ public Config(
this.layerNormEps = layerNormEps;
this.vocabularySize = vocabularySize;
this.bosToken = bosToken;
this.eosToken = eosToken;
this.eosTokens = eosTokens;
this.tensorCache = TensorCache.instance;
this.headSize = headSize;
this.headGroupSize = numberOfHeads / numberOfKeyValueHeads;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package com.github.tjake.jlama.safetensors.prompt;

import com.fasterxml.jackson.annotation.JsonPropertyOrder;
import com.google.common.base.Preconditions;

import java.util.Map;

/**
* FunctionObject
*/
@JsonPropertyOrder({Function.JSON_PROPERTY_NAME, Function.JSON_PROPERTY_DESCRIPTION, Function.JSON_PROPERTY_PARAMETERS})
public class Function {
public static final String JSON_PROPERTY_NAME = "name";
private final String name;

public static final String JSON_PROPERTY_DESCRIPTION = "description";
private final String description;

public static final String JSON_PROPERTY_PARAMETERS = "parameters";
private final Parameters parameters;

public static Builder builder() {
return new Builder();
}

public static class Builder {
private String name;
private String description = "";

private Parameters.Builder parameters = Parameters.builder();

public Builder description(String description) {
this.description = description;
return this;
}

public Builder name(String name) {
this.name = name;
return this;
}

public Builder addParameter(String name, String type, String description, boolean required) {
this.parameters.addProperty(name, type, description, required);
return this;
}

public Builder addParameter(String name, Map<String, Object> properties, boolean required) {
this.parameters.addProperty(name, properties, required);
return this;
}

public Function build() {
Preconditions.checkNotNull(name, "name is required");

return new Function(name, description, parameters.build());
}
}

private Function(String name, String description, Parameters parameters) {
this.name = name;
this.description = description;
this.parameters = parameters;
}

public String getName() {
return name;
}

public String getDescription() {
return description;
}

public Parameters getParameters() {
return parameters;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package com.github.tjake.jlama.safetensors.prompt;

import com.fasterxml.jackson.annotation.JsonPropertyOrder;

import java.util.*;

/**
* Parameters
*/
@JsonPropertyOrder({
Parameters.JSON_PROPERTY_TYPE,
Parameters.JSON_PROPERTY_PROPERTIES,
Parameters.JSON_PROPERTY_REQUIRED})
public class Parameters {

public static final String JSON_PROPERTY_TYPE = "type";
private final String type = "object";

public static final String JSON_PROPERTY_PROPERTIES = "properties";
private final Map<String, Map<String, Object>> properties;

public static final String JSON_PROPERTY_REQUIRED = "required";
private List<String> required;

public Parameters(Map<String, Map<String, Object>> properties, List<String> required) {
this.properties = properties;
this.required = required;
}

public Parameters(Map<String, Map<String, Object>> properties) {
this.properties = properties;
this.required = null;
}

public static Builder builder() {
return new Builder();
}

public static class Builder {
private Map<String, Map<String, Object>> properties = new LinkedHashMap<>();
private Set<String> required = new LinkedHashSet<>();

public Builder addProperty(String name, String type, String description, boolean required) {
Map<String, Object> properties = new LinkedHashMap<>();
properties.put("type", type);
properties.put("description", description);
return addProperty(name, properties, required);
}

public Builder addProperty(String name, Map<String, Object> properties, boolean required) {
this.properties.put(name, properties);
if (required) {
this.required.add(name);
}
return this;
}

public Parameters build() {
return new Parameters(properties, new ArrayList<>(required));
}
}

public String getType() {
return this.type;
}

public Map<String, Map<String, Object>> getProperties() {
return properties;
}

public List<String> getRequired() {
return required;
}
}
Loading

0 comments on commit d234716

Please sign in to comment.