From 2eed0cc89b5ad4ecd6cdcd5b0c6eb92b8cdddff4 Mon Sep 17 00:00:00 2001 From: Eugen Neufeld Date: Fri, 10 Jan 2025 13:23:10 +0100 Subject: [PATCH] fix: improve cancel logic in openAi model The cancel of a request was not working correctly. With the changes the cancelToken is better taken into account. --- .../src/node/openai-language-model.ts | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/packages/ai-openai/src/node/openai-language-model.ts b/packages/ai-openai/src/node/openai-language-model.ts index dec9b61784fc2..c173c7e2f68fa 100644 --- a/packages/ai-openai/src/node/openai-language-model.ts +++ b/packages/ai-openai/src/node/openai-language-model.ts @@ -69,6 +69,9 @@ export class OpenAiModel implements LanguageModel { if (request.response_format?.type === 'json_schema' && this.supportsStructuredOutput()) { return this.handleStructuredOutputRequest(openai, request); } + if (cancellationToken?.isCancellationRequested) { + return { text: '' }; + } let runner: ChatCompletionStream; const tools = this.createTools(request); @@ -95,42 +98,57 @@ export class OpenAiModel implements LanguageModel { let runnerEnd = false; - let resolve: (part: LanguageModelStreamResponsePart) => void; + let resolve: ((part: LanguageModelStreamResponsePart) => void) | undefined; runner.on('error', error => { console.error('Error in OpenAI chat completion stream:', error); runnerEnd = true; - resolve({ content: error.message }); + resolve?.({ content: error.message }); }); // we need to also listen for the emitted errors, as otherwise any error actually thrown by the API will not be caught runner.emitted('error').then(error => { console.error('Error in OpenAI chat completion stream:', error); runnerEnd = true; - resolve({ content: error.message }); + resolve?.({ content: error.message }); }); runner.emitted('abort').then(() => { - // do nothing, as the abort event is only emitted when the runner is aborted by us + // cancel async iterator + runnerEnd = true; }); runner.on('message', message => { if (message.role === 'tool') { - resolve({ tool_calls: [{ id: message.tool_call_id, finished: true, result: this.getCompletionContent(message) }] }); + resolve?.({ tool_calls: [{ id: message.tool_call_id, finished: true, result: this.getCompletionContent(message) }] }); } console.debug('Received Open AI message', JSON.stringify(message)); }); runner.once('end', () => { runnerEnd = true; // eslint-disable-next-line @typescript-eslint/no-explicit-any - resolve(runner.finalChatCompletion as any); + resolve?.(runner.finalChatCompletion as any); }); + if (cancellationToken?.isCancellationRequested) { + return { text: '' }; + } const asyncIterator = { async *[Symbol.asyncIterator](): AsyncIterator { runner.on('chunk', chunk => { - if (chunk.choices[0]?.delta) { + if (cancellationToken?.isCancellationRequested) { + resolve = undefined; + return; + } + if (resolve && chunk.choices[0]?.delta) { resolve({ ...chunk.choices[0]?.delta }); } }); while (!runnerEnd) { + if (cancellationToken?.isCancellationRequested) { + throw new Error('Iterator canceled'); + } const promise = new Promise((res, rej) => { resolve = res; + cancellationToken?.onCancellationRequested(() => { + rej(new Error('Canceled')); + runnerEnd = true; // Stop the iterator + }); }); yield promise; }