Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In progress work to add Gemini. #8

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions src/gemini_wrapper.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import { GenerateContentRequest, VertexAI } from "@google-cloud/vertexai";
// import util from "util";

const projectId = "fleet-toolbox-404319";
const location = "us-central1";
const model = "gemini-1.0-pro";

const vertexAI = new VertexAI({ project: projectId, location: location });
const generativeModel = vertexAI.getGenerativeModel({
model: model,
});

function generateMessagesGemini(
input_text: string,
system: string,
examples: [string, string][]
): GenerateContentRequest {
const examples_list = [];
for (const [user, response] of examples) {
examples_list.push(
{ role: "user", parts: [{ text: user }] },
{ role: "assistant", parts: [{ text: response }] }
);
}
return {
contents: [
{ role: "system", parts: [{ text: system }] },
...examples_list,
{ role: "user", parts: [{ text: input_text }] },
],
};
}

export async function queryGemini(
input_text: string,
system: string,
examples: [string, string][]
): Promise<string> {
if (input_text.trim().length === 0) {
throw new TypeError("Input text can't be empty string");
}
const prompt = generateMessagesGemini(input_text, system, examples);
// TODO: add token counting

const resp = await generativeModel.countTokens(prompt);
console.log("count tokens response: ", resp);

// console.log(util.inspect(prompt, false, null, true));
// const prompt2: GenerateContentRequest = {
// contents: [{ role: "user", parts: [{ text: "Tell me 3 dad jokes." }] }],
// };

try {
const response = await generativeModel.generateContent(prompt);
return response.response.candidates[0].content.parts[0].text || "";
} catch (error) {
if (error instanceof Error) {
console.error(`Error with Gemini API request, error: ${error}`);
} else {
console.error(`Error with Gemini API request: ${error}`);
}
}
return "";
}
241 changes: 174 additions & 67 deletions src/llm_runner.ts
Original file line number Diff line number Diff line change
@@ -1,52 +1,99 @@
import { AsyncParser } from '@json2csv/node';
import { program, Option, OptionValues } from 'commander';
import { AsyncParser } from "@json2csv/node";
import { program, Option, OptionValues } from "commander";

import fs = require('node:fs/promises');
import path = require('node:path');
import fs = require("node:fs/promises");
import path = require("node:path");

import { INCENTIVES_FILE_BASE, OUTPUT_FILE_BASE, OUTPUT_SUBDIR, CSV_OPTS } from './constants.js';
import { SYSTEM, EXAMPLE_1_RESPONSE, EXAMPLE_1_USER, EXAMPLE_2_RESPONSE, EXAMPLE_2_USER } from "./prompt.js"
import {
INCENTIVES_FILE_BASE,
OUTPUT_FILE_BASE,
OUTPUT_SUBDIR,
CSV_OPTS,
} from "./constants.js";
import {
SYSTEM,
EXAMPLE_1_RESPONSE,
EXAMPLE_1_USER,
EXAMPLE_2_RESPONSE,
EXAMPLE_2_USER,
} from "./prompt.js";

import { queryPalm } from "./palm_wrapper.js";
import { GptWrapper } from "./gpt_wrapper.js";
import { Metadata } from "./metadata.js"
import { generateHomeownerRenterField } from './post_processing/homeowner_renter.js';
import { cleanupEnumFields } from './post_processing/cleanup_enum_fields.js';

import { Metadata } from "./metadata.js";
import { generateHomeownerRenterField } from "./post_processing/homeowner_renter.js";
import { cleanupEnumFields } from "./post_processing/cleanup_enum_fields.js";
import { queryGemini } from "./gemini_wrapper.js";

program
.requiredOption("-f, --folders <folders...>", "Name of folder(s) under incentives_data/ where text data is located.")
.option("-r, --restrict <restrict_files...>", 'Will process only the files supplied. Useful for re-dos. Give the full path including INCENTIVES_FILE_BASE')
.option("-o, --output_file <file>", 'Name of output file. Saved in the out/ directory.', "output.csv")
.option("-w, --wait <duration_ms>", "How long to wait in ms between requests to avoid rate limiting")
.addOption(new Option("-m, --model_family <model_family>", 'Name of model family to use for queries').choices(['gpt4', 'gpt', 'palm']).default('palm'));
.requiredOption(
"-f, --folders <folders...>",
"Name of folder(s) under incentives_data/ where text data is located."
)
.option(
"-r, --restrict <restrict_files...>",
"Will process only the files supplied. Useful for re-dos. Give the full path including INCENTIVES_FILE_BASE"
)
.option(
"-o, --output_file <file>",
"Name of output file. Saved in the out/ directory.",
"output.csv"
)
.option(
"-w, --wait <duration_ms>",
"How long to wait in ms between requests to avoid rate limiting"
)
.addOption(
new Option(
"-m, --model_family <model_family>",
"Name of model family to use for queries"
)
.choices(["gpt4", "gpt", "palm", "gemini"])
.default("palm")
);

