diff --git a/src/gemini_wrapper.ts b/src/gemini_wrapper.ts new file mode 100644 index 0000000..bdfd440 --- /dev/null +++ b/src/gemini_wrapper.ts @@ -0,0 +1,64 @@ +import { GenerateContentRequest, VertexAI } from "@google-cloud/vertexai"; +// import util from "util"; + +const projectId = "fleet-toolbox-404319"; +const location = "us-central1"; +const model = "gemini-1.0-pro"; + +const vertexAI = new VertexAI({ project: projectId, location: location }); +const generativeModel = vertexAI.getGenerativeModel({ + model: model, +}); + +function generateMessagesGemini( + input_text: string, + system: string, + examples: [string, string][] +): GenerateContentRequest { + const examples_list = []; + for (const [user, response] of examples) { + examples_list.push( + { role: "user", parts: [{ text: user }] }, + { role: "assistant", parts: [{ text: response }] } + ); + } + return { + contents: [ + { role: "system", parts: [{ text: system }] }, + ...examples_list, + { role: "user", parts: [{ text: input_text }] }, + ], + }; +} + +export async function queryGemini( + input_text: string, + system: string, + examples: [string, string][] +): Promise { + if (input_text.trim().length === 0) { + throw new TypeError("Input text can't be empty string"); + } + const prompt = generateMessagesGemini(input_text, system, examples); + // TODO: add token counting + + const resp = await generativeModel.countTokens(prompt); + console.log("count tokens response: ", resp); + + // console.log(util.inspect(prompt, false, null, true)); + // const prompt2: GenerateContentRequest = { + // contents: [{ role: "user", parts: [{ text: "Tell me 3 dad jokes." }] }], + // }; + + try { + const response = await generativeModel.generateContent(prompt); + return response.response.candidates[0].content.parts[0].text || ""; + } catch (error) { + if (error instanceof Error) { + console.error(`Error with Gemini API request, error: ${error}`); + } else { + console.error(`Error with Gemini API request: ${error}`); + } + } + return ""; +} diff --git a/src/llm_runner.ts b/src/llm_runner.ts index e56e10c..713cffb 100644 --- a/src/llm_runner.ts +++ b/src/llm_runner.ts @@ -1,43 +1,87 @@ -import { AsyncParser } from '@json2csv/node'; -import { program, Option, OptionValues } from 'commander'; +import { AsyncParser } from "@json2csv/node"; +import { program, Option, OptionValues } from "commander"; -import fs = require('node:fs/promises'); -import path = require('node:path'); +import fs = require("node:fs/promises"); +import path = require("node:path"); -import { INCENTIVES_FILE_BASE, OUTPUT_FILE_BASE, OUTPUT_SUBDIR, CSV_OPTS } from './constants.js'; -import { SYSTEM, EXAMPLE_1_RESPONSE, EXAMPLE_1_USER, EXAMPLE_2_RESPONSE, EXAMPLE_2_USER } from "./prompt.js" +import { + INCENTIVES_FILE_BASE, + OUTPUT_FILE_BASE, + OUTPUT_SUBDIR, + CSV_OPTS, +} from "./constants.js"; +import { + SYSTEM, + EXAMPLE_1_RESPONSE, + EXAMPLE_1_USER, + EXAMPLE_2_RESPONSE, + EXAMPLE_2_USER, +} from "./prompt.js"; import { queryPalm } from "./palm_wrapper.js"; import { GptWrapper } from "./gpt_wrapper.js"; -import { Metadata } from "./metadata.js" -import { generateHomeownerRenterField } from './post_processing/homeowner_renter.js'; -import { cleanupEnumFields } from './post_processing/cleanup_enum_fields.js'; - +import { Metadata } from "./metadata.js"; +import { generateHomeownerRenterField } from "./post_processing/homeowner_renter.js"; +import { cleanupEnumFields } from "./post_processing/cleanup_enum_fields.js"; +import { queryGemini } from "./gemini_wrapper.js"; program - .requiredOption("-f, --folders ", "Name of folder(s) under incentives_data/ where text data is located.") - .option("-r, --restrict ", 'Will process only the files supplied. Useful for re-dos. Give the full path including INCENTIVES_FILE_BASE') - .option("-o, --output_file ", 'Name of output file. Saved in the out/ directory.', "output.csv") - .option("-w, --wait ", "How long to wait in ms between requests to avoid rate limiting") - .addOption(new Option("-m, --model_family ", 'Name of model family to use for queries').choices(['gpt4', 'gpt', 'palm']).default('palm')); + .requiredOption( + "-f, --folders ", + "Name of folder(s) under incentives_data/ where text data is located." + ) + .option( + "-r, --restrict ", + "Will process only the files supplied. Useful for re-dos. Give the full path including INCENTIVES_FILE_BASE" + ) + .option( + "-o, --output_file ", + "Name of output file. Saved in the out/ directory.", + "output.csv" + ) + .option( + "-w, --wait ", + "How long to wait in ms between requests to avoid rate limiting" + ) + .addOption( + new Option( + "-m, --model_family ", + "Name of model family to use for queries" + ) + .choices(["gpt4", "gpt", "palm", "gemini"]) + .default("palm") + ); program.parse(); - -async function retrieveMetadata(folder: string, file: string): Promise { - const metadata_file = path.parse(file).name + "_metadata.json" - let contents: string = "" +async function retrieveMetadata( + folder: string, + file: string +): Promise { + const metadata_file = path.parse(file).name + "_metadata.json"; + let contents: string = ""; try { - contents = await fs.readFile(path.join(INCENTIVES_FILE_BASE, folder, metadata_file), { encoding: 'utf8' }) - const metadata: Metadata = JSON.parse(contents) - return metadata + contents = await fs.readFile( + path.join(INCENTIVES_FILE_BASE, folder, metadata_file), + { encoding: "utf8" } + ); + const metadata: Metadata = JSON.parse(contents); + return metadata; } catch (err) { if (err instanceof SyntaxError) { - console.log(`Error parsing metadata in ${metadata_file}: contents are ${contents}; error is ${err}`) + console.log( + `Error parsing metadata in ${metadata_file}: contents are ${contents}; error is ${err}` + ); } else { - console.log(`No metadata file found: ${path.join(INCENTIVES_FILE_BASE, folder, file)}`); + console.log( + `No metadata file found: ${path.join( + INCENTIVES_FILE_BASE, + folder, + file + )}` + ); } - return {} + return {}; } } @@ -45,8 +89,11 @@ function getParamsForLogging(opts: OptionValues) { return { model: opts.model_family, system: SYSTEM, - examples: [[EXAMPLE_1_USER, EXAMPLE_1_RESPONSE], [EXAMPLE_2_USER, EXAMPLE_2_RESPONSE]] - } + examples: [ + [EXAMPLE_1_USER, EXAMPLE_1_RESPONSE], + [EXAMPLE_2_USER, EXAMPLE_2_RESPONSE], + ], + }; } async function main() { @@ -56,31 +103,44 @@ async function main() { const output: object[] = []; const metadata_fields: Set = new Set(); - const runId = Date.now().toString() + const runId = Date.now().toString(); await fs.mkdir(path.join(OUTPUT_FILE_BASE, runId)); - await fs.writeFile(path.join(OUTPUT_FILE_BASE, runId, "parameters.json"), JSON.stringify(getParamsForLogging(opts)), { - encoding: "utf-8", - flag: "w" - }); + await fs.writeFile( + path.join(OUTPUT_FILE_BASE, runId, "parameters.json"), + JSON.stringify(getParamsForLogging(opts)), + { + encoding: "utf-8", + flag: "w", + } + ); await fs.mkdir(path.join(OUTPUT_FILE_BASE, runId, OUTPUT_SUBDIR)); const droppedFiles: string[] = []; for (const folder of opts.folders) { - await fs.mkdir(path.join(OUTPUT_FILE_BASE, runId, OUTPUT_SUBDIR, folder)).catch(err => { - if (err.code !== 'EEXIST') { - console.log(err); - } - }); + await fs + .mkdir(path.join(OUTPUT_FILE_BASE, runId, OUTPUT_SUBDIR, folder)) + .catch((err) => { + if (err.code !== "EEXIST") { + console.log(err); + } + }); const files = await fs.readdir(path.join(INCENTIVES_FILE_BASE, folder)); for (const file of files) { if (!file.endsWith(".txt")) continue; - if (opts.restrict && !(opts.restrict.includes(path.join(INCENTIVES_FILE_BASE, folder, file)))) { + if ( + opts.restrict && + !opts.restrict.includes(path.join(INCENTIVES_FILE_BASE, folder, file)) + ) { continue; } - const txt = (await fs.readFile(path.join(INCENTIVES_FILE_BASE, folder, file), { encoding: 'utf8' })).trim(); + const txt = ( + await fs.readFile(path.join(INCENTIVES_FILE_BASE, folder, file), { + encoding: "utf8", + }) + ).trim(); if (txt.length == 0) { - console.log(`Skipping ${path.join(folder, file)} because it is empty`) + console.log(`Skipping ${path.join(folder, file)} because it is empty`); continue; } @@ -89,50 +149,87 @@ async function main() { metadata_fields.add(field); } if (metadata.tags !== undefined && metadata.tags.includes("index")) { - console.log(`Skipping ${path.join(folder, file)} because we detected an index tag`) + console.log( + `Skipping ${path.join(folder, file)} because we detected an index tag` + ); continue; } if (opts.wait) { - await new Promise(f => setTimeout(f, +opts.wait)) + await new Promise((f) => setTimeout(f, +opts.wait)); } - console.log(`Querying ${opts.model_family} with ${path.join(INCENTIVES_FILE_BASE, folder, file)}`) - const gpt_wrapper = new GptWrapper(opts.model_family) - const queryFunc = opts.model_family == 'palm' ? queryPalm : gpt_wrapper.queryGpt.bind(gpt_wrapper) - const promise = queryFunc(txt, SYSTEM, [[EXAMPLE_1_USER, EXAMPLE_1_RESPONSE], [EXAMPLE_2_USER, EXAMPLE_2_RESPONSE]]).then(async msg => { + console.log( + `Querying ${opts.model_family} with ${path.join( + INCENTIVES_FILE_BASE, + folder, + file + )}` + ); + const gpt_wrapper = new GptWrapper(opts.model_family); + let queryFunc; + switch (opts.model_family) { + case "palm": { + queryFunc = queryPalm; + break; + } + case "gemini": { + queryFunc = queryGemini; + break; + } + default: { + queryFunc = gpt_wrapper.queryGpt.bind(gpt_wrapper); + break; + } + } + const promise = queryFunc(txt, SYSTEM, [ + [EXAMPLE_1_USER, EXAMPLE_1_RESPONSE], + [EXAMPLE_2_USER, EXAMPLE_2_RESPONSE], + ]).then(async (msg) => { if (msg == "") return; - console.log(`Got response from ${path.join(INCENTIVES_FILE_BASE, folder, file)}`) + console.log( + `Got response from ${path.join(INCENTIVES_FILE_BASE, folder, file)}` + ); try { let records = JSON.parse(msg); if (!(Symbol.iterator in Object(records))) { - records = [records] + records = [records]; } let incentive_order_key = 0; - const file_records: object[] = [] + const file_records: object[] = []; let combined: object = {}; for (const record of records) { - cleanupEnumFields(record) - generateHomeownerRenterField(record) + cleanupEnumFields(record); + generateHomeownerRenterField(record); - if (!('folder' in metadata)) { - metadata['folder'] = folder; + if (!("folder" in metadata)) { + metadata["folder"] = folder; } - metadata['file'] = file; // For debugging. - metadata['order'] = incentive_order_key; + metadata["file"] = file; // For debugging. + metadata["order"] = incentive_order_key; combined = { ...record, ...metadata }; output.push(combined); - file_records.push(combined) + file_records.push(combined); incentive_order_key += 1; } - await fs.writeFile(path.join(OUTPUT_FILE_BASE, runId, OUTPUT_SUBDIR, folder, file.replace(".txt", "_output.json")), JSON.stringify(file_records, null, 2), { - encoding: "utf-8", - flag: "w" - }) + await fs.writeFile( + path.join( + OUTPUT_FILE_BASE, + runId, + OUTPUT_SUBDIR, + folder, + file.replace(".txt", "_output.json") + ), + JSON.stringify(file_records, null, 2), + { + encoding: "utf-8", + flag: "w", + } + ); } catch (error) { console.error(`Error parsing json: ${error}, ${msg}`); - droppedFiles.push(path.join(INCENTIVES_FILE_BASE, folder, file)) + droppedFiles.push(path.join(INCENTIVES_FILE_BASE, folder, file)); } }); promises.push(promise); @@ -147,13 +244,23 @@ async function main() { const parser = new AsyncParser(CSV_OPTS); const csv = await parser.parse(output).promise(); - await fs.writeFile(path.join(OUTPUT_FILE_BASE, runId, opts.output_file), csv); + await fs.writeFile( + path.join(OUTPUT_FILE_BASE, runId, opts.output_file), + csv + ); if (droppedFiles.length > 0) { - await fs.writeFile(path.join(OUTPUT_FILE_BASE, runId, "dropped_files.json"), JSON.stringify(droppedFiles)); + await fs.writeFile( + path.join(OUTPUT_FILE_BASE, runId, "dropped_files.json"), + JSON.stringify(droppedFiles) + ); } - console.log(`Find your results with run ID ${runId} at ${path.join(OUTPUT_FILE_BASE, runId)}`) - - }) + console.log( + `Find your results with run ID ${runId} at ${path.join( + OUTPUT_FILE_BASE, + runId + )}` + ); + }); } -main() \ No newline at end of file +main(); diff --git a/src/test_request_gemini.ts b/src/test_request_gemini.ts new file mode 100644 index 0000000..c6aac2a --- /dev/null +++ b/src/test_request_gemini.ts @@ -0,0 +1,40 @@ +import { VertexAI } from "@google-cloud/vertexai"; + +/** + * TODO(developer): Update these variables before running the sample. + */ +async function createNonStreamingMultipartContent( + projectId = "fleet-toolbox-404319", + location = "us-central1", + model = "gemini-1.0-pro" +) { + // Initialize Vertex with your Cloud project and location + const vertexAI = new VertexAI({ project: projectId, location: location }); + + // Instantiate the model + const generativeVisionModel = vertexAI.getGenerativeModel({ + model: model, + }); + + const textPart = { + text: "You are a helpful assistant. Give me three dad jokes.", + }; + + const request = { + contents: [{ role: "user", parts: [textPart] }], + }; + + console.log("Prompt Text:"); + console.log(request.contents[0].parts[0].text); + + console.log("Non-Streaming Response Text:"); + const response = await generativeVisionModel.generateContent(request); + + // Select the text from the response + const fullTextResponse = + response.response.candidates[0].content.parts[0].text; + + console.log(fullTextResponse); +} + +createNonStreamingMultipartContent();