Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Commit

Permalink
Merge pull request #104 from filopedraz/feat/petals-services
Browse files Browse the repository at this point in the history
Added petals models
  • Loading branch information
filopedraz authored Sep 20, 2023
2 parents 472fb39 + 01c62ac commit 0e3fdd4
Show file tree
Hide file tree
Showing 10 changed files with 322 additions and 0 deletions.
13 changes: 13 additions & 0 deletions cht-petals/.dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.editorconfig
.gitattributes
.github
.gitignore
.gitlab-ci.yml
.idea
.pre-commit-config.yaml
.readthedocs.yml
.travis.yml
venv
.git
./ml/models/
.bin
10 changes: 10 additions & 0 deletions cht-petals/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
set -e
export VERSION=1.0.0
source "$(dirname "${BASH_SOURCE[0]}")/../utils.sh"

# TODO: support linux/amd64
BUILDX_PLATFORM=linux/arm64 TESTS_SKIP_CPU=1 \
build_cpu ghcr.io/premai-io/chat-stable-beluga-2-cpu petals-team/StableBeluga2 ${@:1}
BUILDX_PLATFORM=linux/arm64 TESTS_SKIP_CPU=1 \
build_cpu ghcr.io/premai-io/chat-codellama-34b-cpu premai-io/CodeLlama-34b-Instruct-hf ${@:1}
22 changes: 22 additions & 0 deletions cht-petals/docker/cpu/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
FROM python:3.10-slim-bullseye

ARG MODEL_ID

RUN apt update && apt install -y libopenblas-dev ninja-build build-essential wget git
RUN python -m pip install --upgrade pip pytest cmake scikit-build setuptools

WORKDIR /usr/src/app/

COPY requirements.txt ./

RUN pip install --no-cache-dir -r ./requirements.txt --upgrade pip

COPY download.py .

RUN python3 download.py --model $MODEL_ID

COPY . .

ENV MODEL_ID=$MODEL_ID

CMD python main.py
23 changes: 23 additions & 0 deletions cht-petals/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import argparse

from petals import AutoDistributedModelForCausalLM
from tenacity import retry, stop_after_attempt, wait_fixed
from transformers import AutoTokenizer, LlamaTokenizer

parser = argparse.ArgumentParser()
parser.add_argument("--model", help="Model to download")
args = parser.parse_args()

print(f"Downloading model {args.model}")


@retry(stop=stop_after_attempt(3), wait=wait_fixed(5))
def download_model() -> None:
if "llama" in args.model.lower():
_ = LlamaTokenizer.from_pretrained(args.model)
else:
_ = AutoTokenizer.from_pretrained(args.model)
_ = AutoDistributedModelForCausalLM.from_pretrained(args.model)


download_model()
45 changes: 45 additions & 0 deletions cht-petals/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import logging

import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from routes import router as api_router

load_dotenv()

logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)


def create_start_app_handler(app: FastAPI):
def start_app() -> None:
from models import PetalsBasedModel

PetalsBasedModel.get_model()

return start_app


def get_application() -> FastAPI:
application = FastAPI(title="prem-chat", debug=True, version="0.0.1")
application.include_router(api_router, prefix="/v1")
application.add_event_handler("startup", create_start_app_handler(application))
application.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
return application


app = get_application()


if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000)
67 changes: 67 additions & 0 deletions cht-petals/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
from abc import ABC, abstractmethod
from typing import List

from petals import AutoDistributedModelForCausalLM
from transformers import AutoTokenizer, LlamaTokenizer, logging

logging.set_verbosity_error()


class ChatModel(ABC):
@abstractmethod
def get_model(cls):
pass

@abstractmethod
def generate(
cls,
messages: list,
temperature: float = 0.9,
top_p: float = 0.9,
n: int = 1,
stream: bool = False,
max_tokens: int = 128,
stop: str = "",
**kwargs,
):
pass

@abstractmethod
def embeddings(cls, text):
pass


class PetalsBasedModel(ChatModel):
model = None
tokenizer = None

