This repository has been archived by the owner on Dec 6, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
84f7d52
commit a3d9f62
Showing
3 changed files
with
229 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import logging | ||
from typing import Callable | ||
|
||
import uvicorn | ||
from dotenv import load_dotenv | ||
from fastapi import FastAPI | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from models import XGenBasedModel | ||
from routes import router as api_router | ||
|
||
load_dotenv() | ||
|
||
logging.basicConfig( | ||
format="%(asctime)s %(levelname)-8s %(message)s", | ||
level=logging.INFO, | ||
datefmt="%Y-%m-%d %H:%M:%S", | ||
) | ||
|
||
|
||
def create_start_app_handler(app: FastAPI) -> Callable[[], None]: | ||
def start_app() -> None: | ||
XGenBasedModel.get_model() | ||
|
||
return start_app | ||
|
||
|
||
def get_application() -> FastAPI: | ||
application = FastAPI(title="prem-chat", debug=True, version="0.0.1") | ||
application.include_router(api_router, prefix="/v1") | ||
application.add_event_handler("startup", create_start_app_handler(application)) | ||
application.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=["*"], | ||
allow_credentials=True, | ||
allow_methods=["*"], | ||
allow_headers=["*"], | ||
) | ||
return application | ||
|
||
|
||
app = get_application() | ||
|
||
|
||
if __name__ == "__main__": | ||
uvicorn.run("main:app", host="0.0.0.0", port=8000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import os | ||
from abc import ABC, abstractmethod | ||
from typing import List | ||
|
||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | ||
|
||
|
||
class ChatModel(ABC): | ||
@abstractmethod | ||
def get_model(cls) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def generate( | ||
cls, | ||
messages: list, | ||
temperature: float = 0.9, | ||
top_p: float = 0.9, | ||
n: int = 1, | ||
stream: bool = False, | ||
max_tokens: int = 128, | ||
stop: str = "", | ||
**kwargs, | ||
) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def embeddings(cls, text) -> None: | ||
pass | ||
|
||
|
||
class XGenBasedModel(ChatModel): | ||
model = None | ||
stopping_criteria = None | ||
|
||
@classmethod | ||
def generate( | ||
cls, | ||
messages: list, | ||
temperature: float = 0.9, | ||
top_p: float = 0.9, | ||
n: int = 1, | ||
stream: bool = False, | ||
max_tokens: int = 128, | ||
stop: str = "", | ||
**kwargs, | ||
) -> List: | ||
message = messages[-1]["content"] | ||
return [ | ||
cls.model( | ||
message, | ||
max_length=max_tokens, | ||
num_return_sequences=n, | ||
temperature=temperature, | ||
top_p=top_p, | ||
eos_token_id=cls.tokenizer.eos_token_id, | ||
return_full_text=kwargs.get("return_full_text", False), | ||
do_sample=kwargs.get("do_sample", True), | ||
stop_sequence=stop[0] if stop else None, | ||
)[0]["generated_text"].rstrip(cls.tokenizer.eos_token) | ||
] | ||
|
||
@classmethod | ||
def get_model(cls) -> AutoModelForCausalLM: | ||
if cls.model is None: | ||
cls.tokenizer = AutoTokenizer.from_pretrained( | ||
os.getenv("MODEL_ID", "Salesforce/xgen-7b-8k-inst"), | ||
trust_remote_code=True, | ||
) | ||
cls.model = pipeline( | ||
tokenizer=cls.tokenizer, | ||
model=os.getenv("MODEL_ID", "Salesforce/xgen-7b-8k-inst"), | ||
torch_dtype=torch.bfloat16, | ||
trust_remote_code=True, | ||
device_map=os.getenv("DEVICE", "auto"), | ||
) | ||
return cls.model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import json | ||
import os | ||
import uuid | ||
from datetime import datetime as dt | ||
from typing import Any, Dict, Generator, List, Optional, Union | ||
|
||
from fastapi import APIRouter, HTTPException | ||
from fastapi.responses import StreamingResponse | ||
from models import XGenBasedModel as model | ||
from pydantic import BaseModel | ||
|
||
|
||
class ChatCompletionInput(BaseModel): | ||
model: str | ||
messages: List[dict] | ||
temperature: float = 0.7 | ||
top_p: float = 0.75 | ||
n: int = 1 | ||
stream: bool = False | ||
stop: Optional[Union[str, List[str]]] = ["User:"] | ||
max_tokens: int = 64 | ||
presence_penalty: float = 0.0 | ||
frequence_penalty: float = 0.0 | ||
logit_bias: Optional[dict] = {} | ||
user: str = "" | ||
|
||
|
||
class ChatCompletionResponse(BaseModel): | ||
id: str = uuid.uuid4() | ||
model: str | ||
object: str = "chat.completion" | ||
created: int = int(dt.now().timestamp()) | ||
choices: List[dict] | ||
usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} | ||
|
||
|
||
class HealthResponse(BaseModel): | ||
status: bool | ||
|
||
|
||
router = APIRouter() | ||
|
||
|
||
@router.get("/", response_model=HealthResponse) | ||
async def health() -> HealthResponse: | ||
return HealthResponse(status=True) | ||
|
||
|
||
async def generate_chunk_based_response(body, text) -> Generator[str, Any, None]: | ||
yield "event: completion\ndata: " + json.dumps( | ||
{ | ||
"id": str(uuid.uuid4()), | ||
"model": body.model, | ||
"object": "chat.completion", | ||
"choices": [ | ||
{ | ||
"role": "assistant", | ||
"index": 1, | ||
"delta": {"role": "assistant", "content": text}, | ||
"finish_reason": "stop", | ||
} | ||
], | ||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, | ||
} | ||
) + "\n\n" | ||
yield "event: done\ndata: [DONE]\n\n" | ||
|
||
|
||
@router.post("/chat/completions", response_model=ChatCompletionResponse) | ||
async def chat_completions(body: ChatCompletionInput) -> Dict[str, Any]: | ||
try: | ||
predictions = model.generate( | ||
messages=body.messages, | ||
temperature=body.temperature, | ||
top_p=body.top_p, | ||
n=body.n, | ||
stream=body.stream, | ||
max_tokens=body.max_tokens, | ||
stop=body.stop, | ||
) | ||
if body.stream: | ||
return StreamingResponse( | ||
generate_chunk_based_response(body, predictions[0]), | ||
media_type="text/event-stream", | ||
) | ||
return ChatCompletionResponse( | ||
id=str(uuid.uuid4()), | ||
model=os.getenv("MODEL_ID", "Salesforce/xgen-7b-8k-inst"), | ||
object="chat.completion", | ||
created=int(dt.now().timestamp()), | ||
choices=[ | ||
{ | ||
"role": "assistant", | ||
"index": idx, | ||
"message": {"role": "assistant", "content": text}, | ||
"finish_reason": "stop", | ||
} | ||
for idx, text in enumerate(predictions) | ||
], | ||
usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, | ||
) | ||
except ValueError as error: | ||
raise HTTPException( | ||
status_code=400, | ||
detail={"message": str(error)}, | ||
) |