Skip to content

Commit

Permalink
fixes: #73, #75, #76, #77
Browse files Browse the repository at this point in the history
  • Loading branch information
dosco committed Jan 2, 2025
1 parent 23189d4 commit 2e63e33
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 86 deletions.
7 changes: 5 additions & 2 deletions src/ax/ai/anthropic/types.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import type { AxModelConfig } from '../types.js';

export enum AxAIAnthropicModel {
Claude35Sonnet = 'claude-3-5-sonnet-20240620',
Claude3Opus = 'claude-3-opus-20240229',
Claude35Sonnet = 'claude-3-5-sonnet-latest',
Claude35Haiku = 'claude-3-5-haiku-latest',

Claude3Opus = 'claude-3-opus-latest',
Claude3Sonnet = 'claude-3-sonnet-20240229',
Claude3Haiku = 'claude-3-haiku-20240307',

Claude21 = 'claude-2.1',
ClaudeInstant12 = 'claude-instant-1.2'
}
Expand Down
2 changes: 1 addition & 1 deletion src/ax/ai/ollama/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ export class AxAIOllama extends AxAIOpenAI {
apiKey,
options,
config: _config,
apiURL: new URL('/api', url).href,
apiURL: url,
modelMap
});

Expand Down
6 changes: 3 additions & 3 deletions src/ax/dsp/datetime.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import moment from 'moment-timezone';

import { ValidationError } from './extract.js';
import type { AxField } from './sig.js';
import { AxValidationError } from './validate.js';

// eslint-disable-next-line @typescript-eslint/naming-convention
export function parseLLMFriendlyDate(
Expand All @@ -12,7 +12,7 @@ export function parseLLMFriendlyDate(
return _parseLLMFriendlyDate(dateStr);
} catch (err) {
const message = (err as Error).message;
throw new ValidationError({ field, message, value: dateStr });
throw new AxValidationError({ field, message, value: dateStr });
}
}

Expand All @@ -39,7 +39,7 @@ export function parseLLMFriendlyDateTime(
return _parseLLMFriendlyDateTime(dateStr);
} catch (err) {
const message = (err as Error).message;
throw new ValidationError({ field, message, value: dateStr });
throw new AxValidationError({ field, message, value: dateStr });
}
}

Expand Down
49 changes: 0 additions & 49 deletions src/ax/dsp/extract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import JSON5 from 'json5';

import { parseLLMFriendlyDate, parseLLMFriendlyDateTime } from './datetime.js';
import { toFieldType } from './prompt.js';
import type { AxField, AxSignature } from './sig.js';

