Skip to content

Commit

Permalink
Fix Model download and FastAPI response (#4)
Browse files Browse the repository at this point in the history
* Fixed the message and status code;
* fixed the issue with model scan and storage;
* updated the python version; packages updated
  • Loading branch information
ranjan-stha authored Oct 16, 2024
1 parent b7c5e23 commit 6c6d36a
Show file tree
Hide file tree
Showing 7 changed files with 1,043 additions and 901 deletions.
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.10-slim-buster
FROM python:3.12-slim-bullseye

LABEL maintainer="TC Developers"

Expand Down
5 changes: 3 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Optional, Union

from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi import FastAPI, Response, status
from pydantic import BaseModel

from embedding_models import (
Expand Down Expand Up @@ -37,7 +37,8 @@ class RequestSchemaForEmbeddings(BaseModel):

@app.get("/")
async def home():
return "Embedding handler using models for texts", 200
"""Returns a message"""
return Response(content="Embedding handler using models for texts", status_code=status.HTTP_200_OK)


@app.post("/get_embeddings")
Expand Down
6 changes: 3 additions & 3 deletions embedding_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sentence_transformers import SentenceTransformer
from torch import Tensor

from utils import download_models
from utils import check_models


@dataclass
Expand All @@ -24,8 +24,8 @@ def __post_init__(self):
"""
Post initialization
"""
models_info = download_models(sent_embedding_model=self.model)
self.st_embedding_model = SentenceTransformer(model_name_or_path=models_info["model_path"])
model_path = check_models(sent_embedding_model=self.model)
self.st_embedding_model = SentenceTransformer(model_name_or_path=model_path)

def embed_documents(self, texts: list) -> np.ndarray:
"""
Expand Down
1,888 changes: 1,009 additions & 879 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ license = "GNU Affero General Public License v3.0"
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.10"
python = "^3.12"
langchain = "^0.2.15"
langchain-community = "^0.2.14"
langchain-openai = "^0.1.23"
sentence-transformers = "^3.1.0"
torch = {url = "https://download.pytorch.org/whl/cpu/torch-2.0.0%2Bcpu-cp310-cp310-linux_x86_64.whl"}
torch = {url = "https://download.pytorch.org/whl/cpu/torch-2.2.0%2Bcpu-cp312-cp312-linux_x86_64.whl#sha256=8258824bec0181e01a086aef5809016116a97626af2dcbf932d4e0b192d9c1b8"}
fastapi = "^0.114.0"
uvicorn = "^0.22.0"
sentry-sdk = "^1.5.8"
Expand Down
38 changes: 24 additions & 14 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,39 @@
logger.setLevel(logging.INFO)


def download_models(sent_embedding_model: str):
def download_model(embedding_model: str, models_path: str):
"""Downloads the model"""
logger.info("Downloading the model")
embedding_model_local_path = snapshot_download(repo_id=embedding_model, cache_dir=models_path)
return embedding_model_local_path


def check_models(sent_embedding_model: str):
"""Check if the model already exists"""
models_path = Path("/opt/models")
models_info_path = models_path / "model_info.json"

if not os.path.exists(models_path):
os.makedirs(models_path)

if not any(os.listdir(models_path)):
logger.info("Downloading the model")
embedding_model_local_path = snapshot_download(repo_id=sent_embedding_model, cache_dir=models_path)
embedding_model_local_path = download_model(embedding_model=sent_embedding_model, models_path=models_path)
models_info = {
"model": sent_embedding_model,
"model_path": embedding_model_local_path,
sent_embedding_model: embedding_model_local_path,
}

with open(models_info_path, "w", encoding="utf-8") as m_info_f:
json.dump(models_info, m_info_f)

else:
if os.path.exists(models_info_path):
logger.info("Models already exists.")
logger.info(models_info_path)
with open(models_info_path, "r", encoding="utf-8") as m_info_f:
models_info = json.load(m_info_f)

return models_info
return embedding_model_local_path
if os.path.exists(models_info_path):
with open(models_info_path, "r", encoding="utf-8") as m_info_f:
models_info_dict = json.load(m_info_f)
if sent_embedding_model not in models_info_dict.keys():
embedding_model_local_path = download_model(embedding_model=sent_embedding_model, models_path=models_path)
models_info_dict[sent_embedding_model] = embedding_model_local_path
with open(models_info_path, "w", encoding="utf-8") as m_info_f:
json.dump(models_info_dict, m_info_f)
return embedding_model_local_path

logger.info("Model is already available.")
return models_info_dict[sent_embedding_model]

0 comments on commit 6c6d36a

Please sign in to comment.