Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
bradmiro committed Sep 27, 2024
1 parent 18883c4 commit cdfb8bb
Show file tree
Hide file tree
Showing 13 changed files with 530 additions and 0 deletions.
14 changes: 14 additions & 0 deletions ai-ml/spark-gemini-rag/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
FROM python:3.11

RUN mkdir /home/app
WORKDIR /home

COPY requirements.txt /home/app/requirements.txt
COPY main.py /home/app/main.py

RUN python3 -m pip install -r /home/app/requirements.txt

# Preload the model into the image
RUN python3 -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-miniLM-L6-v2')"

CMD ["python3", "/home/app/main.py"]
3 changes: 3 additions & 0 deletions ai-ml/spark-gemini-rag/collector/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
FROM otel/opentelemetry-collector-contrib:0.101.0

COPY collector-config.yaml /etc/otelcol-contrib/config.yaml
49 changes: 49 additions & 0 deletions ai-ml/spark-gemini-rag/collector/collector-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
receivers:
otlp:
protocols:
grpc:
http:

processors:
batch:
# batch metrics before sending to reduce API usage
send_batch_max_size: 200
send_batch_size: 200
timeout: 5s

memory_limiter:
# drop metrics if memory usage gets too high
check_interval: 1s
limit_percentage: 65
spike_limit_percentage: 20

# automatically detect Cloud Run resource metadata
resourcedetection:
detectors: [env, gcp]
timeout: 2s
override: false

resource:
attributes:
# add instance_id as a resource attribute
- key: service.instance.id
from_attribute: faas.id
action: upsert
# parse service name from K_SERVICE Cloud Run variable
- key: service.name
value: ${env:K_SERVICE}
action: insert

exporters:
googlemanagedprometheus: # Note: this is intentionally left blank

extensions:
health_check:

service:
extensions: [health_check]
pipelines:
metrics:
receivers: [otlp]
processors: [batch, memory_limiter, resourcedetection, resource]
exporters: [googlemanagedprometheus]
8 changes: 8 additions & 0 deletions ai-ml/spark-gemini-rag/create_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from sentence_transformers import SentenceTransformer
from google.cloud import aiplatform
from pyspark.sql import functions as F
import pyspark.sql.types as T
import pandas as pd

df = spark.read.format("bigquery").load("bigquery-public-data.breathe.nature")
data = df.where(F.length("body") > 100).sample(.1).collect()
59 changes: 59 additions & 0 deletions ai-ml/spark-gemini-rag/deploy.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
### INCOMPLETE

PROJECT_ID=<YOUR-PROJECT-ID>
REGION=<YOUR-REGION>
GCAR_REPO=my-repo
APP_NAME=gemini-app

python <<EOF
import yaml
PROJECT_ID="${PROJECT_ID}"
REGION="${REGION}"
REPO="${GCAR_REPO}"
APP_NAME="${APP_NAME}"
with open("service.yaml", "r+") as f:
y = yaml.safe_load(f)
y["metadata"]["name"] = APP_NAME
y["spec"]["template"]["spec"]["containers"][0]["image"] = f"{REGION}-docker.pkg.dev/{PROJECT_ID}/{REPO}/{APP_NAME}"
y["spec"]["template"]["spec"]["containers"][0]["env"][1]["value"] = PROJECT_ID
y["spec"]["template"]["spec"]["containers"][0]["env"][2]["value"] = REGION
y["spec"]["template"]["spec"]["containers"][1]["image"] = f"{REGION}-docker.pkg.dev/{PROJECT_ID}/{REPO}/otel-collector-metrics"
f.seek(0)
yaml.dump(y, f, sort_keys=False)
f.truncate()
EOF

# Create artifact registry if it doesn't already exist
gcloud artifacts repositories describe ${GCAR_REPO} \
--location=${REGION} \
--project=${PROJECT_ID} >/dev/null 2>&1 ||
gcloud artifacts repositories create ${GCAR_REPO} \
--repository-format=docker \
--location=${REGION} \
--project=${PROJECT_ID}

