Skip to content

Commit

Permalink
🆕 add tools manager for OpenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
telesoho committed Sep 15, 2024
1 parent f1f29c4 commit 403bd07
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 24 deletions.
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
"max_tokens": 4096
}
],
"description": "The template for AI completion responses. See:"
"description": "The template for AI completion responses."
},
"MarkdownPaste.openaiCompletionTemplateFile": {
"type": "string",
Expand Down Expand Up @@ -337,7 +337,7 @@
"shelljs": "^0.8.5",
"turndown": "^7.1.2",
"xclip": "^1.0.5",
"openai": "^4.61.0"
"openai": "^4.61.0"
},
"devDependencies": {
"@types/glob": "^7.1.3",
Expand Down
74 changes: 74 additions & 0 deletions src/ToolsManager.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import { ChatCompletionTool } from "openai/resources/chat/completions";
import Logger from "./Logger";

type ToolFunction = (...args: any[]) => any;

interface ToolInfo {
func: ToolFunction;
description: string;
parameters: Record<string, unknown>;
}

export class ToolsManager {
private tools: Map<string, ToolInfo>;

constructor() {
this.tools = new Map();
}

public registerDefaultTools() {
this.registerTool(
"get_current_weather",
({ city }: { city: string }) => {
return JSON.stringify({
city: city,
temperature: "25°C",
weather: "sunny",
});
},
"Get the current weather for a specified city",
{
type: "object",
properties: {
city: { type: "string", description: "The name of the city" },
},
required: ["city"],
}
);
}

public registerTool(
name: string,
func: ToolFunction,
description: string,
parameters: Record<string, unknown>
) {
this.tools.set(name, { func, description, parameters });
}

public executeTool(name: string, args: any): string | null {
const toolInfo = this.tools.get(name);
if (toolInfo) {
try {
return JSON.stringify(toolInfo.func(args));
} catch (error) {
Logger.log(`Error executing tool ${name}:`, error);
return null;
}
} else {
Logger.log(`Tool ${name} not found`);
return null;
}
}

public getToolsForOpenAI(): ChatCompletionTool[] {
return Array.from(this.tools.entries()).map(([toolName, toolInfo]) => ({
type: "function",
function: {
name: toolName,
description: toolInfo.description,
parameters: toolInfo.parameters,
},
}));
}
}
55 changes: 33 additions & 22 deletions src/ai_paster.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@ import OpenAI from "openai";
import {
ChatCompletionMessageParam,
ChatCompletionTool,
ChatCompletionToolMessageParam,
} from "openai/resources/chat/completions";
import { Predefine } from "./predefine";

import Logger from "./Logger";
import { ToolsManager } from "./ToolsManager";

export class AIPaster {
private client: OpenAI;
private toolsManager: ToolsManager;

constructor() {
this.client = new OpenAI(this.config.openaiConnectOption);
this.toolsManager = new ToolsManager();
this.toolsManager.registerDefaultTools();
}

public destructor() {
Expand Down Expand Up @@ -42,27 +44,19 @@ export class AIPaster {
const responseMessages = chatCompletion.choices[0].message;
const toolCalls = chatCompletion.choices[0].message.tool_calls;
if (toolCalls) {
const availableFunctions = {
get_current_weather: function ({ city }: { city: string }) {
return JSON.stringify({
city: city,
temperature: "25°C",
weather: "sunny",
});
},
};
// messages.push(responseMessages);
for (const toolCall of toolCalls) {
const functionName: keyof typeof availableFunctions = toolCall
.function.name as keyof typeof availableFunctions;
const functionToCall = availableFunctions[functionName];
const functionArgs = JSON.parse(toolCall.function.arguments);
const functionResponse = functionToCall(functionArgs);
completion.messages.push({
tool_call_id: toolCall.id,
role: "tool",
content: functionResponse,
});
const functionName = toolCall.function.name;
const functionResponse = this.toolsManager.executeTool(
functionName,
JSON.parse(toolCall.function.arguments)
);
if (functionResponse !== null) {
completion.messages.push({
tool_call_id: toolCall.id,
role: "tool",
content: functionResponse,
});
}
}
completion.messages.forEach((message: ChatCompletionMessageParam) => {
Logger.log(
Expand All @@ -89,6 +83,15 @@ export class AIPaster {
}
}

private mergeToolsByFunctionName(existingTools, newTools) {
const toolMap = new Map();

existingTools.forEach((tool) => toolMap.set(tool.function.name, tool));
newTools.forEach((tool) => toolMap.set(tool.function.name, tool));

return Array.from(toolMap.values());
}

public async callAI(clipboardText: string): Promise<any> {
try {
let openaiCompletionTemplate = this.config.openaiCompletionTemplate;
Expand Down Expand Up @@ -124,6 +127,14 @@ export class AIPaster {
);
}
});
if (completion.tools && Array.isArray(completion.tools)) {
completion.tools = this.mergeToolsByFunctionName(
completion.tools,
this.toolsManager.getToolsForOpenAI()
);
} else {
completion.tools = this.toolsManager.getToolsForOpenAI();
}
let content = await this.runCompletion(completion);
Logger.log("content:", content);
result += content;
Expand Down
173 changes: 173 additions & 0 deletions test/suite/ToolsManager.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import * as assert from "assert";
import { ToolsManager } from "../../src/ToolsManager";
import { ChatCompletionTool } from "openai/resources/chat/completions";

// Defines a Mocha test suite to group tests of similar kind together
suite("ToolsManager Tests", () => {
let toolsManager: ToolsManager;

setup(() => {
toolsManager = new ToolsManager();
});

test("registerTool should add a new tool", () => {
const toolName = "test_tool";
const toolFunc = () => ({ result: "success" });
const toolDescription = "A test tool";
const toolParameters = { type: "object", properties: {} };

toolsManager.registerTool(
toolName,
toolFunc,
toolDescription,
toolParameters
);

const tools = toolsManager.getToolsForOpenAI();
assert.strictEqual(tools.length, 1);
assert.strictEqual(tools[0].function.name, toolName);
assert.strictEqual(tools[0].function.description, toolDescription);
assert.deepStrictEqual(tools[0].function.parameters, toolParameters);
});

test("executeTool should call the registered tool function", () => {
const toolName = "test_tool";
const toolFunc = (args: any) => ({ result: args.input });
const toolDescription = "A test tool";
const toolParameters = { type: "object", properties: {} };

toolsManager.registerTool(
toolName,
toolFunc,
toolDescription,
toolParameters
);

const result = toolsManager.executeTool(toolName, { input: "test" });
assert.strictEqual(result, JSON.stringify({ result: "test" }));
});

test("executeTool should return null for unregistered tool", () => {
const result = toolsManager.executeTool("nonexistent_tool", {});
assert.strictEqual(result, null);
});

test("getToolsForOpenAI should return correct format", () => {
const toolName = "test_tool";
const toolFunc = () => ({});
const toolDescription = "A test tool";
const toolParameters = {
type: "object",
properties: { arg: { type: "string" } },
};

toolsManager.registerTool(
toolName,
toolFunc,
toolDescription,
toolParameters
);

const tools = toolsManager.getToolsForOpenAI();
assert.strictEqual(tools.length, 1);
assert.deepStrictEqual(tools[0], {
type: "function",
function: {
name: toolName,
description: toolDescription,
parameters: toolParameters,
},
});
});

test("registerDefaultTools should register the weather tool", () => {
toolsManager.registerDefaultTools();
const tools = toolsManager.getToolsForOpenAI();
assert.strictEqual(tools.length, 1);
assert.strictEqual(tools[0].function.name, "get_current_weather");
});

test("registerTool should overwrite existing tool with same name", () => {
const toolName = "test_tool";
const toolFunc1 = () => ({ result: "original" });
const toolFunc2 = () => ({ result: "overwritten" });
const toolDescription = "A test tool";
const toolParameters = { type: "object", properties: {} };

toolsManager.registerTool(
toolName,
toolFunc1,
toolDescription,
toolParameters
);
toolsManager.registerTool(
toolName,
toolFunc2,
toolDescription,
toolParameters
);

const result = toolsManager.executeTool(toolName, {});
assert.strictEqual(result, JSON.stringify({ result: "overwritten" }));
});

test("executeTool should handle errors in tool function", () => {
const toolName = "error_tool";
const toolFunc = () => {
throw new Error("Test error");
};
const toolDescription = "A tool that throws an error";
const toolParameters = { type: "object", properties: {} };

toolsManager.registerTool(
toolName,
toolFunc,
toolDescription,
toolParameters
);

const result = toolsManager.executeTool(toolName, {});
assert.strictEqual(result, null);
});

test("getToolsForOpenAI should return empty array when no tools are registered", () => {
const tools = toolsManager.getToolsForOpenAI();
assert.strictEqual(tools.length, 0);
});

test("registerTool should handle complex parameter structures", () => {
const toolName = "complex_tool";
const toolFunc = () => ({});
const toolDescription = "A tool with complex parameters";
const toolParameters = {
type: "object",
properties: {
stringArg: { type: "string" },
numberArg: { type: "number" },
booleanArg: { type: "boolean" },
arrayArg: {
type: "array",
items: { type: "string" },
},
objectArg: {
type: "object",
properties: {
nestedProp: { type: "string" },
},
},
},
required: ["stringArg", "numberArg"],
};

toolsManager.registerTool(
toolName,
toolFunc,
toolDescription,
toolParameters
);

const tools = toolsManager.getToolsForOpenAI();
assert.strictEqual(tools.length, 1);
assert.deepStrictEqual(tools[0].function.parameters, toolParameters);
});
});

0 comments on commit 403bd07

Please sign in to comment.