This repository has been archived by the owner on Dec 6, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #104 from filopedraz/feat/petals-services
Added petals models
- Loading branch information
Showing
10 changed files
with
322 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
.editorconfig | ||
.gitattributes | ||
.github | ||
.gitignore | ||
.gitlab-ci.yml | ||
.idea | ||
.pre-commit-config.yaml | ||
.readthedocs.yml | ||
.travis.yml | ||
venv | ||
.git | ||
./ml/models/ | ||
.bin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#!/bin/bash | ||
set -e | ||
export VERSION=1.0.0 | ||
source "$(dirname "${BASH_SOURCE[0]}")/../utils.sh" | ||
|
||
# TODO: support linux/amd64 | ||
BUILDX_PLATFORM=linux/arm64 TESTS_SKIP_CPU=1 \ | ||
build_cpu ghcr.io/premai-io/chat-stable-beluga-2-cpu petals-team/StableBeluga2 ${@:1} | ||
BUILDX_PLATFORM=linux/arm64 TESTS_SKIP_CPU=1 \ | ||
build_cpu ghcr.io/premai-io/chat-codellama-34b-cpu premai-io/CodeLlama-34b-Instruct-hf ${@:1} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
FROM python:3.10-slim-bullseye | ||
|
||
ARG MODEL_ID | ||
|
||
RUN apt update && apt install -y libopenblas-dev ninja-build build-essential wget git | ||
RUN python -m pip install --upgrade pip pytest cmake scikit-build setuptools | ||
|
||
WORKDIR /usr/src/app/ | ||
|
||
COPY requirements.txt ./ | ||
|
||
RUN pip install --no-cache-dir -r ./requirements.txt --upgrade pip | ||
|
||
COPY download.py . | ||
|
||
RUN python3 download.py --model $MODEL_ID | ||
|
||
COPY . . | ||
|
||
ENV MODEL_ID=$MODEL_ID | ||
|
||
CMD python main.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import argparse | ||
|
||
from petals import AutoDistributedModelForCausalLM | ||
from tenacity import retry, stop_after_attempt, wait_fixed | ||
from transformers import AutoTokenizer, LlamaTokenizer | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model", help="Model to download") | ||
args = parser.parse_args() | ||
|
||
print(f"Downloading model {args.model}") | ||
|
||
|
||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(5)) | ||
def download_model() -> None: | ||
if "llama" in args.model.lower(): | ||
_ = LlamaTokenizer.from_pretrained(args.model) | ||
else: | ||
_ = AutoTokenizer.from_pretrained(args.model) | ||
_ = AutoDistributedModelForCausalLM.from_pretrained(args.model) | ||
|
||
|
||
download_model() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import logging | ||
|
||
import uvicorn | ||
from dotenv import load_dotenv | ||
from fastapi import FastAPI | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from routes import router as api_router | ||
|
||
load_dotenv() | ||
|
||
logging.basicConfig( | ||
format="%(asctime)s %(levelname)-8s %(message)s", | ||
level=logging.INFO, | ||
datefmt="%Y-%m-%d %H:%M:%S", | ||
) | ||
|
||
|
||
def create_start_app_handler(app: FastAPI): | ||
def start_app() -> None: | ||
from models import PetalsBasedModel | ||
|
||
PetalsBasedModel.get_model() | ||
|
||
return start_app | ||
|
||
|
||
def get_application() -> FastAPI: | ||
application = FastAPI(title="prem-chat", debug=True, version="0.0.1") | ||
application.include_router(api_router, prefix="/v1") | ||
application.add_event_handler("startup", create_start_app_handler(application)) | ||
application.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=["*"], | ||
allow_credentials=True, | ||
allow_methods=["*"], | ||
allow_headers=["*"], | ||
) | ||
return application | ||
|
||
|
||
app = get_application() | ||
|
||
|
||
if __name__ == "__main__": | ||
uvicorn.run("main:app", host="0.0.0.0", port=8000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import os | ||
from abc import ABC, abstractmethod | ||
from typing import List | ||
|
||
from petals import AutoDistributedModelForCausalLM | ||
from transformers import AutoTokenizer, LlamaTokenizer, logging | ||
|
||
logging.set_verbosity_error() | ||
|
||
|
||
class ChatModel(ABC): | ||
@abstractmethod | ||
def get_model(cls): | ||
pass | ||
|
||
@abstractmethod | ||
def generate( | ||
cls, | ||
messages: list, | ||
temperature: float = 0.9, | ||
top_p: float = 0.9, | ||
n: int = 1, | ||
stream: bool = False, | ||
max_tokens: int = 128, | ||
stop: str = "", | ||
**kwargs, | ||
): | ||
pass | ||
|
||
@abstractmethod | ||
def embeddings(cls, text): | ||
pass | ||
|
||
|
||
class PetalsBasedModel(ChatModel): | ||
model = None | ||
tokenizer = None | ||
|
||
@classmethod | ||
def generate( | ||
cls, | ||
messages: list, | ||
temperature: float = 0.9, | ||
top_p: float = 0.9, | ||
n: int = 1, | ||
stream: bool = False, | ||
max_tokens: int = 128, | ||
stop: str = "", | ||
**kwargs, | ||
) -> List: | ||
message = messages[-1]["content"] | ||
inputs = cls.tokenizer(message, return_tensors="pt")["input_ids"] | ||
outputs = cls.model.generate(inputs, max_new_tokens=5) | ||
print(cls.tokenizer.decode(outputs[0])) | ||
return [cls.tokenizer.decode(outputs[0])] | ||
|
||
@classmethod | ||
def get_model(cls): | ||
if cls.model is None: | ||
if "llama" in os.getenv("MODEL_ID").lower(): | ||
cls.tokenizer = LlamaTokenizer.from_pretrained(os.getenv("MODEL_ID")) | ||
else: | ||
cls.tokenizer = AutoTokenizer.from_pretrained(os.getenv("MODEL_ID")) | ||
cls.model = AutoDistributedModelForCausalLM.from_pretrained( | ||
os.getenv("MODEL_ID") | ||
) | ||
return cls.model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
fastapi==0.95.0 | ||
uvicorn==0.21.1 | ||
pytest==7.2.2 | ||
requests==2.28.2 | ||
tqdm==4.65.0 | ||
httpx==0.23.3 | ||
python-dotenv==1.0.0 | ||
tenacity==8.2.2 | ||
petals==2.2.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import json | ||
import uuid | ||
from datetime import datetime as dt | ||
from typing import List, Optional, Union | ||
|
||
from fastapi import APIRouter, HTTPException | ||
from fastapi.responses import StreamingResponse | ||
from models import PetalsBasedModel as model | ||
from pydantic import BaseModel | ||
|
||
|
||
class ChatCompletionInput(BaseModel): | ||
model: str | ||
messages: List[dict] | ||
temperature: float = 1.0 | ||
top_p: float = 1.0 | ||
n: int = 1 | ||
stream: bool = False | ||
stop: Optional[Union[str, List[str]]] = "" | ||
max_tokens: int = 7 | ||
presence_penalty: float = 0.0 | ||
frequence_penalty: float = 0.0 | ||
logit_bias: Optional[dict] = {} | ||
user: str = "" | ||
|
||
|
||
class ChatCompletionResponse(BaseModel): | ||
id: str = uuid.uuid4() | ||
model: str | ||
object: str = "chat.completion" | ||
created: int = int(dt.now().timestamp()) | ||
choices: List[dict] | ||
usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} | ||
|
||
|
||
class HealthResponse(BaseModel): | ||
status: bool | ||
|
||
|
||
router = APIRouter() | ||
|
||
|
||
@router.get("/", response_model=HealthResponse) | ||
async def health(): | ||
return HealthResponse(status=True) | ||
|
||
|
||
async def generate_chunk_based_response(body, text): | ||
yield "event: completion\ndata: " + json.dumps( | ||
{ | ||
"id": str(uuid.uuid4()), | ||
"model": body.model, | ||
"object": "chat.completion", | ||
"choices": [ | ||
{ | ||
"role": "assistant", | ||
"index": 1, | ||
"delta": {"role": "assistant", "content": text}, | ||
"finish_reason": "stop", | ||
} | ||
], | ||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, | ||
} | ||
) + "\n\n" | ||
yield "event: done\ndata: [DONE]\n\n" | ||
|
||
|
||
@router.post("/chat/completions", response_model=ChatCompletionResponse) | ||
async def chat_completions(body: ChatCompletionInput): | ||
try: | ||
predictions = model.generate( | ||
messages=body.messages, | ||
temperature=body.temperature, | ||
top_p=body.top_p, | ||
n=body.n, | ||
stream=body.stream, | ||
max_tokens=body.max_tokens, | ||
stop=body.stop, | ||
presence_penalty=body.presence_penalty, | ||
frequence_penalty=body.frequence_penalty, | ||
logit_bias=body.logit_bias, | ||
) | ||
if body.stream: | ||
return StreamingResponse( | ||
generate_chunk_based_response(body, predictions[0]), | ||
media_type="text/event-stream", | ||
) | ||
return ChatCompletionResponse( | ||
id=str(uuid.uuid4()), | ||
model=body.model, | ||
object="chat.completion", | ||
choices=[ | ||
{ | ||
"role": "assistant", | ||
"index": idx, | ||
"message": {"role": "assistant", "content": text}, | ||
"finish_reason": "stop", | ||
} | ||
for idx, text in enumerate(predictions) | ||
], | ||
usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, | ||
) | ||
except ValueError as error: | ||
raise HTTPException( | ||
status_code=400, | ||
detail={"message": str(error)}, | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from fastapi.testclient import TestClient | ||
from main import get_application | ||
|
||
|
||
def test_chat_llama_cpp() -> None: | ||
app = get_application() | ||
with TestClient(app) as client: | ||
response = client.post( | ||
"/v1/chat/completions", | ||
json={ | ||
"model": "stable-beluga", | ||
"messages": [{"role": "user", "content": "Hello!"}], | ||
"n_threads": 10, | ||
}, | ||
) | ||
assert response.status_code == 200 | ||
|
||
response = client.post( | ||
"/v1/chat/completions", | ||
json={ | ||
"stream": True, | ||
"model": "stable-beluga", | ||
"messages": [{"role": "user", "content": "Hello!"}], | ||
}, | ||
) | ||
assert response.status_code == 200 |