program.parse();


async function retrieveMetadata(folder: string, file: string): Promise<Metadata> {
const metadata_file = path.parse(file).name + "_metadata.json"
let contents: string = ""
async function retrieveMetadata(
folder: string,
file: string
): Promise<Metadata> {
const metadata_file = path.parse(file).name + "_metadata.json";
let contents: string = "";
try {
contents = await fs.readFile(path.join(INCENTIVES_FILE_BASE, folder, metadata_file), { encoding: 'utf8' })
const metadata: Metadata = JSON.parse(contents)
return metadata
contents = await fs.readFile(
path.join(INCENTIVES_FILE_BASE, folder, metadata_file),
{ encoding: "utf8" }
);
const metadata: Metadata = JSON.parse(contents);
return metadata;
} catch (err) {
if (err instanceof SyntaxError) {
console.log(`Error parsing metadata in ${metadata_file}: contents are ${contents}; error is ${err}`)
console.log(
`Error parsing metadata in ${metadata_file}: contents are ${contents}; error is ${err}`
);
} else {
console.log(`No metadata file found: ${path.join(INCENTIVES_FILE_BASE, folder, file)}`);
console.log(
`No metadata file found: ${path.join(
INCENTIVES_FILE_BASE,
folder,
file
)}`
);
}
return {}
return {};
}
}

function getParamsForLogging(opts: OptionValues) {
return {
model: opts.model_family,
system: SYSTEM,
examples: [[EXAMPLE_1_USER, EXAMPLE_1_RESPONSE], [EXAMPLE_2_USER, EXAMPLE_2_RESPONSE]]
}
examples: [
[EXAMPLE_1_USER, EXAMPLE_1_RESPONSE],
[EXAMPLE_2_USER, EXAMPLE_2_RESPONSE],
],
};
}