# Create app image
gcloud builds submit \
--tag ${REGION}-docker.pkg.dev/${PROJECT_ID}/${GCAR_REPO}/${APP_NAME} \
--region=${REGION} \
--project=${PROJECT_ID}

# Create collector image
gcloud builds submit collector \
--tag ${REGION}-docker.pkg.dev/${PROJECT_ID}/${GCAR_REPO}/otel-collector-metrics \
--region=${REGION} \
--project=${PROJECT_ID}

# Create the Cloud Run app
gcloud run services replace service.yaml \
--region=${REGION} \
--project=${PROJECT_ID}

# (Optional) Allow unauthenticated calls
# gcloud run services set-iam-policy ${APP_NAME} unauthenticated_policy.yaml \
# --region=${REGION}
# --project=${PROJECT_ID}

# Need to also create BigQuery dataset, embeddings, and firestore db
2 changes: 2 additions & 0 deletions ai-ml/spark-gemini-rag/generate_index.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
CREATE VECTOR INDEX my_index ON `rag_data.embeddings`(embeddings)
OPTIONS(distance_type='COSINE', index_type='IVF', ivf_options='{"num_lists": 1000}');
210 changes: 210 additions & 0 deletions ai-ml/spark-gemini-rag/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import os
import time
import uuid

from google.cloud import bigquery
from google.cloud import firestore
from nicegui import app, run, ui
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (
BatchSpanProcessor,
ConsoleSpanExporter,
)
import pandas as pd
from sentence_transformers import SentenceTransformer
import vertexai
from vertexai.generative_models import GenerativeModel
from vertexai.evaluation import EvalTask, PointwiseMetric

# Define constants
# Instantiate OpenTelemetry
provider = TracerProvider()
processor = BatchSpanProcessor(ConsoleSpanExporter())
provider.add_span_processor(processor)

# Sets the global default tracer provider
trace.set_tracer_provider(provider)

# Creates a tracer from the global tracer provider
TRACER = trace.get_tracer("my.tracer.name")

PROJECT_ID = os.environ.get("PROJECT")
REGION = os.environ.get("REGION", "us-central1")

# Initialize VertexAI
vertexai.init(project=PROJECT_ID, location=REGION)

BIGQUERY_CLIENT = bigquery.Client(project=PROJECT_ID)
FIRESTORE_CLIENT = firestore.Client(project=PROJECT_ID, database="gemini-hackathon")
GEMINI_ENDPOINT = GenerativeModel("gemini-1.5-flash")
TRANSFORMER_MODEL = None


def get_embedding(prompt):
global TRANSFORMER_MODEL

if not TRANSFORMER_MODEL:
TRANSFORMER_MODEL = SentenceTransformer("all-miniLM-L6-v2")
embeddings = TRANSFORMER_MODEL.encode(prompt)
return embeddings.tolist()


def make_gemini_prediction(prompt: str) -> str:
try:
return GEMINI_ENDPOINT.generate_content(prompt).text
except Exception as e:
print(e)
raise()


def prompt_maker(input) -> tuple[str, str, str]:
context = get_rag_context(input)
version = os.environ.get("PROMPT_VERSION", "v20240926.1")
ref = FIRESTORE_CLIENT.collection(f"prompts").document(document_id=version).get()
prompt = ref.to_dict()["prompt"].format(input=input, context=context)
return prompt, version


def write_to_database(client_id: str, data: dict):
FIRESTORE_CLIENT.collection("requests").add(data, document_id=client_id)


def multiturn_quality(history, prompt, response):
# Define a pointwise multi-turn chat quality metric
pointwise_chat_quality_metric_prompt = """Evaluate the AI's contribution to a meaningful conversation, considering coherence, fluency, groundedness, and conciseness.
Review the chat history for context. Rate the response on a 1-5 scale, with explanations for each criterion and its overall impact.
# Conversation History
{history}
# Current User Prompt
{prompt}
# AI-generated Response
{response}
"""

