Skip to content

Commit

Permalink
cleaned up a bit (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
fpgmaas authored Jun 23, 2024
1 parent d1f5818 commit c593ad4
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 45 deletions.
1 change: 1 addition & 0 deletions frontend/app/utils/search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export const handleSearch = async (
`${apiUrl}/search`,
{
query: query,
top_k: 40,
},
{
headers: {
Expand Down
17 changes: 15 additions & 2 deletions pypi_scout/api/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def load_dataset(self) -> Tuple[pl.DataFrame, pl.DataFrame]:
else:
raise ValueError(f"Unexpected value found for STORAGE_BACKEND: {self.config.STORAGE_BACKEND}") # noqa: TRY003

df_embeddings = self._drop_rows_from_embeddings_that_do_not_appear_in_packages(df_embeddings, df_packages)
return df_packages, df_embeddings

def _load_local_dataset(self) -> Tuple[pl.DataFrame, pl.DataFrame]:
Expand Down Expand Up @@ -56,10 +57,22 @@ def _load_blob_dataset(self) -> Tuple[pl.DataFrame, pl.DataFrame]:

return df_packages, df_embeddings

def _log_packages_dataset_info(self, df_packages: pl.DataFrame) -> None:
@staticmethod
def _log_packages_dataset_info(df_packages: pl.DataFrame) -> None:
logging.info(f"Finished loading the `packages` dataset. Number of rows in dataset: {len(df_packages):,}")
logging.info(df_packages.describe())

def _log_embeddings_dataset_info(self, df_embeddings: pl.DataFrame) -> None:
@staticmethod
def _log_embeddings_dataset_info(df_embeddings: pl.DataFrame) -> None:
logging.info(f"Finished loading the `embeddings` dataset. Number of rows in dataset: {len(df_embeddings):,}")
logging.info(df_embeddings.describe())

@staticmethod
def _drop_rows_from_embeddings_that_do_not_appear_in_packages(df_embeddings, df_packages):
# We only keep the packages in the vector dataset that also occur in the packages dataset.
# In theory, this should never drop something. But still good to keep as a fail-safe to prevent issues in the API.
logging.info("Dropping packages in the `embeddings` dataset that do not occur in the `packages` dataset...")
logging.info(f"Number of rows before dropping: {len(df_embeddings):,}...")
df_embeddings = df_embeddings.join(df_packages, on="name", how="semi")
logging.info(f"Number of rows after dropping: {len(df_embeddings):,}...")
return df_embeddings
44 changes: 4 additions & 40 deletions pypi_scout/api/main.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,35 @@
import logging

import polars as pl
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from starlette.requests import Request

from pypi_scout.api.data_loader import ApiDataLoader
from pypi_scout.api.models import QueryModel, SearchResponse
from pypi_scout.config import Config
from pypi_scout.embeddings.simple_vector_database import SimpleVectorDatabase
from pypi_scout.utils.logging import setup_logging
from pypi_scout.utils.score_calculator import calculate_score

# Setup logging
setup_logging()
logging.info("Initializing backend...")

# Initialize limiter
limiter = Limiter(key_func=get_remote_address)
app = FastAPI()
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

# Load environment variables and configuration
load_dotenv()
config = Config()

# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Temporary wildcard for testing
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
Expand All @@ -44,28 +39,9 @@
df_packages, df_embeddings = data_loader.load_dataset()

model = SentenceTransformer(config.EMBEDDINGS_MODEL_NAME)

vector_database = SimpleVectorDatabase(embeddings_model=model, df_embeddings=df_embeddings)


class QueryModel(BaseModel):
query: str
top_k: int = config.N_RESULTS_TO_RETURN


class Match(BaseModel):
name: str
summary: str
similarity: float
weekly_downloads: int


class SearchResponse(BaseModel):
matches: list[Match]
warning: bool = False
warning_message: str = None


@app.post("/api/search", response_model=SearchResponse)
@limiter.limit("4/minute")
async def search(query: QueryModel, request: Request):
Expand All @@ -75,7 +51,7 @@ async def search(query: QueryModel, request: Request):
The top_k packages with the highest score are returned.
"""

if query.top_k > 60:
if query.top_k > 100:
raise HTTPException(status_code=400, detail="top_k cannot be larger than 100.")

logging.info(f"Searching for similar projects. Query: '{query.query}'")
Expand All @@ -85,18 +61,6 @@ async def search(query: QueryModel, request: Request):
f"Fetched the {len(df_matches)} most similar projects. Calculating the weighted scores and filtering..."
)

warning = False
warning_message = ""
matches_missing_in_local_dataset = df_matches.filter(pl.col("weekly_downloads").is_null())["name"].to_list()
if matches_missing_in_local_dataset:
warning = True
warning_message = (
f"The following entries have 'None' for 'weekly_downloads': {matches_missing_in_local_dataset}. "
"These entries were found in the vector database but not in the local dataset and have been excluded from the results."
)
logging.error(warning_message)
df_matches = df_matches.filter(~pl.col("name").is_in(matches_missing_in_local_dataset))

df_matches = calculate_score(
df_matches, weight_similarity=config.WEIGHT_SIMILARITY, weight_weekly_downloads=config.WEIGHT_WEEKLY_DOWNLOADS
)
Expand All @@ -107,4 +71,4 @@ async def search(query: QueryModel, request: Request):

logging.info(f"Returning the {len(df_matches)} best matches.")
df_matches = df_matches.select(["name", "similarity", "summary", "weekly_downloads"])
return SearchResponse(matches=df_matches.to_dicts(), warning=warning, warning_message=warning_message)
return SearchResponse(matches=df_matches.to_dicts())
19 changes: 19 additions & 0 deletions pypi_scout/api/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pydantic import BaseModel


class QueryModel(BaseModel):
query: str
top_k: int


class Match(BaseModel):
name: str
summary: str
similarity: float
weekly_downloads: int


class SearchResponse(BaseModel):
matches: list[Match]
warning: bool = False
warning_message: str = None
3 changes: 0 additions & 3 deletions pypi_scout/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ class Config:
# Google Drive file ID for downloading the raw dataset.
GOOGLE_FILE_ID = "1IDJvCsq1gz0yUSXgff13pMl3nUk7zJzb"

# Number of top results to return for a query.
N_RESULTS_TO_RETURN = 40

# Fraction of the dataset to include in the vector database. This value determines the portion of top packages
# (sorted by weekly downloads) to include. Increase this value to include a larger portion of the dataset, up to 1.0 (100%).
# For reference, a value of 0.25 corresponds to including all PyPI packages with at least approximately 650 weekly downloads
Expand Down

0 comments on commit c593ad4

Please sign in to comment.