@classmethod
def generate(
cls,
messages: list,
temperature: float = 0.9,
top_p: float = 0.9,
n: int = 1,
stream: bool = False,
max_tokens: int = 128,
stop: str = "",
**kwargs,
) -> List:
message = messages[-1]["content"]
inputs = cls.tokenizer(message, return_tensors="pt")["input_ids"]
outputs = cls.model.generate(inputs, max_new_tokens=5)
print(cls.tokenizer.decode(outputs[0]))
return [cls.tokenizer.decode(outputs[0])]

@classmethod
def get_model(cls):
if cls.model is None:
if "llama" in os.getenv("MODEL_ID").lower():
cls.tokenizer = LlamaTokenizer.from_pretrained(os.getenv("MODEL_ID"))
else:
cls.tokenizer = AutoTokenizer.from_pretrained(os.getenv("MODEL_ID"))
cls.model = AutoDistributedModelForCausalLM.from_pretrained(
os.getenv("MODEL_ID")
)
return cls.model
9 changes: 9 additions & 0 deletions cht-petals/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
fastapi==0.95.0
uvicorn==0.21.1
pytest==7.2.2
requests==2.28.2
tqdm==4.65.0
httpx==0.23.3
python-dotenv==1.0.0
tenacity==8.2.2
petals==2.2.0
107 changes: 107 additions & 0 deletions cht-petals/routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import json
import uuid
from datetime import datetime as dt
from typing import List, Optional, Union

from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from models import PetalsBasedModel as model
from pydantic import BaseModel


class ChatCompletionInput(BaseModel):
model: str
messages: List[dict]
temperature: float = 1.0
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: Optional[Union[str, List[str]]] = ""
max_tokens: int = 7
presence_penalty: float = 0.0
frequence_penalty: float = 0.0
logit_bias: Optional[dict] = {}
user: str = ""


class ChatCompletionResponse(BaseModel):
id: str = uuid.uuid4()
model: str
object: str = "chat.completion"
created: int = int(dt.now().timestamp())
choices: List[dict]
usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}


class HealthResponse(BaseModel):
status: bool


router = APIRouter()


@router.get("/", response_model=HealthResponse)
async def health():
return HealthResponse(status=True)


async def generate_chunk_based_response(body, text):
yield "event: completion\ndata: " + json.dumps(
{
"id": str(uuid.uuid4()),
"model": body.model,
"object": "chat.completion",
"choices": [
{
"role": "assistant",
"index": 1,
"delta": {"role": "assistant", "content": text},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
) + "\n\n"
yield "event: done\ndata: [DONE]\n\n"


@router.post("/chat/completions", response_model=ChatCompletionResponse)
async def chat_completions(body: ChatCompletionInput):
try:
predictions = model.generate(
messages=body.messages,
temperature=body.temperature,
top_p=body.top_p,
n=body.n,
stream=body.stream,
max_tokens=body.max_tokens,
stop=body.stop,
presence_penalty=body.presence_penalty,
frequence_penalty=body.frequence_penalty,
logit_bias=body.logit_bias,
)
if body.stream:
return StreamingResponse(
generate_chunk_based_response(body, predictions[0]),
media_type="text/event-stream",
)
return ChatCompletionResponse(
id=str(uuid.uuid4()),
model=body.model,
object="chat.completion",
choices=[
{
"role": "assistant",
"index": idx,
"message": {"role": "assistant", "content": text},
"finish_reason": "stop",
}
for idx, text in enumerate(predictions)
],
usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
)
except ValueError as error:
raise HTTPException(
status_code=400,
detail={"message": str(error)},
)
Empty file added cht-petals/tests/__init__.py
Empty file.
26 changes: 26 additions & 0 deletions cht-petals/tests/test_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from fastapi.testclient import TestClient
from main import get_application


def test_chat_llama_cpp() -> None:
app = get_application()
with TestClient(app) as client:
response = client.post(
"/v1/chat/completions",
json={
"model": "stable-beluga",
"messages": [{"role": "user", "content": "Hello!"}],
"n_threads": 10,
},
)
assert response.status_code == 200

response = client.post(
"/v1/chat/completions",
json={
"stream": True,
"model": "stable-beluga",
"messages": [{"role": "user", "content": "Hello!"}],
},
)
assert response.status_code == 200

0 comments on commit 0e3fdd4

Please sign in to comment.