async function main() {
Expand All @@ -56,31 +103,44 @@ async function main() {
const output: object[] = [];
const metadata_fields: Set<string> = new Set<string>();

const runId = Date.now().toString()
const runId = Date.now().toString();

await fs.mkdir(path.join(OUTPUT_FILE_BASE, runId));
await fs.writeFile(path.join(OUTPUT_FILE_BASE, runId, "parameters.json"), JSON.stringify(getParamsForLogging(opts)), {
encoding: "utf-8",
flag: "w"
});
await fs.writeFile(
path.join(OUTPUT_FILE_BASE, runId, "parameters.json"),
JSON.stringify(getParamsForLogging(opts)),
{
encoding: "utf-8",
flag: "w",
}
);

await fs.mkdir(path.join(OUTPUT_FILE_BASE, runId, OUTPUT_SUBDIR));
const droppedFiles: string[] = [];
for (const folder of opts.folders) {
await fs.mkdir(path.join(OUTPUT_FILE_BASE, runId, OUTPUT_SUBDIR, folder)).catch(err => {
if (err.code !== 'EEXIST') {
console.log(err);
}
});
await fs
.mkdir(path.join(OUTPUT_FILE_BASE, runId, OUTPUT_SUBDIR, folder))
.catch((err) => {
if (err.code !== "EEXIST") {
console.log(err);
}
});
const files = await fs.readdir(path.join(INCENTIVES_FILE_BASE, folder));
for (const file of files) {
if (!file.endsWith(".txt")) continue;
if (opts.restrict && !(opts.restrict.includes(path.join(INCENTIVES_FILE_BASE, folder, file)))) {
if (
opts.restrict &&
!opts.restrict.includes(path.join(INCENTIVES_FILE_BASE, folder, file))
) {
continue;
}
const txt = (await fs.readFile(path.join(INCENTIVES_FILE_BASE, folder, file), { encoding: 'utf8' })).trim();
const txt = (
await fs.readFile(path.join(INCENTIVES_FILE_BASE, folder, file), {
encoding: "utf8",
})
).trim();
if (txt.length == 0) {
console.log(`Skipping ${path.join(folder, file)} because it is empty`)
console.log(`Skipping ${path.join(folder, file)} because it is empty`);
continue;
}

Expand All @@ -89,50 +149,87 @@ async function main() {
metadata_fields.add(field);
}
if (metadata.tags !== undefined && metadata.tags.includes("index")) {
console.log(`Skipping ${path.join(folder, file)} because we detected an index tag`)
console.log(
`Skipping ${path.join(folder, file)} because we detected an index tag`
);
continue;
}

if (opts.wait) {
await new Promise(f => setTimeout(f, +opts.wait))
await new Promise((f) => setTimeout(f, +opts.wait));
}

console.log(`Querying ${opts.model_family} with ${path.join(INCENTIVES_FILE_BASE, folder, file)}`)
const gpt_wrapper = new GptWrapper(opts.model_family)
const queryFunc = opts.model_family == 'palm' ? queryPalm : gpt_wrapper.queryGpt.bind(gpt_wrapper)
const promise = queryFunc(txt, SYSTEM, [[EXAMPLE_1_USER, EXAMPLE_1_RESPONSE], [EXAMPLE_2_USER, EXAMPLE_2_RESPONSE]]).then(async msg => {
console.log(
`Querying ${opts.model_family} with ${path.join(
INCENTIVES_FILE_BASE,
folder,
file
)}`
);
const gpt_wrapper = new GptWrapper(opts.model_family);
let queryFunc;
switch (opts.model_family) {
case "palm": {
queryFunc = queryPalm;
break;
}
case "gemini": {
queryFunc = queryGemini;
break;
}
default: {
queryFunc = gpt_wrapper.queryGpt.bind(gpt_wrapper);
break;
}
}
const promise = queryFunc(txt, SYSTEM, [
[EXAMPLE_1_USER, EXAMPLE_1_RESPONSE],
[EXAMPLE_2_USER, EXAMPLE_2_RESPONSE],
]).then(async (msg) => {
if (msg == "") return;
console.log(`Got response from ${path.join(INCENTIVES_FILE_BASE, folder, file)}`)
console.log(
`Got response from ${path.join(INCENTIVES_FILE_BASE, folder, file)}`
);
try {
let records = JSON.parse(msg);
if (!(Symbol.iterator in Object(records))) {
records = [records]
records = [records];
}

let incentive_order_key = 0;
const file_records: object[] = []
const file_records: object[] = [];
let combined: object = {};
for (const record of records) {
cleanupEnumFields(record)
generateHomeownerRenterField(record)
cleanupEnumFields(record);
generateHomeownerRenterField(record);

if (!('folder' in metadata)) {
metadata['folder'] = folder;
if (!("folder" in metadata)) {
metadata["folder"] = folder;
}
metadata['file'] = file; // For debugging.
metadata['order'] = incentive_order_key;
metadata["file"] = file; // For debugging.
metadata["order"] = incentive_order_key;
combined = { ...record, ...metadata };
output.push(combined);
file_records.push(combined)
file_records.push(combined);
incentive_order_key += 1;
}
await fs.writeFile(path.join(OUTPUT_FILE_BASE, runId, OUTPUT_SUBDIR, folder, file.replace(".txt", "_output.json")), JSON.stringify(file_records, null, 2), {
encoding: "utf-8",
flag: "w"
})
await fs.writeFile(
path.join(
OUTPUT_FILE_BASE,
runId,
OUTPUT_SUBDIR,
folder,
file.replace(".txt", "_output.json")
),
JSON.stringify(file_records, null, 2),
{
encoding: "utf-8",
flag: "w",
}
);
} catch (error) {
console.error(`Error parsing json: ${error}, ${msg}`);
droppedFiles.push(path.join(INCENTIVES_FILE_BASE, folder, file))
droppedFiles.push(path.join(INCENTIVES_FILE_BASE, folder, file));
}
});
promises.push(promise);
Expand All @@ -147,13 +244,23 @@ async function main() {

const parser = new AsyncParser(CSV_OPTS);
const csv = await parser.parse(output).promise();
await fs.writeFile(path.join(OUTPUT_FILE_BASE, runId, opts.output_file), csv);
await fs.writeFile(
path.join(OUTPUT_FILE_BASE, runId, opts.output_file),
csv
);
if (droppedFiles.length > 0) {
await fs.writeFile(path.join(OUTPUT_FILE_BASE, runId, "dropped_files.json"), JSON.stringify(droppedFiles));
await fs.writeFile(
path.join(OUTPUT_FILE_BASE, runId, "dropped_files.json"),
JSON.stringify(droppedFiles)
);
}
console.log(`Find your results with run ID ${runId} at ${path.join(OUTPUT_FILE_BASE, runId)}`)

})
console.log(
`Find your results with run ID ${runId} at ${path.join(
OUTPUT_FILE_BASE,
runId
)}`
);
});
}

main()
main();
Loading