Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Commit

Permalink
Merge pull request #136 from nsosio/feat/embedding-binary
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl authored Nov 6, 2023
2 parents 290a2d7 + 52c5969 commit 4ea6a43
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 9 deletions.
2 changes: 1 addition & 1 deletion ebd-all-minilm/build-aarch64-apple-darwin.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env bash
set -e
export VERSION=1.0.4
export VERSION=1.0.5

test -f venv/bin/activate || python -m venv venv
source venv/bin/activate
Expand Down
2 changes: 1 addition & 1 deletion ebd-all-minilm/build.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
set -e
export VERSION=1.0.4
export VERSION=1.0.5
source "$(dirname "${BASH_SOURCE[0]}")/../utils.sh"

build_cpu ghcr.io/premai-io/embeddings-all-minilm-l6-v2-cpu all-MiniLM-L6-v2 ${@:1}
Expand Down
15 changes: 13 additions & 2 deletions ebd-all-minilm/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse
import logging
import os

import uvicorn
from dotenv import load_dotenv
Expand All @@ -9,6 +11,15 @@

load_dotenv()

MODEL_DIR = os.getenv("MODEL_ID", "all-MiniLM-L6-v2")

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", help="Port to run model server on", type=int, default=8444)
parser.add_argument("--model-dir", help="Path to model dir", default=MODEL_DIR)
args = parser.parse_args()
MODEL_DIR = args.model_dir

logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
Expand All @@ -18,7 +29,7 @@

def create_start_app_handler(app: FastAPI):
def start_app() -> None:
SentenceTransformerBasedModel.get_model()
SentenceTransformerBasedModel.get_model(MODEL_DIR)

return start_app

Expand All @@ -41,4 +52,4 @@ def get_application() -> FastAPI:


if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000)
uvicorn.run("main:app", host="0.0.0.0", port=args.port)
7 changes: 2 additions & 5 deletions ebd-all-minilm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@ def embeddings(cls, texts):
return values.tolist()

@classmethod
def get_model(cls):
def get_model(cls, model_path):
if cls.model is None:
cls.model = SentenceTransformer(
os.getenv("MODEL_ID", "all-MiniLM-L6-v2"),
device=os.getenv("DEVICE", "cpu"),
)
cls.model = SentenceTransformer(model_path, device=os.getenv("DEVICE", "cpu"))
return cls.model

0 comments on commit 4ea6a43

Please sign in to comment.