freeform_multi_turn_chat_quality_metric = PointwiseMetric(
metric="multi_turn_chat_quality_metric",
metric_prompt_template=pointwise_chat_quality_metric_prompt,
)

eval_dataset = pd.DataFrame(
{
"history": [history],
"prompt": prompt,
"response": response
}
)

# Run evaluation using the defined metric
eval_task = EvalTask(
dataset=eval_dataset,
metrics=[freeform_multi_turn_chat_quality_metric],
)

result = eval_task.evaluate()

return {
"score": result.metrics_table["multi_turn_chat_quality_metric/score"].item(),
"explanation": result.metrics_table["multi_turn_chat_quality_metric/explanation"].item(),
"mean": result.summary_metrics["multi_turn_chat_quality_metric/mean"].item()
}


def get_rag_context(input):
embedding = get_embedding(input)

query = f"""
SELECT *
FROM `{PROJECT_ID}.rag_data.rag_data`
WHERE id IN (
SELECT s.base.id
FROM VECTOR_SEARCH(
TABLE `{PROJECT_ID}.rag_data.embeddings`,
"embeddings",
(SELECT {embedding}),
top_k => 5) as s);
"""
print("waiting")
rows = BIGQUERY_CLIENT.query_and_wait(query)
print("waited")
bodies = [row["body"] for row in rows]
return " ### ".join(bodies)


@ui.page('/')
def index():
async def update_prompt():
print("prompt received")
input_time = time.time()
user_input = user_input_raw.value

with chat_container:
ui.chat_message(user_input, name='Me')

print("spinner")
spinner = ui.spinner('audio', size='lg', color='green')

client_id = str(uuid.uuid4())
request_id = f"{client_id}-{str(uuid.uuid4())[:8]}"

prompt, prompt_version = await run.cpu_bound(prompt_maker, user_input)

app.storage.client["count"] = app.storage.client.get("count", 0) + 1
app.storage.client["history"] = app.storage.client.get("history", "") + "### User: " + prompt

with TRACER.start_as_current_span("child") as span:
span.set_attribute(
"operation.count", app.storage.client["count"])
span.set_attribute("prompt", user_input)
span.set_attribute("prompt_id", prompt_version)
span.set_attribute("client_id", client_id)
span.set_attribute("request_id", request_id)

request_time = time.time()

response = await run.io_bound(make_gemini_prediction, prompt)
# response = make_prediction(user_input)
response_time = time.time()
app.storage.client["history"] = app.storage.client.get("history") + "### Agent: " + response
span.set_attribute("response", response)

spinner.delete()

ui.chat_message(response,
name='Robot',
stamp='now',
avatar='https://robohash.org/ui',) \
.style('font-family: Comic Sans, sans-serif; font-size: 16px;')

query = {
"request_id": request_id,
"prompt": user_input,
"response": response,
"input_time": input_time,
"request_time": request_time,
"response_time": response_time,
"prompt_version": prompt_version
}
print(f"Count: {app.storage.client['count']}")
write_to_database(client_id, query)
# print(multiturn_quality(
# app.storage.client.get("history"),
# prompt,
# response
# ))

ui.markdown("<h2>Welcome to predictions bot!</h2>")
with ui.row().classes('flex flex-col h-screen'):
chat_container = ui.column().classes('w-full max-w-3xl mx-auto my-6')

with ui.footer().classes('bg-black'), ui.column().classes('w-full max-w-3xl mx-auto my-6'):
with ui.row().classes('w-full no-wrap items-center'):
user_input_raw = ui.input("Prompt").on('keydown.enter', update_prompt) \
.props('rounded outlined input-class=mx-3').classes('flex-grow')


ui.run(host="0.0.0.0", port=int(os.environ.get("PORT", 8080)), storage_secret="1234", dark=True)
7 changes: 7 additions & 0 deletions ai-ml/spark-gemini-rag/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
google-cloud-bigquery
google-cloud-firestore
nicegui
opentelemetry-distro
pandas
sentence_transformers
vertexai
Loading

0 comments on commit cdfb8bb

Please sign in to comment.