Skip to content

Commit

Permalink
Ollama LLM provider tools support #14610
Browse files Browse the repository at this point in the history
  • Loading branch information
dhuebner committed Jan 10, 2025
1 parent d8022a1 commit 218c7f8
Showing 1 changed file with 134 additions and 36 deletions.
170 changes: 134 additions & 36 deletions packages/ai-ollama/src/node/ollama-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import {
LanguageModelRequest,
LanguageModelRequestMessage,
LanguageModelResponse,
LanguageModelStreamResponse,
LanguageModelStreamResponsePart,
ToolCall,
ToolRequest
} from '@theia/ai-core';
import { CancellationToken } from '@theia/core';
Expand All @@ -31,7 +33,9 @@ export const OllamaModelIdentifier = Symbol('OllamaModelIdentifier');
export class OllamaModel implements LanguageModel {

protected readonly DEFAULT_REQUEST_SETTINGS: Partial<Omit<ChatRequest, 'stream' | 'model'>> = {
keep_alive: '15m'
keep_alive: '15m',
// options see: https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
options: {}
};

readonly providerId = 'ollama';
Expand All @@ -50,62 +54,125 @@ export class OllamaModel implements LanguageModel {
public defaultRequestSettings?: { [key: string]: unknown }
) { }

async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
const settings = this.getSettings(request);
const ollama = this.initializeOllama();

const ollamaRequest: ExtendedChatRequest = {
model: this.model,
...this.DEFAULT_REQUEST_SETTINGS,
...settings,
messages: request.messages.map(this.toOllamaMessage),
tools: request.tools?.map(this.toOllamaTool)
};
const structured = request.response_format?.type === 'json_schema';
return this.dispatchRequest(ollama, ollamaRequest, structured, cancellationToken);
}

/**
* Retrieves the settings for the chat request, merging the request-specific settings with the default settings.
* @param request The language model request containing specific settings.
* @returns A partial ChatRequest object containing the merged settings.
*/
protected getSettings(request: LanguageModelRequest): Partial<ChatRequest> {
const settings = request.settings ?? this.defaultRequestSettings ?? {};
return {
options: settings as Partial<Options>
};
}

