Skip to content

Commit

Permalink
fix: vertex ai and streaming fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dosco committed Dec 16, 2024
1 parent 905d6c0 commit 8c0ade9
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 29 deletions.
26 changes: 20 additions & 6 deletions src/ax/ai/google-gemini/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ export class AxAIGoogleGemini extends AxBaseAI<
private options?: AxAIGoogleGeminiArgs['options'];
private config: AxAIGoogleGeminiConfig;
private apiKey: string;
private isVertex: boolean;

constructor({
apiKey,
Expand All @@ -118,10 +119,17 @@ export class AxAIGoogleGemini extends AxBaseAI<
throw new Error('GoogleGemini AI API key not set');
}

let apiURL = 'https://generativelanguage.googleapis.com/v1beta';
const isVertex = projectId !== undefined && region !== undefined;

if (projectId && region) {
apiURL = `POST https://${region}-aiplatform.googleapis.com/v1/projects/${projectId}/locations/{REGION}/publishers/google/`;
let apiURL;
let headers;

if (isVertex) {
apiURL = `https://${region}-aiplatform.googleapis.com/v1/projects/${projectId}/locations/${region}/publishers/google/`;
headers = { Authorization: `Bearer ${apiKey}` };
} else {
apiURL = 'https://generativelanguage.googleapis.com/v1beta';
headers = {};
}

const _config = {
Expand All @@ -132,7 +140,7 @@ export class AxAIGoogleGemini extends AxBaseAI<
super({
name: 'GoogleGeminiAI',
apiURL,
headers: {},
headers,
modelInfo: axModelInfoGoogleGemini,
models: {
model: _config.model as AxAIGoogleGeminiModel,
Expand All @@ -145,6 +153,7 @@ export class AxAIGoogleGemini extends AxBaseAI<
this.options = options;
this.config = _config;
this.apiKey = apiKey;
this.isVertex = isVertex;
}

override getModelConfig(): AxModelConfig {
Expand All @@ -169,10 +178,15 @@ export class AxAIGoogleGemini extends AxBaseAI<

const apiConfig = {
name: stream
? `/models/${model}:streamGenerateContent?alt=sse&key=${this.apiKey}`
: `/models/${model}:generateContent?key=${this.apiKey}`
? `/models/${model}:streamGenerateContent?alt=sse`
: `/models/${model}:generateContent`
};

if (this.isVertex === false) {
const pf = stream ? '&' : '?';
apiConfig.name += `${pf}key=${this.apiKey}`;
}

const systemPrompts = req.chatPrompt
.filter((p) => p.role === 'system')
.map((p) => p.content);
Expand Down
31 changes: 27 additions & 4 deletions src/ax/dsp/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import type {
AxChatResponse,
AxChatResponseResult,
AxFunction,
AxModelConfig,
AxRateLimiterFunction
} from '../ai/types.js';
import { mergeFunctionCalls } from '../ai/util.js';
Expand Down Expand Up @@ -209,14 +210,15 @@ export class AxGen<
modelConfig,
model,
rateLimiter,
stream = false,
stream,
functions,
functionCall
}: Readonly<
Omit<AxProgramForwardOptions, 'ai' | 'mem'> & {
Omit<AxProgramForwardOptions, 'ai' | 'mem' | 'stream'> & {
sig: Readonly<AxSignature>;
ai: Readonly<AxAIService>;
mem: AxAIMemory;
stream: boolean;
}
>): Promise<OUT> {
const usageInfo = {
Expand Down Expand Up @@ -406,7 +408,15 @@ export class AxGen<
const maxRetries = options?.maxRetries ?? this.options?.maxRetries ?? 5;
const maxSteps = options?.maxSteps ?? this.options?.maxSteps ?? 10;
const mem = options?.mem ?? this.options?.mem ?? new AxMemory();

const modelConfig = mergeAxModelConfigs(
ai.getModelConfig(),
options?.modelConfig ?? {}
);

const canStream = ai.getFeatures(options?.model).streaming;
const stream =
options?.stream ?? this.options?.stream ?? modelConfig.stream ?? true;

let err: ValidationError | AxAssertionError | undefined;

Expand All @@ -431,9 +441,9 @@ export class AxGen<
mem,
sessionId: options?.sessionId,
traceId: options?.traceId,
modelConfig: options?.modelConfig,
modelConfig,
model: options?.model,
stream: canStream && options?.stream,
stream: canStream && stream,
maxSteps: options?.maxSteps,
rateLimiter: options?.rateLimiter,
functions: options?.functions,
Expand Down Expand Up @@ -537,3 +547,16 @@ export class AxGen<
);
}
}

function mergeAxModelConfigs(
baseConfig: Readonly<AxModelConfig>,
overrideConfig: Readonly<AxModelConfig>
): AxModelConfig {
return {
...baseConfig,
...overrideConfig,
// Merge arrays to avoid overriding entirely
stopSequences: overrideConfig.stopSequences ?? baseConfig.stopSequences,
endSequences: overrideConfig.endSequences ?? baseConfig.endSequences
};
}
4 changes: 2 additions & 2 deletions src/ax/prompts/prompts.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ const mockFetch = async (): Promise<Response> => {
};

test('generate prompt', async (t) => {
const options = { fetch: mockFetch };
const ai = new AxAI({
name: 'openai',
apiKey: 'no-key',
options
options: { fetch: mockFetch },
config: { stream: false }
});

// const ai = new AxAI({ name: 'ollama', config: { model: 'nous-hermes2' } });
Expand Down
17 changes: 8 additions & 9 deletions src/ax/prompts/react.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,22 @@ export class AxReAct<
signature: Readonly<AxSignature | string>,
options: Readonly<AxGenOptions>
) {
if (!options?.functions || options.functions.length === 0) {
throw new Error('No functions provided');
}

const fnNames = options.functions.map((f) => {
const fnNames = options.functions?.map((f) => {
if ('toFunction' in f) {
return f.toFunction().name;
}
return f.name;
});

const funcList = fnNames.map((fname) => `'${fname}'`).join(', ');
const funcList = fnNames?.map((fname) => `'${fname}'`).join(', ');

const sig = new AxSignature(signature);
sig.setDescription(
`Use the following functions ${funcList} to complete the task and return the result. The functions must be used to resolve the final result values`
);

if (funcList && funcList.length > 0) {
sig.setDescription(
`Use the following functions ${funcList} to complete the task and return the result. The functions must be used to resolve the final result values`
);
}

// sig.addInputField({
// name: 'observation',
Expand Down
16 changes: 8 additions & 8 deletions src/examples/food-search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,17 @@ const functions: AxFunction[] = [
}
];

const ai = new AxAI({
name: 'openai',
apiKey: process.env.OPENAI_APIKEY as string
});

// const ai = new AxAI({
// name: 'google-gemini',
// apiKey: process.env.GOOGLE_APIKEY as string
// // config: { model: 'gemini-2.0-flash-exp' }
// name: 'openai',
// apiKey: process.env.OPENAI_APIKEY as string
// });

const ai = new AxAI({
name: 'google-gemini',
apiKey: process.env.GOOGLE_APIKEY as string,
config: { model: 'gemini-2.0-flash-exp' }
});

// const ai = new AxAI({
// name: 'groq',
// apiKey: process.env.GROQ_APIKEY as string
Expand Down

0 comments on commit 8c0ade9

Please sign in to comment.