From 8c0ade90bd39134882563c717a2205adf30a7d61 Mon Sep 17 00:00:00 2001 From: dosco <832235+dosco@users.noreply.github.com> Date: Mon, 16 Dec 2024 00:50:44 -0800 Subject: [PATCH] fix: vertex ai and streaming fixes --- src/ax/ai/google-gemini/api.ts | 26 ++++++++++++++++++++------ src/ax/dsp/generate.ts | 31 +++++++++++++++++++++++++++---- src/ax/prompts/prompts.test.ts | 4 ++-- src/ax/prompts/react.ts | 17 ++++++++--------- src/examples/food-search.ts | 16 ++++++++-------- 5 files changed, 65 insertions(+), 29 deletions(-) diff --git a/src/ax/ai/google-gemini/api.ts b/src/ax/ai/google-gemini/api.ts index 8fd01bc..43587ef 100644 --- a/src/ax/ai/google-gemini/api.ts +++ b/src/ax/ai/google-gemini/api.ts @@ -105,6 +105,7 @@ export class AxAIGoogleGemini extends AxBaseAI< private options?: AxAIGoogleGeminiArgs['options']; private config: AxAIGoogleGeminiConfig; private apiKey: string; + private isVertex: boolean; constructor({ apiKey, @@ -118,10 +119,17 @@ export class AxAIGoogleGemini extends AxBaseAI< throw new Error('GoogleGemini AI API key not set'); } - let apiURL = 'https://generativelanguage.googleapis.com/v1beta'; + const isVertex = projectId !== undefined && region !== undefined; - if (projectId && region) { - apiURL = `POST https://${region}-aiplatform.googleapis.com/v1/projects/${projectId}/locations/{REGION}/publishers/google/`; + let apiURL; + let headers; + + if (isVertex) { + apiURL = `https://${region}-aiplatform.googleapis.com/v1/projects/${projectId}/locations/${region}/publishers/google/`; + headers = { Authorization: `Bearer ${apiKey}` }; + } else { + apiURL = 'https://generativelanguage.googleapis.com/v1beta'; + headers = {}; } const _config = { @@ -132,7 +140,7 @@ export class AxAIGoogleGemini extends AxBaseAI< super({ name: 'GoogleGeminiAI', apiURL, - headers: {}, + headers, modelInfo: axModelInfoGoogleGemini, models: { model: _config.model as AxAIGoogleGeminiModel, @@ -145,6 +153,7 @@ export class AxAIGoogleGemini extends AxBaseAI< this.options = options; this.config = _config; this.apiKey = apiKey; + this.isVertex = isVertex; } override getModelConfig(): AxModelConfig { @@ -169,10 +178,15 @@ export class AxAIGoogleGemini extends AxBaseAI< const apiConfig = { name: stream - ? `/models/${model}:streamGenerateContent?alt=sse&key=${this.apiKey}` - : `/models/${model}:generateContent?key=${this.apiKey}` + ? `/models/${model}:streamGenerateContent?alt=sse` + : `/models/${model}:generateContent` }; + if (this.isVertex === false) { + const pf = stream ? '&' : '?'; + apiConfig.name += `${pf}key=${this.apiKey}`; + } + const systemPrompts = req.chatPrompt .filter((p) => p.role === 'system') .map((p) => p.content); diff --git a/src/ax/dsp/generate.ts b/src/ax/dsp/generate.ts index c92192e..185733a 100644 --- a/src/ax/dsp/generate.ts +++ b/src/ax/dsp/generate.ts @@ -6,6 +6,7 @@ import type { AxChatResponse, AxChatResponseResult, AxFunction, + AxModelConfig, AxRateLimiterFunction } from '../ai/types.js'; import { mergeFunctionCalls } from '../ai/util.js'; @@ -209,14 +210,15 @@ export class AxGen< modelConfig, model, rateLimiter, - stream = false, + stream, functions, functionCall }: Readonly< - Omit & { + Omit & { sig: Readonly; ai: Readonly; mem: AxAIMemory; + stream: boolean; } >): Promise { const usageInfo = { @@ -406,7 +408,15 @@ export class AxGen< const maxRetries = options?.maxRetries ?? this.options?.maxRetries ?? 5; const maxSteps = options?.maxSteps ?? this.options?.maxSteps ?? 10; const mem = options?.mem ?? this.options?.mem ?? new AxMemory(); + + const modelConfig = mergeAxModelConfigs( + ai.getModelConfig(), + options?.modelConfig ?? {} + ); + const canStream = ai.getFeatures(options?.model).streaming; + const stream = + options?.stream ?? this.options?.stream ?? modelConfig.stream ?? true; let err: ValidationError | AxAssertionError | undefined; @@ -431,9 +441,9 @@ export class AxGen< mem, sessionId: options?.sessionId, traceId: options?.traceId, - modelConfig: options?.modelConfig, + modelConfig, model: options?.model, - stream: canStream && options?.stream, + stream: canStream && stream, maxSteps: options?.maxSteps, rateLimiter: options?.rateLimiter, functions: options?.functions, @@ -537,3 +547,16 @@ export class AxGen< ); } } + +function mergeAxModelConfigs( + baseConfig: Readonly, + overrideConfig: Readonly +): AxModelConfig { + return { + ...baseConfig, + ...overrideConfig, + // Merge arrays to avoid overriding entirely + stopSequences: overrideConfig.stopSequences ?? baseConfig.stopSequences, + endSequences: overrideConfig.endSequences ?? baseConfig.endSequences + }; +} diff --git a/src/ax/prompts/prompts.test.ts b/src/ax/prompts/prompts.test.ts index 135f425..6e7cded 100644 --- a/src/ax/prompts/prompts.test.ts +++ b/src/ax/prompts/prompts.test.ts @@ -44,11 +44,11 @@ const mockFetch = async (): Promise => { }; test('generate prompt', async (t) => { - const options = { fetch: mockFetch }; const ai = new AxAI({ name: 'openai', apiKey: 'no-key', - options + options: { fetch: mockFetch }, + config: { stream: false } }); // const ai = new AxAI({ name: 'ollama', config: { model: 'nous-hermes2' } }); diff --git a/src/ax/prompts/react.ts b/src/ax/prompts/react.ts index f344c97..5986d24 100644 --- a/src/ax/prompts/react.ts +++ b/src/ax/prompts/react.ts @@ -12,23 +12,22 @@ export class AxReAct< signature: Readonly, options: Readonly ) { - if (!options?.functions || options.functions.length === 0) { - throw new Error('No functions provided'); - } - - const fnNames = options.functions.map((f) => { + const fnNames = options.functions?.map((f) => { if ('toFunction' in f) { return f.toFunction().name; } return f.name; }); - const funcList = fnNames.map((fname) => `'${fname}'`).join(', '); + const funcList = fnNames?.map((fname) => `'${fname}'`).join(', '); const sig = new AxSignature(signature); - sig.setDescription( - `Use the following functions ${funcList} to complete the task and return the result. The functions must be used to resolve the final result values` - ); + + if (funcList && funcList.length > 0) { + sig.setDescription( + `Use the following functions ${funcList} to complete the task and return the result. The functions must be used to resolve the final result values` + ); + } // sig.addInputField({ // name: 'observation', diff --git a/src/examples/food-search.ts b/src/examples/food-search.ts index 108c568..8a0f064 100644 --- a/src/examples/food-search.ts +++ b/src/examples/food-search.ts @@ -140,17 +140,17 @@ const functions: AxFunction[] = [ } ]; -const ai = new AxAI({ - name: 'openai', - apiKey: process.env.OPENAI_APIKEY as string -}); - // const ai = new AxAI({ -// name: 'google-gemini', -// apiKey: process.env.GOOGLE_APIKEY as string -// // config: { model: 'gemini-2.0-flash-exp' } +// name: 'openai', +// apiKey: process.env.OPENAI_APIKEY as string // }); +const ai = new AxAI({ + name: 'google-gemini', + apiKey: process.env.GOOGLE_APIKEY as string, + config: { model: 'gemini-2.0-flash-exp' } +}); + // const ai = new AxAI({ // name: 'groq', // apiKey: process.env.GROQ_APIKEY as string