async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
const settings = this.getSettings(request);
const ollama = this.initializeOllama();
protected async dispatchRequest(ollama: Ollama, ollamaRequest: ExtendedChatRequest, structured: boolean, cancellation?: CancellationToken): Promise<LanguageModelResponse> {

// Handle structured output request
if (structured) {
return this.handleStructuredOutputRequest(ollama, ollamaRequest);
}

if (request.response_format?.type === 'json_schema') {
return this.handleStructuredOutputRequest(ollama, request);
// Handle tool request - response may call tools
if (ollamaRequest.tools && ollamaRequest.tools?.length > 0) {
return this.handleToolsRequest(ollama, ollamaRequest);
}

// Handle standard chat request
const response = await ollama.chat({
model: this.model,
...this.DEFAULT_REQUEST_SETTINGS,
...settings,
messages: request.messages.map(this.toOllamaMessage),
stream: true,
tools: request.tools?.map(this.toOllamaTool),
...ollamaRequest,
stream: true
});
return this.handleCancellationAndWrapIterator(response, cancellation);
}

cancellationToken?.onCancellationRequested(() => {
response.abort();
protected async handleToolsRequest(ollama: Ollama, chatRequest: ExtendedChatRequest, prevResponse?: ChatResponse): Promise<LanguageModelResponse> {
const response = prevResponse || await ollama.chat({
...chatRequest,
stream: false
});

async function* wrapAsyncIterator<T>(inputIterable: AsyncIterable<ChatResponse>): AsyncIterable<LanguageModelStreamResponsePart> {
for await (const item of inputIterable) {
// TODO handle tool calls
yield { content: item.message.content };
if (response.message.tool_calls) {
const tools: ToolWithHandler[] = chatRequest.tools ?? [];
// Add response message to chat history
chatRequest.messages.push(response.message);
const tool_calls: ToolCall[] = [];
for (const [idx, toolCall] of response.message.tool_calls.entries()) {
const functionToCall = tools.find(tool => tool.function.name === toolCall.function.name);
if (functionToCall) {
const args = JSON.stringify(toolCall.function?.arguments);
const funcResult = await functionToCall.handler(args);
chatRequest.messages.push({
role: 'tool',
content: `Tool call ${functionToCall.function.name} returned: ${String(funcResult)}`,
});
let resultString = String(funcResult);
if (resultString.length > 1000) {
// truncate result string if it is too long
resultString = resultString.substring(0, 1000) + '...';
}
tool_calls.push({
id: `ollama_${response.created_at}_${idx}`,
function: {
name: functionToCall.function.name,
arguments: Object.values(toolCall.function?.arguments ?? {}).join(', ')
},
result: resultString,
finished: true
});
}
}
// Get final response from model with function outputs
const finalResponse = await ollama.chat({ ...chatRequest, stream: false });
if (finalResponse.message.tool_calls) {
// If the final response also calls tools, recursively handle them
return this.handleToolsRequest(ollama, chatRequest, finalResponse);
}
return { stream: this.createAsyncIterable([{ tool_calls }, { content: finalResponse.message.content }]) };
}
return { stream: wrapAsyncIterator(response) };
return { text: response.message.content };
}

protected async handleStructuredOutputRequest(ollama: Ollama, request: LanguageModelRequest): Promise<LanguageModelParsedResponse> {
const settings = this.getSettings(request);
const result = await ollama.chat({
...settings,
...this.DEFAULT_REQUEST_SETTINGS,
model: this.model,
messages: request.messages.map(this.toOllamaMessage),
protected createAsyncIterable<T>(items: T[]): AsyncIterable<T> {
return {
[Symbol.asyncIterator]: async function* (): AsyncIterableIterator<T> {
for (const item of items) {
yield item;
}
}
};
}

protected async handleStructuredOutputRequest(ollama: Ollama, chatRequest: ChatRequest): Promise<LanguageModelParsedResponse> {
const response = await ollama.chat({
...chatRequest,
format: 'json',
stream: false,
});
try {
return {
content: result.message.content,
parsed: JSON.parse(result.message.content)
content: response.message.content,
parsed: JSON.parse(response.message.content)
};
} catch (error) {
// TODO use ILogger
console.log('Failed to parse structured response from the language model.', error);
return {
content: result.message.content,
content: response.message.content,
parsed: {}
};
}
Expand All @@ -119,11 +186,21 @@ export class OllamaModel implements LanguageModel {
return new Ollama({ host: host });
}

protected toOllamaTool(tool: ToolRequest): Tool {
const transform = (props: Record<string, {
[key: string]: unknown;
type: string;
}> | undefined) => {
protected handleCancellationAndWrapIterator(response: AbortableAsyncIterable<ChatResponse>, token?: CancellationToken): LanguageModelStreamResponse {
token?.onCancellationRequested(() => {
// maybe it is better to use ollama.abort() as we are using one client per request
response.abort();
});
async function* wrapAsyncIterator<T>(inputIterable: AsyncIterable<ChatResponse>): AsyncIterable<LanguageModelStreamResponsePart> {
for await (const item of inputIterable) {
yield { content: item.message.content };
}
}
return { stream: wrapAsyncIterator(response) };
}

protected toOllamaTool(tool: ToolRequest): ToolWithHandler {
const transform = (props: Record<string, { [key: string]: unknown; type: string; }> | undefined) => {
if (!props) {
return undefined;
}
Expand All @@ -148,7 +225,8 @@ export class OllamaModel implements LanguageModel {
required: Object.keys(tool.parameters?.properties ?? {}),
properties: transform(tool.parameters?.properties) ?? {}
},
}
},
handler: tool.handler
};
}

Expand All @@ -165,3 +243,23 @@ export class OllamaModel implements LanguageModel {
return { role: 'system', content: '' };
}
}

/**
* Extended Tool containing a handler
* @see Tool
*/
type ToolWithHandler = Tool & { handler: (arg_string: string) => Promise<unknown> };

/**
* Extended chat request with mandatory messages and ToolWithHandler tools
*
* @see ChatRequest
* @see ToolWithHandler
*/
type ExtendedChatRequest = ChatRequest & {
messages: Message[]
tools?: ToolWithHandler[]
};

// Ollama doesn't export this type, so we have to define it here
type AbortableAsyncIterable<T> = AsyncIterable<T> & { abort: () => void };

0 comments on commit 218c7f8

Please sign in to comment.