export const extractValues = (
Expand Down Expand Up @@ -180,54 +179,6 @@ function validateAndParseFieldValue(
return value;
}

export class ValidationError extends Error {
private field: AxField;
private value: string;

constructor({
message,
field,
value
}: Readonly<{
message: string;
field: AxField;
value: string;
}>) {
super(message);
this.field = field;
this.value = value;
this.name = this.constructor.name;
Error.captureStackTrace(this, this.constructor);
}

public getField = () => this.field;
public getValue = () => this.value;

public getFixingInstructions = () => {
const f = this.field;

const extraFields = [
// {
// name: `past_${f.name}`,
// title: `Past ${f.title}`,
// description: this.value
// },
{
name: `invalidField`,
title: `Invalid Field`,
description: `The field \`${f.title}\` is invalid. Got value: \`${this.value}\`, expected ${toFieldType(f.type)}`
}
// {
// name: 'instructions',
// title: 'Instructions',
// description: this.message
// }
];

return extraFields;
};
}

export const extractBlock = (input: string): string => {
const jsonBlockPattern = /```([A-Za-z]+)?\s*([\s\S]*?)\s*```/g;
const match = jsonBlockPattern.exec(input);
Expand Down
8 changes: 4 additions & 4 deletions src/ax/dsp/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ import {
type extractionState,
extractValues,
streamingExtractFinalValue,
streamingExtractValues,
ValidationError
streamingExtractValues
} from './extract.js';
import {
type AxChatResponseFunctionCall,
Expand All @@ -43,6 +42,7 @@ import {
} from './program.js';
import { AxPromptTemplate } from './prompt.js';
import { AxSignature } from './sig.js';
import { AxValidationError } from './validate.js';

export interface AxGenOptions {
maxCompletions?: number;
Expand Down Expand Up @@ -378,7 +378,7 @@ export class AxGen<
const stream =
options?.stream ?? this.options?.stream ?? modelConfig.stream ?? true;

let err: ValidationError | AxAssertionError | undefined;
let err: AxValidationError | AxAssertionError | undefined;

if (options?.functions && options?.functions.length > 0) {
const promptTemplate = this.options?.promptTemplate ?? AxPromptTemplate;
Expand Down Expand Up @@ -430,7 +430,7 @@ export class AxGen<
let extraFields;
span?.recordAxSpanException(e as Error);

if (e instanceof ValidationError) {
if (e instanceof AxValidationError) {
extraFields = e.getFixingInstructions();
err = e;
} else if (e instanceof AxAssertionError) {
Expand Down
48 changes: 24 additions & 24 deletions src/ax/dsp/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -386,21 +386,17 @@ export class AxPromptTemplate {

private renderFields = (fields: readonly AxField[]) => {
// Header
const header =
'Field Name | Type | Required/Optional | Format | Description';
const header = 'Field Name | Field Type | Required/Optional | Description';
const separator = '|';

// Transform each field into table row
const rows = fields.map((field) => {
const name = field.title;
const type = field.type?.name ? toFieldType(field.type) : 'string';
const required = field.isOptional ? 'optional' : 'required';
const format = field.type?.isArray ? 'array' : 'single';
const description = field.description ?? '';

return [name, type, required, format, description]
.join(` ${separator} `)
.trim();
return [name, type, required, description].join(` ${separator} `).trim();
});

// Combine header and rows
Expand Down Expand Up @@ -442,24 +438,28 @@ const processValue = (

// eslint-disable-next-line @typescript-eslint/naming-convention
export const toFieldType = (type: Readonly<AxField['type']>) => {
switch (type?.name) {
case 'string':
return 'string';
case 'number':
return 'number';
case 'boolean':
return 'boolean';
case 'date':
return 'date ("YYYY-MM-DD" format)';
case 'datetime':
return 'date time ("YYYY-MM-DD HH:mm Timezone" format)';
case 'json':
return 'JSON object';
case 'class':
return `classification class (allowed classes: ${type.classes?.join(', ')})`;
default:
return 'string';
}
const baseType = (() => {
switch (type?.name) {
case 'string':
return 'string';
case 'number':
return 'number';
case 'boolean':
return 'boolean';
case 'date':
return 'date ("YYYY-MM-DD" format)';
case 'datetime':
return 'date time ("YYYY-MM-DD HH:mm Timezone" format)';
case 'json':
return 'JSON object';
case 'class':
return `classification class (allowed classes: ${type.classes?.join(', ')})`;
default:
return 'string';
}
})();

return type?.isArray ? `json array of ${baseType} items` : baseType;
};

function combineConsecutiveStrings(separator: string) {
Expand Down
40 changes: 40 additions & 0 deletions src/ax/dsp/validate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import { toFieldType } from './prompt.js';
import type { AxField } from './sig.js';

export class AxValidationError extends Error {
private field: AxField;
private value: string;

constructor({
message,
field,
value
}: Readonly<{
message: string;
field: AxField;
value: string;
}>) {
super(message);
this.field = field;
this.value = value;
this.name = this.constructor.name;
Error.captureStackTrace(this, this.constructor);
}

public getField = () => this.field;
public getValue = () => this.value;

public getFixingInstructions = () => {
const f = this.field;

const extraFields = [
{
name: `invalidField`,
title: `Invalid Field`,
description: `Ensure the field \`${f.title}\` is of type \`${toFieldType(f.type)}\``
}
];

return extraFields;
};
}
2 changes: 1 addition & 1 deletion src/examples/balancer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ const ai2 = new AxAI({
}
});

const gen = new AxChainOfThought(
const gen = new AxChainOfThought<{ textToSummarize: string }>(
`textToSummarize -> shortSummary "summarize in 5 to 10 words"`
);

Expand Down
2 changes: 1 addition & 1 deletion src/examples/streaming1.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { AxAI, AxChainOfThought } from '@ax-llm/ax';
// });

// setup the prompt program
const gen = new AxChainOfThought(
const gen = new AxChainOfThought<{ startNumber: number }>(
`startNumber:number -> next10Numbers:number[]`
);

Expand Down
4 changes: 3 additions & 1 deletion src/examples/streaming2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import { AxAI, AxChainOfThought } from '@ax-llm/ax';
// });

// setup the prompt program
const gen = new AxChainOfThought(`question:string -> answerInPoints:string`);
const gen = new AxChainOfThought<{ question: string }>(
`question:string -> answerInPoints:string`
);

// add a assertion to ensure all lines start with a number and a dot.
gen.addStreamingAssert(
Expand Down
63 changes: 63 additions & 0 deletions src/examples/streaming3.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import { AxAI, AxChainOfThought } from '@ax-llm/ax';

// Setup the prompt program for movie reviews
const gen = new AxChainOfThought<{ movieTitle: string }>(
`movieTitle:string ->
rating:number,
genres:string[],
strengths:string[],
weaknesses:string[],
recommendedAudience:string,
verdict:string`
);

// Assert rating is between 1 and 10
gen.addAssert(({ rating }: Readonly<{ rating: number }>) => {
if (!rating) return undefined;
return rating >= 1 && rating <= 10;
}, 'Rating must be between 1 and 10');

// Assert there are between 1-3 genres
gen.addAssert(({ genres }: Readonly<{ genres: string[] }>) => {
if (!genres) return undefined;
return genres.length >= 1 && genres.length <= 3;
}, 'Must specify between 1-3 genres');

// Assert strengths and weaknesses are balanced (similar length arrays)
gen.addAssert(
({
strengths,
weaknesses
}: Readonly<{ strengths: string[]; weaknesses: string[] }>) => {
if (!strengths || !weaknesses) return undefined;
const diff = Math.abs(strengths.length - weaknesses.length);
return diff <= 1;
},
'Review should be balanced with similar number of strengths and weaknesses'
);

// Assert verdict is not too short
gen.addAssert(({ verdict }: Readonly<{ verdict: string }>) => {
if (!verdict) return undefined;
return verdict.length >= 50;
}, 'Verdict must be at least 50 characters');

// Assert recommended audience doesn't mention specific age numbers
gen.addAssert(
({ recommendedAudience }: Readonly<{ recommendedAudience: string }>) => {
if (!recommendedAudience) return undefined;
return !/\d+/.test(recommendedAudience);
},
'Use age groups (e.g. "teens", "adults") instead of specific ages'
);

const ai = new AxAI({
name: 'google-gemini',
apiKey: process.env.GOOGLE_APIKEY as string
});
// ai.setOptions({ debug: true });

// Run the program
const res = await gen.forward(ai, { movieTitle: 'The Grand Budapest Hotel' });

console.log('>', res);

0 comments on commit 2e63e33

Please sign in to comment.