Skip to content

Commit

Permalink
feat: improve grammar support (#215)
Browse files Browse the repository at this point in the history
* feat: improve grammar support
* feat: improve JSON schema grammar
  • Loading branch information
giladgd authored May 12, 2024
1 parent c6a80c0 commit d321fe3
Show file tree
Hide file tree
Showing 32 changed files with 591 additions and 148 deletions.
2 changes: 1 addition & 1 deletion .config/typedoc.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@
"propertiesFormat": "list",
"enumMembersFormat": "table",
"typeDeclarationFormat": "list",
"hideInPageTOC": true,
"sort": ["source-order"],
"docsRoot": "../docs"
}
4 changes: 3 additions & 1 deletion .releaserc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ export default {
}
}],
"@semantic-release/npm",
"@semantic-release/github",
["@semantic-release/github", {
"discussionCategoryName": "Releases"
}],
["@semantic-release/exec", {
"publishCmd": "echo \"${nextRelease.version}\" > .semanticRelease.npmPackage.deployedVersion.txt"
}]
Expand Down
7 changes: 5 additions & 2 deletions .vitepress/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ export default defineConfig({
...(await fs.readJSON(path.join(__dirname, "..", "tsconfig.json"))).compilerOptions,
moduleResolution: undefined,
paths: {
"node-llama-cpp": [path.resolve(__dirname, "..", "src", "index.ts")]
"node-llama-cpp": [
path.resolve(__dirname, "..", "dist", "index.d.ts"),
path.resolve(__dirname, "..", "src", "index.ts")
]
},
typeRoots: [
path.resolve(__dirname, "..", "node_modules"),
Expand All @@ -157,7 +160,7 @@ export default defineConfig({
},
nav: [
{text: "Guide", link: "/guide/", activeMatch: "/guide/"},
{text: "API Reference", link: "/api/classes/LlamaModel", activeMatch: "/api/"},
{text: "API Reference", link: "/api/functions/getLlama", activeMatch: "/api/"},
{
text: packageVersion,
items: [{
Expand Down
30 changes: 15 additions & 15 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"dev:build": "npm run build && node ./dist/cli/cli.js build --noUsageExample",
"clean": "rm -rf ./node_modules ./dist ./tsconfig.tsbuildinfo ./test/.models ./docs/api ./docs/api-overrides",
"docs:generateTypedoc": "typedoc && rimraf ./docs/api/index.md ./docs/api/globals.md ./docs/api/functions/LlamaText.md && npm run docs:generateTypedoc:overrides",
"docs:generateTypedoc:overrides": "typedoc --entryPoints ./src/apiDocsOverrides.ts --out ./docs/api-overrides && copyfiles --flat \"./docs/api-overrides/classes/*.md\" ./docs/api/classes && rimraf ./docs/api-overrides",
"docs:generateTypedoc:overrides": "typedoc --entryPoints ./src/apiDocsOverrides.ts --out ./docs/api-overrides && copyfiles --flat \"./docs/api-overrides/classes/LlamaText.md\" ./docs/api/classes && rimraf ./docs/api-overrides",
"docs:dev": "npm run docs:generateTypedoc && vitepress dev",
"docs:build": "npm run docs:generateTypedoc && vitepress build",
"docs:preview": "npm run docs:generateTypedoc && vitepress preview"
Expand Down Expand Up @@ -139,10 +139,10 @@
"rimraf": "^5.0.1",
"semantic-release": "^22.0.8",
"tslib": "^2.6.1",
"typedoc": "^0.25.3",
"typedoc": "^0.25.13",
"typedoc-plugin-markdown": "^4.0.0-next.55",
"typedoc-plugin-mdn-links": "^3.1.19",
"typedoc-vitepress-theme": "^1.0.0-next.10",
"typedoc-plugin-mdn-links": "^3.1.24",
"typedoc-vitepress-theme": "1.0.0-next.10",
"typescript": "^5.2.2",
"vite-node": "^1.4.0",
"vitepress": "^1.1.4",
Expand Down
23 changes: 21 additions & 2 deletions src/apiDocsOverrides.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@
/** @internal */
import {_LlamaText} from "./utils/LlamaText.js";
import {Tokenizer} from "./types.js";
/** @internal */
import {
_LlamaText,
type LlamaTextJSON,
type LlamaTextJSONValue,
type LlamaTextSpecialTokenJSON,
type LlamaTextSpecialTokensTextJSON,
type LlamaTextValue,
type LlamaTextInputValue
} from "./utils/LlamaText.js";

/** @internal */
export {_LlamaText as LlamaText};
export {
_LlamaText as LlamaText,
type Tokenizer,
type LlamaTextJSON,
type LlamaTextJSONValue,
type LlamaTextSpecialTokensTextJSON,
type LlamaTextSpecialTokenJSON,
type LlamaTextValue,
type LlamaTextInputValue
};
15 changes: 15 additions & 0 deletions src/bindings/Llama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ import {DisposedError, EventRelay, withLock} from "lifecycle-utils";
import {getConsoleLogPrefix} from "../utils/getConsoleLogPrefix.js";
import {LlamaModel, LlamaModelOptions} from "../evaluator/LlamaModel.js";
import {DisposeGuard} from "../utils/DisposeGuard.js";
import {GbnfJsonSchema} from "../utils/gbnfJson/types.js";
import {LlamaJsonSchemaGrammar} from "../evaluator/LlamaJsonSchemaGrammar.js";
import {LlamaGrammar, LlamaGrammarOptions} from "../evaluator/LlamaGrammar.js";
import {BindingModule} from "./AddonTypes.js";
import {BuildGpu, BuildMetadataFile, LlamaLocks, LlamaLogLevel} from "./types.js";
import {MemoryOrchestrator, MemoryReservation} from "./utils/MemoryOrchestrator.js";
Expand Down Expand Up @@ -229,6 +232,18 @@ export class Llama {
});
}

public async createGrammarForJsonSchema<const T extends Readonly<GbnfJsonSchema>>(schema: T) {
return new LlamaJsonSchemaGrammar<T>(this, schema);
}

public async getGrammarFor(type: Parameters<typeof LlamaGrammar.getFor>[1]) {
return await LlamaGrammar.getFor(this, type);
}

public async createGrammar(options: LlamaGrammarOptions) {
return new LlamaGrammar(this, options);
}

/** @internal */
public async _init() {
await this._bindings.init();
Expand Down
2 changes: 1 addition & 1 deletion src/bindings/getLlama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ const defaultBuildOption: Exclude<LlamaOptions["build"], undefined> = runningInE
: "auto";

/**
* Get a llama.cpp binding.
* Get a `llama.cpp` binding.
* Defaults to prefer a prebuilt binary, and fallbacks to building from source if a prebuilt binary is not found.
* Pass `"lastCliBuild"` to default to use the last successful build created using the `download` or `build` CLI commands if one exists.
*/
Expand Down
9 changes: 7 additions & 2 deletions src/evaluator/LlamaChat/LlamaChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ export class LlamaChat {
pendingTokens.push(...streamRegulator.popFreeChunkTokens());

const triggeredStops = functionSyntaxStartDetector.getTriggeredStops();
const partiallyFreeTokens = streamRegulator.getPartiallyFreeChunk();
const partiallyFreeTokens = streamRegulator.getPartiallyFreeChunk(model.tokenizer);

const queuedTokensBeforeStopTrigger = getQueuedTokensBeforeStopTrigger(
triggeredStops,
Expand Down Expand Up @@ -775,10 +775,15 @@ export class LlamaChat {
if (stopGenerationDetector.hasTriggeredStops || customStopGenerationTriggersDetector.hasTriggeredStops ||
model.isEogToken(token)
) {
stopGenerationDetector.clearInProgressStops();
customStopGenerationTriggersDetector.clearInProgressStops();
pendingTokens.push(...streamRegulator.popFreeChunkTokens());

const triggeredStops = stopGenerationDetector.hasTriggeredStops
? stopGenerationDetector.getTriggeredStops()
: customStopGenerationTriggersDetector.getTriggeredStops();
const partiallyFreeTokens = streamRegulator.getPartiallyFreeChunk();

const partiallyFreeTokens = streamRegulator.getPartiallyFreeChunk(model.tokenizer);

const queuedTokensBeforeStopTrigger = getQueuedTokensBeforeStopTrigger(
triggeredStops,
Expand Down
3 changes: 1 addition & 2 deletions src/evaluator/LlamaChat/utils/FunctionCallGrammar.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ export class FunctionCallGrammar<const Functions extends ChatModelFunctions> ext
public constructor(llama: Llama, functions: Functions, chatWrapper: ChatWrapper, initialFunctionCallEngaged: boolean) {
const grammar = getGbnfGrammarForFunctionCalls(functions, chatWrapper, initialFunctionCallEngaged);

super({
llama,
super(llama, {
grammar,
stopGenerationTriggers: [LlamaText(chatWrapper.settings.functions.call.suffix, "\n".repeat(4))],
trimWhitespaceSuffix: true
Expand Down
2 changes: 1 addition & 1 deletion src/evaluator/LlamaCompletion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ export class LlamaCompletion {

if (stopGenerationDetector.hasTriggeredStops || model.isEogToken(token)) {
const triggeredStops = stopGenerationDetector.getTriggeredStops();
const partiallyFreeTokens = streamRegulator.getPartiallyFreeChunk();
const partiallyFreeTokens = streamRegulator.getPartiallyFreeChunk(model.tokenizer);

const queuedTokensBeforeStopTrigger = getQueuedTokensBeforeStopTrigger(
triggeredStops,
Expand Down
10 changes: 4 additions & 6 deletions src/evaluator/LlamaGrammar.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import {Llama} from "../bindings/Llama.js";


export type LlamaGrammarOptions = {
llama: Llama,

/** GBNF grammar */
grammar: string,

Expand All @@ -35,10 +33,11 @@ export class LlamaGrammar {
* > More info here: [
* github:ggerganov/llama.cpp:grammars/README.md
* ](https://github.com/ggerganov/llama.cpp/blob/f5fe98d11bdf9e7797bcfb05c0c3601ffc4b9d26/grammars/README.md)
* @param llama
* @param options
*/
public constructor({
llama, grammar, stopGenerationTriggers = [], trimWhitespaceSuffix = false, printGrammar = false
public constructor(llama: Llama, {
grammar, stopGenerationTriggers = [], trimWhitespaceSuffix = false, printGrammar = false
}: LlamaGrammarOptions) {
this._llama = llama;
this._grammar = new this._llama._bindings.AddonGrammar(grammar, {
Expand Down Expand Up @@ -69,8 +68,7 @@ export class LlamaGrammar {

if (await fs.pathExists(grammarFile)) {
const grammar = await fs.readFile(grammarFile, "utf8");
return new LlamaGrammar({
llama,
return new LlamaGrammar(llama, {
grammar,
stopGenerationTriggers: [LlamaText(["\n".repeat(10)])], // this is a workaround for the model not stopping to generate text,
trimWhitespaceSuffix: true
Expand Down
5 changes: 2 additions & 3 deletions src/evaluator/LlamaJsonSchemaGrammar.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import {GbnfJsonSchema, GbnfJsonSchemaToType} from "../utils/gbnfJson/types.js";
import {getGbnfGrammarForGbnfJsonSchema} from "../utils/getGbnfGrammarForGbnfJsonSchema.js";
import {getGbnfGrammarForGbnfJsonSchema} from "../utils/gbnfJson/getGbnfGrammarForGbnfJsonSchema.js";
import {validateObjectAgainstGbnfSchema} from "../utils/gbnfJson/utils/validateObjectAgainstGbnfSchema.js";
import {LlamaText} from "../utils/LlamaText.js";
import {Llama} from "../bindings/Llama.js";
Expand All @@ -11,8 +11,7 @@ export class LlamaJsonSchemaGrammar<const T extends Readonly<GbnfJsonSchema>> ex
public constructor(llama: Llama, schema: T) {
const grammar = getGbnfGrammarForGbnfJsonSchema(schema);

super({
llama,
super(llama, {
grammar,
stopGenerationTriggers: [LlamaText(["\n".repeat(4)])],
trimWhitespaceSuffix: true
Expand Down
11 changes: 8 additions & 3 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ import {
} from "./chatWrappers/utils/resolveChatWrapper.js";
import {ChatModelFunctionsDocumentationGenerator} from "./chatWrappers/utils/ChatModelFunctionsDocumentationGenerator.js";
import {
LlamaText, SpecialTokensText, SpecialToken, isLlamaText, tokenizeText, type LlamaTextJSON, type LlamaTextJSONValue,
type LlamaTextSpecialTokensTextJSON, type LlamaTextSpecialTokenJSON
LlamaText, SpecialTokensText, SpecialToken, isLlamaText, tokenizeText, type LlamaTextValue, type LlamaTextInputValue,
type LlamaTextJSON, type LlamaTextJSONValue, type LlamaTextSpecialTokensTextJSON, type LlamaTextSpecialTokenJSON
} from "./utils/LlamaText.js";
import {appendUserMessageToChatHistory} from "./utils/appendUserMessageToChatHistory.js";
import {getModuleVersion} from "./utils/getModuleVersion.js";
Expand All @@ -62,7 +62,8 @@ import {createModelDownloader, ModelDownloader, type ModelDownloaderOptions} fro
import {
type ChatHistoryItem, type ChatModelFunctionCall, type ChatModelFunctions, type ChatModelResponse,
type ChatSessionModelFunction, type ChatSessionModelFunctions, type ChatSystemMessage, type ChatUserMessage,
type Token, isChatModelResponseFunctionCall, type LLamaContextualRepeatPenalty, type ChatWrapperSettings
type Token, type Tokenizer, type Detokenizer, isChatModelResponseFunctionCall, type LLamaContextualRepeatPenalty,
type ChatWrapperSettings
} from "./types.js";
import {
type GbnfJsonArraySchema, type GbnfJsonBasicSchema, type GbnfJsonConstSchema, type GbnfJsonEnumSchema, type GbnfJsonObjectSchema,
Expand Down Expand Up @@ -161,6 +162,8 @@ export {
SpecialToken,
isLlamaText,
tokenizeText,
type LlamaTextValue,
type LlamaTextInputValue,
type LlamaTextJSON,
type LlamaTextJSONValue,
type LlamaTextSpecialTokensTextJSON,
Expand All @@ -176,6 +179,8 @@ export {
type ChatSystemMessage,
type ChatUserMessage,
type Token,
type Tokenizer,
type Detokenizer,
isChatModelResponseFunctionCall,
type GbnfJsonSchema,
type GbnfJsonSchemaToType,
Expand Down
2 changes: 1 addition & 1 deletion src/utils/LlamaText.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import {Token, Tokenizer} from "../types.js";
import type {InspectOptions, inspect as InspectFunction} from "node:util";

type LlamaTextInputValue = LlamaTextValue | LlamaText | number | boolean | readonly LlamaTextInputValue[];
export type LlamaTextValue = string | SpecialTokensText | SpecialToken;
export type LlamaTextInputValue = LlamaTextValue | LlamaText | number | boolean | readonly LlamaTextInputValue[];

export type LlamaTextJSON = Array<LlamaTextJSONValue>;
export type LlamaTextJSONValue = string | LlamaTextSpecialTokensTextJSON | LlamaTextSpecialTokenJSON;
Expand Down
Loading

0 comments on commit d321fe3

Please sign in to comment.