-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
530 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
FROM python:3.11 | ||
|
||
RUN mkdir /home/app | ||
WORKDIR /home | ||
|
||
COPY requirements.txt /home/app/requirements.txt | ||
COPY main.py /home/app/main.py | ||
|
||
RUN python3 -m pip install -r /home/app/requirements.txt | ||
|
||
# Preload the model into the image | ||
RUN python3 -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-miniLM-L6-v2')" | ||
|
||
CMD ["python3", "/home/app/main.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
FROM otel/opentelemetry-collector-contrib:0.101.0 | ||
|
||
COPY collector-config.yaml /etc/otelcol-contrib/config.yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
receivers: | ||
otlp: | ||
protocols: | ||
grpc: | ||
http: | ||
|
||
processors: | ||
batch: | ||
# batch metrics before sending to reduce API usage | ||
send_batch_max_size: 200 | ||
send_batch_size: 200 | ||
timeout: 5s | ||
|
||
memory_limiter: | ||
# drop metrics if memory usage gets too high | ||
check_interval: 1s | ||
limit_percentage: 65 | ||
spike_limit_percentage: 20 | ||
|
||
# automatically detect Cloud Run resource metadata | ||
resourcedetection: | ||
detectors: [env, gcp] | ||
timeout: 2s | ||
override: false | ||
|
||
resource: | ||
attributes: | ||
# add instance_id as a resource attribute | ||
- key: service.instance.id | ||
from_attribute: faas.id | ||
action: upsert | ||
# parse service name from K_SERVICE Cloud Run variable | ||
- key: service.name | ||
value: ${env:K_SERVICE} | ||
action: insert | ||
|
||
exporters: | ||
googlemanagedprometheus: # Note: this is intentionally left blank | ||
|
||
extensions: | ||
health_check: | ||
|
||
service: | ||
extensions: [health_check] | ||
pipelines: | ||
metrics: | ||
receivers: [otlp] | ||
processors: [batch, memory_limiter, resourcedetection, resource] | ||
exporters: [googlemanagedprometheus] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from sentence_transformers import SentenceTransformer | ||
from google.cloud import aiplatform | ||
from pyspark.sql import functions as F | ||
import pyspark.sql.types as T | ||
import pandas as pd | ||
|
||
df = spark.read.format("bigquery").load("bigquery-public-data.breathe.nature") | ||
data = df.where(F.length("body") > 100).sample(.1).collect() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
### INCOMPLETE | ||
|
||
PROJECT_ID=<YOUR-PROJECT-ID> | ||
REGION=<YOUR-REGION> | ||
GCAR_REPO=my-repo | ||
APP_NAME=gemini-app | ||
|
||
python <<EOF | ||
import yaml | ||
PROJECT_ID="${PROJECT_ID}" | ||
REGION="${REGION}" | ||
REPO="${GCAR_REPO}" | ||
APP_NAME="${APP_NAME}" | ||
with open("service.yaml", "r+") as f: | ||
y = yaml.safe_load(f) | ||
y["metadata"]["name"] = APP_NAME | ||
y["spec"]["template"]["spec"]["containers"][0]["image"] = f"{REGION}-docker.pkg.dev/{PROJECT_ID}/{REPO}/{APP_NAME}" | ||
y["spec"]["template"]["spec"]["containers"][0]["env"][1]["value"] = PROJECT_ID | ||
y["spec"]["template"]["spec"]["containers"][0]["env"][2]["value"] = REGION | ||
y["spec"]["template"]["spec"]["containers"][1]["image"] = f"{REGION}-docker.pkg.dev/{PROJECT_ID}/{REPO}/otel-collector-metrics" | ||
f.seek(0) | ||
yaml.dump(y, f, sort_keys=False) | ||
f.truncate() | ||
EOF | ||
|
||
# Create artifact registry if it doesn't already exist | ||
gcloud artifacts repositories describe ${GCAR_REPO} \ | ||
--location=${REGION} \ | ||
--project=${PROJECT_ID} >/dev/null 2>&1 || | ||
gcloud artifacts repositories create ${GCAR_REPO} \ | ||
--repository-format=docker \ | ||
--location=${REGION} \ | ||
--project=${PROJECT_ID} | ||
|
||
# Create app image | ||
gcloud builds submit \ | ||
--tag ${REGION}-docker.pkg.dev/${PROJECT_ID}/${GCAR_REPO}/${APP_NAME} \ | ||
--region=${REGION} \ | ||
--project=${PROJECT_ID} | ||
|
||
# Create collector image | ||
gcloud builds submit collector \ | ||
--tag ${REGION}-docker.pkg.dev/${PROJECT_ID}/${GCAR_REPO}/otel-collector-metrics \ | ||
--region=${REGION} \ | ||
--project=${PROJECT_ID} | ||
|
||
# Create the Cloud Run app | ||
gcloud run services replace service.yaml \ | ||
--region=${REGION} \ | ||
--project=${PROJECT_ID} | ||
|
||
# (Optional) Allow unauthenticated calls | ||
# gcloud run services set-iam-policy ${APP_NAME} unauthenticated_policy.yaml \ | ||
# --region=${REGION} | ||
# --project=${PROJECT_ID} | ||
|
||
# Need to also create BigQuery dataset, embeddings, and firestore db |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
CREATE VECTOR INDEX my_index ON `rag_data.embeddings`(embeddings) | ||
OPTIONS(distance_type='COSINE', index_type='IVF', ivf_options='{"num_lists": 1000}'); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
import os | ||
import time | ||
import uuid | ||
|
||
from google.cloud import bigquery | ||
from google.cloud import firestore | ||
from nicegui import app, run, ui | ||
from opentelemetry import trace | ||
from opentelemetry.sdk.trace import TracerProvider | ||
from opentelemetry.sdk.trace.export import ( | ||
BatchSpanProcessor, | ||
ConsoleSpanExporter, | ||
) | ||
import pandas as pd | ||
from sentence_transformers import SentenceTransformer | ||
import vertexai | ||
from vertexai.generative_models import GenerativeModel | ||
from vertexai.evaluation import EvalTask, PointwiseMetric | ||
|
||
# Define constants | ||
# Instantiate OpenTelemetry | ||
provider = TracerProvider() | ||
processor = BatchSpanProcessor(ConsoleSpanExporter()) | ||
provider.add_span_processor(processor) | ||
|
||
# Sets the global default tracer provider | ||
trace.set_tracer_provider(provider) | ||
|
||
# Creates a tracer from the global tracer provider | ||
TRACER = trace.get_tracer("my.tracer.name") | ||
|
||
PROJECT_ID = os.environ.get("PROJECT") | ||
REGION = os.environ.get("REGION", "us-central1") | ||
|
||
# Initialize VertexAI | ||
vertexai.init(project=PROJECT_ID, location=REGION) | ||
|
||
BIGQUERY_CLIENT = bigquery.Client(project=PROJECT_ID) | ||
FIRESTORE_CLIENT = firestore.Client(project=PROJECT_ID, database="gemini-hackathon") | ||
GEMINI_ENDPOINT = GenerativeModel("gemini-1.5-flash") | ||
TRANSFORMER_MODEL = None | ||
|
||
|
||
def get_embedding(prompt): | ||
global TRANSFORMER_MODEL | ||
|
||
if not TRANSFORMER_MODEL: | ||
TRANSFORMER_MODEL = SentenceTransformer("all-miniLM-L6-v2") | ||
embeddings = TRANSFORMER_MODEL.encode(prompt) | ||
return embeddings.tolist() | ||
|
||
|
||
def make_gemini_prediction(prompt: str) -> str: | ||
try: | ||
return GEMINI_ENDPOINT.generate_content(prompt).text | ||
except Exception as e: | ||
print(e) | ||
raise() | ||
|
||
|
||
def prompt_maker(input) -> tuple[str, str, str]: | ||
context = get_rag_context(input) | ||
version = os.environ.get("PROMPT_VERSION", "v20240926.1") | ||
ref = FIRESTORE_CLIENT.collection(f"prompts").document(document_id=version).get() | ||
prompt = ref.to_dict()["prompt"].format(input=input, context=context) | ||
return prompt, version | ||
|
||
|
||
def write_to_database(client_id: str, data: dict): | ||
FIRESTORE_CLIENT.collection("requests").add(data, document_id=client_id) | ||
|
||
|
||
def multiturn_quality(history, prompt, response): | ||
# Define a pointwise multi-turn chat quality metric | ||
pointwise_chat_quality_metric_prompt = """Evaluate the AI's contribution to a meaningful conversation, considering coherence, fluency, groundedness, and conciseness. | ||
Review the chat history for context. Rate the response on a 1-5 scale, with explanations for each criterion and its overall impact. | ||
# Conversation History | ||
{history} | ||
# Current User Prompt | ||
{prompt} | ||
# AI-generated Response | ||
{response} | ||
""" | ||
|
||
freeform_multi_turn_chat_quality_metric = PointwiseMetric( | ||
metric="multi_turn_chat_quality_metric", | ||
metric_prompt_template=pointwise_chat_quality_metric_prompt, | ||
) | ||
|
||
eval_dataset = pd.DataFrame( | ||
{ | ||
"history": [history], | ||
"prompt": prompt, | ||
"response": response | ||
} | ||
) | ||
|
||
# Run evaluation using the defined metric | ||
eval_task = EvalTask( | ||
dataset=eval_dataset, | ||
metrics=[freeform_multi_turn_chat_quality_metric], | ||
) | ||
|
||
result = eval_task.evaluate() | ||
|
||
return { | ||
"score": result.metrics_table["multi_turn_chat_quality_metric/score"].item(), | ||
"explanation": result.metrics_table["multi_turn_chat_quality_metric/explanation"].item(), | ||
"mean": result.summary_metrics["multi_turn_chat_quality_metric/mean"].item() | ||
} | ||
|
||
|
||
def get_rag_context(input): | ||
embedding = get_embedding(input) | ||
|
||
query = f""" | ||
SELECT * | ||
FROM `{PROJECT_ID}.rag_data.rag_data` | ||
WHERE id IN ( | ||
SELECT s.base.id | ||
FROM VECTOR_SEARCH( | ||
TABLE `{PROJECT_ID}.rag_data.embeddings`, | ||
"embeddings", | ||
(SELECT {embedding}), | ||
top_k => 5) as s); | ||
""" | ||
print("waiting") | ||
rows = BIGQUERY_CLIENT.query_and_wait(query) | ||
print("waited") | ||
bodies = [row["body"] for row in rows] | ||
return " ### ".join(bodies) | ||
|
||
|
||
@ui.page('/') | ||
def index(): | ||
async def update_prompt(): | ||
print("prompt received") | ||
input_time = time.time() | ||
user_input = user_input_raw.value | ||
|
||
with chat_container: | ||
ui.chat_message(user_input, name='Me') | ||
|
||
print("spinner") | ||
spinner = ui.spinner('audio', size='lg', color='green') | ||
|
||
client_id = str(uuid.uuid4()) | ||
request_id = f"{client_id}-{str(uuid.uuid4())[:8]}" | ||
|
||
prompt, prompt_version = await run.cpu_bound(prompt_maker, user_input) | ||
|
||
app.storage.client["count"] = app.storage.client.get("count", 0) + 1 | ||
app.storage.client["history"] = app.storage.client.get("history", "") + "### User: " + prompt | ||
|
||
with TRACER.start_as_current_span("child") as span: | ||
span.set_attribute( | ||
"operation.count", app.storage.client["count"]) | ||
span.set_attribute("prompt", user_input) | ||
span.set_attribute("prompt_id", prompt_version) | ||
span.set_attribute("client_id", client_id) | ||
span.set_attribute("request_id", request_id) | ||
|
||
request_time = time.time() | ||
|
||
response = await run.io_bound(make_gemini_prediction, prompt) | ||
# response = make_prediction(user_input) | ||
response_time = time.time() | ||
app.storage.client["history"] = app.storage.client.get("history") + "### Agent: " + response | ||
span.set_attribute("response", response) | ||
|
||
spinner.delete() | ||
|
||
ui.chat_message(response, | ||
name='Robot', | ||
stamp='now', | ||
avatar='https://robohash.org/ui',) \ | ||
.style('font-family: Comic Sans, sans-serif; font-size: 16px;') | ||
|
||
query = { | ||
"request_id": request_id, | ||
"prompt": user_input, | ||
"response": response, | ||
"input_time": input_time, | ||
"request_time": request_time, | ||
"response_time": response_time, | ||
"prompt_version": prompt_version | ||
} | ||
print(f"Count: {app.storage.client['count']}") | ||
write_to_database(client_id, query) | ||
# print(multiturn_quality( | ||
# app.storage.client.get("history"), | ||
# prompt, | ||
# response | ||
# )) | ||
|
||
ui.markdown("<h2>Welcome to predictions bot!</h2>") | ||
with ui.row().classes('flex flex-col h-screen'): | ||
chat_container = ui.column().classes('w-full max-w-3xl mx-auto my-6') | ||
|
||
with ui.footer().classes('bg-black'), ui.column().classes('w-full max-w-3xl mx-auto my-6'): | ||
with ui.row().classes('w-full no-wrap items-center'): | ||
user_input_raw = ui.input("Prompt").on('keydown.enter', update_prompt) \ | ||
.props('rounded outlined input-class=mx-3').classes('flex-grow') | ||
|
||
|
||
ui.run(host="0.0.0.0", port=int(os.environ.get("PORT", 8080)), storage_secret="1234", dark=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
google-cloud-bigquery | ||
google-cloud-firestore | ||
nicegui | ||
opentelemetry-distro | ||
pandas | ||
sentence_transformers | ||
vertexai |
Oops, something went wrong.