diff --git a/cht-petals/.dockerignore b/cht-petals/.dockerignore new file mode 100644 index 0000000..ebf7f70 --- /dev/null +++ b/cht-petals/.dockerignore @@ -0,0 +1,13 @@ +.editorconfig +.gitattributes +.github +.gitignore +.gitlab-ci.yml +.idea +.pre-commit-config.yaml +.readthedocs.yml +.travis.yml +venv +.git +./ml/models/ +.bin diff --git a/cht-petals/build.sh b/cht-petals/build.sh new file mode 100755 index 0000000..0139474 --- /dev/null +++ b/cht-petals/build.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -e +export VERSION=1.0.0 +source "$(dirname "${BASH_SOURCE[0]}")/../utils.sh" + +# TODO: support linux/amd64 +BUILDX_PLATFORM=linux/arm64 TESTS_SKIP_CPU=1 \ +build_cpu ghcr.io/premai-io/chat-stable-beluga-2-cpu petals-team/StableBeluga2 ${@:1} +BUILDX_PLATFORM=linux/arm64 TESTS_SKIP_CPU=1 \ +build_cpu ghcr.io/premai-io/chat-codellama-34b-cpu premai-io/CodeLlama-34b-Instruct-hf ${@:1} diff --git a/cht-petals/docker/cpu/Dockerfile b/cht-petals/docker/cpu/Dockerfile new file mode 100644 index 0000000..65a558b --- /dev/null +++ b/cht-petals/docker/cpu/Dockerfile @@ -0,0 +1,22 @@ +FROM python:3.10-slim-bullseye + +ARG MODEL_ID + +RUN apt update && apt install -y libopenblas-dev ninja-build build-essential wget git +RUN python -m pip install --upgrade pip pytest cmake scikit-build setuptools + +WORKDIR /usr/src/app/ + +COPY requirements.txt ./ + +RUN pip install --no-cache-dir -r ./requirements.txt --upgrade pip + +COPY download.py . + +RUN python3 download.py --model $MODEL_ID + +COPY . . + +ENV MODEL_ID=$MODEL_ID + +CMD python main.py diff --git a/cht-petals/download.py b/cht-petals/download.py new file mode 100644 index 0000000..4f1d938 --- /dev/null +++ b/cht-petals/download.py @@ -0,0 +1,23 @@ +import argparse + +from petals import AutoDistributedModelForCausalLM +from tenacity import retry, stop_after_attempt, wait_fixed +from transformers import AutoTokenizer, LlamaTokenizer + +parser = argparse.ArgumentParser() +parser.add_argument("--model", help="Model to download") +args = parser.parse_args() + +print(f"Downloading model {args.model}") + + +@retry(stop=stop_after_attempt(3), wait=wait_fixed(5)) +def download_model() -> None: + if "llama" in args.model.lower(): + _ = LlamaTokenizer.from_pretrained(args.model) + else: + _ = AutoTokenizer.from_pretrained(args.model) + _ = AutoDistributedModelForCausalLM.from_pretrained(args.model) + + +download_model() diff --git a/cht-petals/main.py b/cht-petals/main.py new file mode 100644 index 0000000..0b416bb --- /dev/null +++ b/cht-petals/main.py @@ -0,0 +1,45 @@ +import logging + +import uvicorn +from dotenv import load_dotenv +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from routes import router as api_router + +load_dotenv() + +logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) + + +def create_start_app_handler(app: FastAPI): + def start_app() -> None: + from models import PetalsBasedModel + + PetalsBasedModel.get_model() + + return start_app + + +def get_application() -> FastAPI: + application = FastAPI(title="prem-chat", debug=True, version="0.0.1") + application.include_router(api_router, prefix="/v1") + application.add_event_handler("startup", create_start_app_handler(application)) + application.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + return application + + +app = get_application() + + +if __name__ == "__main__": + uvicorn.run("main:app", host="0.0.0.0", port=8000) diff --git a/cht-petals/models.py b/cht-petals/models.py new file mode 100644 index 0000000..b401836 --- /dev/null +++ b/cht-petals/models.py @@ -0,0 +1,67 @@ +import os +from abc import ABC, abstractmethod +from typing import List + +from petals import AutoDistributedModelForCausalLM +from transformers import AutoTokenizer, LlamaTokenizer, logging + +logging.set_verbosity_error() + + +class ChatModel(ABC): + @abstractmethod + def get_model(cls): + pass + + @abstractmethod + def generate( + cls, + messages: list, + temperature: float = 0.9, + top_p: float = 0.9, + n: int = 1, + stream: bool = False, + max_tokens: int = 128, + stop: str = "", + **kwargs, + ): + pass + + @abstractmethod + def embeddings(cls, text): + pass + + +class PetalsBasedModel(ChatModel): + model = None + tokenizer = None + + @classmethod + def generate( + cls, + messages: list, + temperature: float = 0.9, + top_p: float = 0.9, + n: int = 1, + stream: bool = False, + max_tokens: int = 128, + stop: str = "", + **kwargs, + ) -> List: + message = messages[-1]["content"] + inputs = cls.tokenizer(message, return_tensors="pt")["input_ids"] + outputs = cls.model.generate(inputs, max_new_tokens=5) + print(cls.tokenizer.decode(outputs[0])) + return [cls.tokenizer.decode(outputs[0])] + + @classmethod + def get_model(cls): + if cls.model is None: + if "llama" in os.getenv("MODEL_ID").lower(): + cls.tokenizer = LlamaTokenizer.from_pretrained(os.getenv("MODEL_ID")) + else: + cls.tokenizer = AutoTokenizer.from_pretrained(os.getenv("MODEL_ID")) + cls.model = AutoDistributedModelForCausalLM.from_pretrained( + os.getenv("MODEL_ID") + ) + return cls.model diff --git a/cht-petals/requirements.txt b/cht-petals/requirements.txt new file mode 100644 index 0000000..205b6a1 --- /dev/null +++ b/cht-petals/requirements.txt @@ -0,0 +1,9 @@ +fastapi==0.95.0 +uvicorn==0.21.1 +pytest==7.2.2 +requests==2.28.2 +tqdm==4.65.0 +httpx==0.23.3 +python-dotenv==1.0.0 +tenacity==8.2.2 +petals==2.2.0 diff --git a/cht-petals/routes.py b/cht-petals/routes.py new file mode 100644 index 0000000..b009f39 --- /dev/null +++ b/cht-petals/routes.py @@ -0,0 +1,107 @@ +import json +import uuid +from datetime import datetime as dt +from typing import List, Optional, Union + +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse +from models import PetalsBasedModel as model +from pydantic import BaseModel + + +class ChatCompletionInput(BaseModel): + model: str + messages: List[dict] + temperature: float = 1.0 + top_p: float = 1.0 + n: int = 1 + stream: bool = False + stop: Optional[Union[str, List[str]]] = "" + max_tokens: int = 7 + presence_penalty: float = 0.0 + frequence_penalty: float = 0.0 + logit_bias: Optional[dict] = {} + user: str = "" + + +class ChatCompletionResponse(BaseModel): + id: str = uuid.uuid4() + model: str + object: str = "chat.completion" + created: int = int(dt.now().timestamp()) + choices: List[dict] + usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + + +class HealthResponse(BaseModel): + status: bool + + +router = APIRouter() + + +@router.get("/", response_model=HealthResponse) +async def health(): + return HealthResponse(status=True) + + +async def generate_chunk_based_response(body, text): + yield "event: completion\ndata: " + json.dumps( + { + "id": str(uuid.uuid4()), + "model": body.model, + "object": "chat.completion", + "choices": [ + { + "role": "assistant", + "index": 1, + "delta": {"role": "assistant", "content": text}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } + ) + "\n\n" + yield "event: done\ndata: [DONE]\n\n" + + +@router.post("/chat/completions", response_model=ChatCompletionResponse) +async def chat_completions(body: ChatCompletionInput): + try: + predictions = model.generate( + messages=body.messages, + temperature=body.temperature, + top_p=body.top_p, + n=body.n, + stream=body.stream, + max_tokens=body.max_tokens, + stop=body.stop, + presence_penalty=body.presence_penalty, + frequence_penalty=body.frequence_penalty, + logit_bias=body.logit_bias, + ) + if body.stream: + return StreamingResponse( + generate_chunk_based_response(body, predictions[0]), + media_type="text/event-stream", + ) + return ChatCompletionResponse( + id=str(uuid.uuid4()), + model=body.model, + object="chat.completion", + choices=[ + { + "role": "assistant", + "index": idx, + "message": {"role": "assistant", "content": text}, + "finish_reason": "stop", + } + for idx, text in enumerate(predictions) + ], + usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + ) + except ValueError as error: + raise HTTPException( + status_code=400, + detail={"message": str(error)}, + ) diff --git a/cht-petals/tests/__init__.py b/cht-petals/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cht-petals/tests/test_views.py b/cht-petals/tests/test_views.py new file mode 100644 index 0000000..fdfbbd3 --- /dev/null +++ b/cht-petals/tests/test_views.py @@ -0,0 +1,26 @@ +from fastapi.testclient import TestClient +from main import get_application + + +def test_chat_llama_cpp() -> None: + app = get_application() + with TestClient(app) as client: + response = client.post( + "/v1/chat/completions", + json={ + "model": "stable-beluga", + "messages": [{"role": "user", "content": "Hello!"}], + "n_threads": 10, + }, + ) + assert response.status_code == 200 + + response = client.post( + "/v1/chat/completions", + json={ + "stream": True, + "model": "stable-beluga", + "messages": [{"role": "user", "content": "Hello!"}], + }, + ) + assert response.status_code == 200