Skip to content

Commit

Permalink
Merge branch 'main' into inardini--batch-pipeline-rag
Browse files Browse the repository at this point in the history
  • Loading branch information
holtskinner authored Sep 23, 2024
2 parents 3e3220e + 94f8200 commit 3f20004
Show file tree
Hide file tree
Showing 7 changed files with 2,422 additions and 19 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,4 @@
/generative-ai/open-models/serving @polong-lin @GoogleCloudPlatform/generative-ai-devrel
/generative-ai/open-models/serving/cloud_run_ollama_gemma2_rag_qa.ipynb @eliasecchig @GoogleCloudPlatform/generative-ai-devrel
/generative-ai/open-models/serving/vertex_ai_text_generation_inference_gemma.ipynb @alvarobartt @philschmid @pagezyhf @jeffboudier
/generative-ai/gemini/use-cases/applying-llms-to-data/semantic-search-in-bigquery/stackoverflow_questions_semantic_search.ipynb @sethijaideep @GoogleCloudPlatform/generative-ai-devrel
1 change: 1 addition & 0 deletions .github/actions/spelling/allow.txt
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ ipynb
isa
itables
iterrows
ivf
jegadesh
jetbrains
jsonify
Expand Down
102 changes: 93 additions & 9 deletions gemini/prompts/prompt_optimizer/vapo_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Utility functions and classes for the VAPO notebook."""
import csv
import io
import json
import re
Expand Down Expand Up @@ -61,17 +62,89 @@ def is_run_target_required(eval_metric_types: list[str], source_model: str) -> b
_TARGET_KEY = "target"


def load_file_from_gcs(dataset: str) -> str:
"""Loads the file from GCS and returns it as a string."""
if dataset.startswith("gs://"):
with gfile.GFile(dataset, "r") as f:
return f.read()
else:
raise ValueError(
"Unsupported file location. Only GCS paths starting with 'gs://' are"
" supported."
)


def parse_jsonl(data_str: str) -> list[dict[str, str]]:
"""Parses the content of a JSONL file and returns a list of dictionaries."""
data = []
lines = data_str.splitlines()
for line in lines:
if line:
try:
data.append(json.loads(line))
except json.JSONDecodeError as e:
raise ValueError(
f"Error decoding JSON on line: {line}. Error: {e}"
) from e
return data


def parse_and_validate_csv(data_str: str) -> list[dict[str, str]]:
"""Parses and validates the content of a CSV file and returns a list of dictionaries."""
data = []
csv_reader = csv.reader(io.StringIO(data_str))

# Extract and validate headers
try:
headers = next(csv_reader)
if not headers:
raise ValueError("The CSV file has an empty or invalid header row.")
except StopIteration as e:
raise ValueError("The CSV file is empty.") from e

# Validate and process rows
for row_number, row in enumerate(csv_reader, start=2):
if len(row) != len(headers):
raise ValueError(
f"Row {row_number} has an inconsistent number of fields. "
f"Expected {len(headers)} fields but found {len(row)}."
)
# Create dictionary for each row using headers as keys
item = dict(zip(headers, row))
data.append(item)

return data


def load_dataset(dataset: str) -> list[dict[str, str]]:
"""Loads and parses the dataset based on its file type ('.jsonl' or '.csv')."""
# Load the file from GCS
data_str = load_file_from_gcs(dataset)

# Parse based on file type
if dataset.endswith(".jsonl"):
return parse_jsonl(data_str)

if dataset.endswith(".csv"):
return parse_and_validate_csv(data_str)

raise ValueError(
"Unsupported file type. Please provide a file with '.jsonl' or '.csv'"
" extension."
)


def validate_prompt_and_data(
template: str,
dataset_path: str,
placeholder_to_content: str,
label_enforced: bool,
) -> None:
"""Validates the prompt template and the dataset."""
placeholder_to_content = json.loads(placeholder_to_content)
with gfile.GFile(dataset_path, "r") as f:
data = [json.loads(line) for line in f.readlines()]

data = load_dataset(dataset_path)
placeholder_to_content_json = json.loads(placeholder_to_content)
template = re.sub(r"(?<!{){(?!{)", "{{", template)
template = re.sub(r"(?<!})}(?!})", "}}", template)
env = jinja2.Environment()
try:
parsed_content = env.parse(template)
Expand All @@ -81,7 +154,7 @@ def validate_prompt_and_data(
template_variables = jinja2.meta.find_undeclared_variables(parsed_content)
extra_keys = set()
for ex in data:
ex.update(placeholder_to_content)
ex.update(placeholder_to_content_json)
missing_keys = [key for key in template_variables if key not in ex]
extra_keys.update([key for key in ex if key not in template_variables])
if label_enforced:
Expand All @@ -99,7 +172,7 @@ def validate_prompt_and_data(
)
if extra_keys:
raise Warning(
"Warning: extra keys in the examples not used in the context/task"
"Warning: extra keys in the examples not used in the prompt template"
f" template {extra_keys}"
)

Expand Down Expand Up @@ -189,8 +262,10 @@ def generate_dataframe(filename: str) -> pd.DataFrame:
return pd.DataFrame()

with gfile.GFile(filename, "r") as f:
data = json.load(f)

try:
data = json.load(f)
except json.JSONDecodeError:
return pd.DataFrame()
return pd.json_normalize(data)


Expand Down Expand Up @@ -227,6 +302,15 @@ class ProgressForm:

def __init__(self, params: dict[str, str]) -> None:
"""Initialize the progress form."""
self.instruction_progress_bar = None
self.instruction_display = None
self.instruction_best = None
self.instruction_score = None
self.demo_progress_bar = None
self.demo_display = None
self.demo_best = None
self.demo_score = None

self.job_state_display = display(
HTML("<span>Job State: Not Started!</span>"), display_id=True
)
Expand Down Expand Up @@ -262,7 +346,7 @@ def __init__(self, params: dict[str, str]) -> None:
# pylint: disable=too-many-arguments
def update_progress(
self,
progress_bar: widgets.IntProgress,
progress_bar: widgets.IntProgress | None,
templates_file: str,
df: pd.DataFrame | None,
df_display: DisplayHandle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@
"outputs": [],
"source": [
"SYSTEM_INSTRUCTION = \"Answer the following question. Let's think step by step.\\n\" # @param {type:\"string\"}\n",
"PROMPT_TEMPLATE = (\n",
" \"Question: {{question}}\\n\\nAnswer:{{target}}\" # @param {type:\"string\"}\n",
")"
"PROMPT_TEMPLATE = \"Question: {question}\\n\\nAnswer:{target}\" # @param {type:\"string\"}"
]
},
{
Expand Down Expand Up @@ -203,9 +201,9 @@
"# @markdown * Number of the demonstrations to include in each prompt.\n",
"\n",
"# @markdown **Model Configs**: <br/>\n",
"TARGET_MODEL_QPS = 3 # @param {type:\"integer\"}\n",
"SOURCE_MODEL_QPS = 3 # @param {type:\"integer\"}\n",
"EVAL_MODEL_QPS = 3 # @param {type:\"integer\"}\n",
"TARGET_MODEL_QPS = 3.0 # @param {type:\"number\"}\n",
"SOURCE_MODEL_QPS = 3.0 # @param {type:\"number\"}\n",
"EVAL_QPS = 3.0 # @param {type:\"number\"}\n",
"# @markdown * The QPS for calling the eval model, which is currently gemini-1.5-pro-001.\n",
"\n",
"# @markdown **Multi-metric Configs**: <br/>\n",
Expand Down Expand Up @@ -280,15 +278,15 @@
"params = {\n",
" \"project\": PROJECT_ID,\n",
" \"num_steps\": NUM_INST_OPTIMIZATION_STEPS,\n",
" \"prompt_template\": SYSTEM_INSTRUCTION,\n",
" \"demo_and_query_template\": PROMPT_TEMPLATE,\n",
" \"system_instruction\": SYSTEM_INSTRUCTION,\n",
" \"prompt_template\": PROMPT_TEMPLATE,\n",
" \"target_model\": TARGET_MODEL,\n",
" \"target_model_qps\": TARGET_MODEL_QPS,\n",
" \"target_model_location\": LOCATION,\n",
" \"source_model\": SOURCE_MODEL,\n",
" \"source_model_qps\": SOURCE_MODEL_QPS,\n",
" \"source_model_location\": LOCATION,\n",
" \"eval_model_qps\": EVAL_MODEL_QPS,\n",
" \"eval_qps\": EVAL_QPS,\n",
" \"eval_model_location\": LOCATION,\n",
" \"optimization_mode\": OPTIMIZATION_MODE,\n",
" \"num_demo_set_candidates\": NUM_DEMO_OPTIMIZATION_STEPS,\n",
Expand Down Expand Up @@ -346,7 +344,7 @@
"source": [
"from IPython.display import HTML, display\n",
"\n",
"RESULT_PATH = \"gs://prompt_design_demo\" # @param {type:\"string\"}\n",
"RESULT_PATH = \"[OUTPUT_PATH]\" # @param {type:\"string\"}\n",
"# @markdown * Specify a GCS path that contains artifacts of a single or multiple VAPO runs.\n",
"\n",
"results_ui = vapo_lib.ResultsUI(RESULT_PATH)\n",
Expand Down
Loading

0 comments on commit 3f20004

Please sign in to comment.