Skip to content

Commit

Permalink
fix: several issues with agents
Browse files Browse the repository at this point in the history
  • Loading branch information
dosco committed Jun 25, 2024
1 parent 47f78d6 commit 5800ff0
Show file tree
Hide file tree
Showing 11 changed files with 192 additions and 152 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@ax-llm/ax",
"version": "9.0.10",
"version": "9.0.11",
"type": "module",
"description": "The best library to work with LLMs",
"typings": "build/module/src/index.d.ts",
Expand Down
41 changes: 27 additions & 14 deletions src/ai/groq/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { AxRateLimiterTokenUsage } from '../../util/rate-limit.js';
import { axBaseAIDefaultConfig } from '../base.js';
import { AxAIOpenAI } from '../openai/api.js';
import type { AxAIOpenAIConfig } from '../openai/types.js';
import type { AxAIServiceOptions } from '../types.js';
import type { AxAIServiceOptions, AxRateLimiterFunction } from '../types.js';

import { AxAIGroqModel } from './types.js';

Expand Down Expand Up @@ -35,23 +35,11 @@ export class AxAIGroq extends AxAIOpenAI {
...config
};

let rateLimiter = options?.rateLimiter;
if (!rateLimiter) {
const tokensPerMin = options?.tokensPerMinute ?? 5800;
const rt = new AxRateLimiterTokenUsage(tokensPerMin, tokensPerMin / 60);

rateLimiter = async (func, info) => {
const totalTokens = info.modelUsage?.totalTokens || 0;
await rt.acquire(totalTokens);
return func();
};
}

const _options = {
...options,
rateLimiter,
streamingUsage: false
};

super({
apiKey,
config: _config,
Expand All @@ -61,5 +49,30 @@ export class AxAIGroq extends AxAIOpenAI {
});

super.setName('Groq');
this.setOptions(_options);
}

override setOptions = (options: Readonly<AxAIServiceOptions>) => {
const rateLimiter = this.newRateLimiter(options);
super.setOptions({ ...options, rateLimiter });
};

private newRateLimiter = (options: Readonly<AxAIGroqArgs['options']>) => {
if (options?.rateLimiter) {
return options.rateLimiter;
}

const tokensPerMin = options?.tokensPerMinute ?? 4800;
const rt = new AxRateLimiterTokenUsage(tokensPerMin, tokensPerMin / 60, {
debug: options?.debug
});

const rtFunc: AxRateLimiterFunction = async (func, info) => {
const totalTokens = info.modelUsage?.totalTokens || 0;
await rt.acquire(totalTokens);
return await func();
};

return rtFunc;
};
}
2 changes: 1 addition & 1 deletion src/dsp/generate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ test('extractValues with json', (t) => {
});

test('extractValues with text values', (t) => {
const sig = new AxSignature(`text -> title, keyPoints, description`);
const sig = new AxSignature(`someText -> title, keyPoints, description`);
const v1 = {};
extractValues(
sig,
Expand Down
27 changes: 20 additions & 7 deletions src/dsp/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ import type {
AxChatRequest,
AxChatResponse,
AxChatResponseResult,
AxFunction
AxFunction,
AxRateLimiterFunction
} from '../ai/types.js';
import { mergeFunctionCalls } from '../ai/util.js';
import {
type AxChatResponseFunctionCall,
AxFunctionProcessor
} from '../funcs/functions.js';
import { type AxAIMemory, AxMemory } from '../mem/index.js';
import { type AxSpan, AxSpanKind } from '../trace/index.js';
import { type AxSpan, AxSpanKind, type AxTracer } from '../trace/index.js';

import {
assertAssertions,
Expand Down Expand Up @@ -41,6 +42,15 @@ import { AxSignature } from './sig.js';
import { validateValue } from './util.js';

export interface AxGenerateOptions {
maxCompletions?: number;
maxRetries?: number;
maxSteps?: number;
mem?: AxAIMemory;
tracer?: AxTracer;
rateLimiter?: AxRateLimiterFunction;
stream?: boolean;
debug?: boolean;

functions?: AxFunction[];
functionCall?: AxChatRequest['functionCall'];
promptTemplate?: typeof AxPromptTemplate;
Expand Down Expand Up @@ -363,8 +373,9 @@ export class AxGenerate<
options?: Readonly<AxProgramForwardOptions>,
span?: AxSpan
): Promise<OUT> {
const maxRetries = options?.maxRetries ?? 5;
const mem = options?.mem ?? new AxMemory();
const maxRetries = this.options?.maxRetries ?? options?.maxRetries ?? 5;
const maxSteps = this.options?.maxSteps ?? options?.maxSteps ?? 10;
const mem = this.options?.mem ?? options?.mem ?? new AxMemory();
const canStream = this.ai.getFeatures().streaming;

let err: ValidationError | AxAssertionError | undefined;
Expand All @@ -384,7 +395,7 @@ export class AxGenerate<

for (let i = 0; i < maxRetries; i++) {
try {
for (let n = 0; n < (options?.maxSteps ?? 10); n++) {
for (let n = 0; n < maxSteps; n++) {
const {
sessionId,
traceId,
Expand Down Expand Up @@ -453,7 +464,9 @@ export class AxGenerate<
values: IN,
options?: Readonly<AxProgramForwardOptions>
): Promise<OUT> {
if (!options?.tracer) {
const tracer = this.options?.tracer ?? options?.tracer;

if (!tracer) {
return await this._forward(values, options);
}

Expand All @@ -462,7 +475,7 @@ export class AxGenerate<
['generate.functions']: this.functionList ?? 'none'
};

return await options?.tracer.startActiveSpan(
return await tracer.startActiveSpan(
'Generate',
{
kind: AxSpanKind.SERVER,
Expand Down
51 changes: 47 additions & 4 deletions src/dsp/sig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,6 @@ export class AxSignature {
};

private parseField = (field: Readonly<AxField>): AxIField => {
if (!field.name || field.name.length === 0) {
throw new Error('Field name is required.');
}

const title =
!field.title || field.title.length === 0
? this.toTitle(field.name)
Expand Down Expand Up @@ -157,7 +153,14 @@ export class AxSignature {
};

private updateHash = (): [string, string] => {
this.getInputFields().forEach((field) => {
validateField(field);
if (field.type?.name === 'image') {
throw new Error('Image type is not supported in output fields.');
}
});
this.getOutputFields().forEach((field) => {
validateField(field);
if (field.type?.name === 'image') {
throw new Error('Image type is not supported in output fields.');
}
Expand Down Expand Up @@ -218,3 +221,43 @@ function renderSignature(
// Combine all parts into the final signature.
return `${descriptionPart} ${inputFieldsRendered} -> ${outputFieldsRendered}`;
}

function isValidCase(inputString: string): boolean {
const camelCaseRegex = /^[a-z][a-zA-Z0-9]*$/;
const snakeCaseRegex = /^[a-z]+(_[a-z0-9]+)*$/;

return camelCaseRegex.test(inputString) || snakeCaseRegex.test(inputString);
}

function validateField(field: Readonly<AxField>): void {
if (!field.name || field.name.length === 0) {
throw new Error('Field name cannot be blank');
}

if (!isValidCase(field.name)) {
throw new Error(
'Field name must be in camel case or snake case: ' + field.name
);
}

if (
[
'text',
'object',
'image',
'string',
'number',
'boolean',
'json',
'array',
'date',
'time',
'type'
].includes(field.name)
) {
throw new Error(
'Invalid field name, please make it more descriptive (eg. noteText): ' +
field.name
);
}
}
34 changes: 23 additions & 11 deletions src/examples/agent.ts
Original file line number Diff line number Diff line change
@@ -1,32 +1,44 @@
import { AxAgent, AxAI } from '../index.js';

// const ai = new AxAI({
// name: 'openai',
// apiKey: process.env.OPENAI_APIKEY as string
// });

// const ai = new AxAI({
// name: 'google-gemini',
// apiKey: process.env.GOOGLE_APIKEY as string
// });

const ai = new AxAI({
name: 'openai',
apiKey: process.env.OPENAI_APIKEY as string
name: 'groq',
apiKey: process.env.GROQ_APIKEY as string
});

ai.setOptions({ debug: true });

const researcher = new AxAgent(ai, {
name: 'researcher',
description: 'Researcher agent',
name: 'Physics Researcher',
description:
'Researcher for physics questions can answer questions about advanced physics',
signature: `physicsQuestion "physics questions" -> answer "reply in bullet points"`
});

const summarizer = new AxAgent(ai, {
name: 'summarizer',
description: 'Summarizer agent',
signature: `text "text so summarize" -> shortSummary "summarize in 5 to 10 words"`
name: 'Science Summarizer',
description:
'Summarizer can write short summaries of advanced science topics',
signature: `textToSummarize "text to summarize" -> shortSummary "summarize in 5 to 10 words"`
});

const agent = new AxAgent(ai, {
name: 'agent',
description: 'Agent',
name: 'Scientist',
description: 'An agent that can answer advanced science questions',
signature: `question -> answer`,
agents: [researcher, summarizer]
});

const question = `Why is gravity not a real force? Why is light pure energy? Why is physics scale invariant? Why is the centrifugal force talked about so much if it's not real?
`;
const question = `Why is gravity not a real force? Why is light pure energy? Why is physics scale invariant? Why is the centrifugal force talked about so much if it's not real?`;

const res = await agent.forward({ question });

Expand Down
8 changes: 4 additions & 4 deletions src/examples/food-search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ const functions: AxFunction[] = [
type: 'object',
properties: {
location: {
type: 'string',
description: 'location to get weather for'
description: 'location to get weather for',
type: 'string'
},
units: {
type: 'string',
Expand All @@ -121,8 +121,8 @@ const functions: AxFunction[] = [
type: 'object',
properties: {
location: {
type: 'string',
description: 'location to find restaurants in'
description: 'location to find restaurants in',
type: 'string'
},
outdoor: {
type: 'boolean',
Expand Down
40 changes: 37 additions & 3 deletions src/prompts/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ export class AxAgent<IN extends AxGenIn, OUT extends AxGenOut>
];

const opt = {
promptTemplate: options?.promptTemplate,
asserts: options?.asserts,
...options,
functions: funcs
};

Expand All @@ -70,11 +69,24 @@ export class AxAgent<IN extends AxGenIn, OUT extends AxGenOut>
? new AxReAct<IN, OUT>(ai, sig, opt)
: new AxChainOfThought<IN, OUT>(ai, sig, opt);

if (!name || name.length < 5) {
throw new Error(
'Agent name must be at least 10 characters (more descriptive): ' + name
);
}

if (!description || description.length < 20) {
throw new Error(
'Agent description must be at least 20 characters (explain in detail what the agent does): ' +
description
);
}

this.name = name;
this.description = description;
this.subAgentList = agents?.map((a) => a.getFunction().name).join(', ');
this.func = {
name: this.name,
name: toCamelCase(this.name),
description: this.description,
parameters: sig.toJSONSchema(),
func: () => this.forward
Expand Down Expand Up @@ -119,3 +131,25 @@ export class AxAgent<IN extends AxGenIn, OUT extends AxGenOut>
);
}
}

function toCamelCase(inputString: string): string {
// Split the string by any non-alphanumeric character (including underscores, spaces, hyphens)
const words = inputString.split(/[^a-zA-Z0-9]/);

// Map through each word, capitalize the first letter of each word except the first word
const camelCaseString = words
.map((word, index) => {
// Lowercase the word to handle cases like uppercase letters in input
const lowerWord = word.toLowerCase();

// Capitalize the first letter of each word except the first one
if (index > 0 && lowerWord && lowerWord[0]) {
return lowerWord[0].toUpperCase() + lowerWord.slice(1);
}

return lowerWord;
})
.join('');

return camelCaseString;
}
8 changes: 4 additions & 4 deletions src/prompts/prompts.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { AxSignature } from '../dsp/sig.js';

import { AxChainOfThought } from './cot.js';

const text = `The technological singularity—or simply the singularity[1]—is a hypothetical future point in time at which technological growth becomes uncontrollable and irreversible.`;
const someText = `The technological singularity—or simply the singularity[1]—is a hypothetical future point in time at which technological growth becomes uncontrollable and irreversible.`;

const examples = [
{
Expand Down Expand Up @@ -56,11 +56,11 @@ test('generate prompt', async (t) => {

const gen = new AxChainOfThought(
ai,
`text -> shortSummary "summarize in 5 to 10 words"`
`someText -> shortSummary "summarize in 5 to 10 words"`
);
gen.setExamples(examples);

const res = await gen.forward({ text });
const res = await gen.forward({ someText });

t.deepEqual(res, {
reason: 'Blah blah blah',
Expand All @@ -69,5 +69,5 @@ test('generate prompt', async (t) => {
});

test('generate prompt: invalid signature', async (t) => {
t.throws(() => new AxSignature(`text -> output:image`));
t.throws(() => new AxSignature(`someText -> output:image`));
});
Loading

0 comments on commit 5800ff0

Please sign in to comment.