Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Commit

Permalink
add: endpoints and models for xgen
Browse files Browse the repository at this point in the history
  • Loading branch information
biswaroop1547 committed Jul 3, 2023
1 parent 84f7d52 commit a3d9f62
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 0 deletions.
45 changes: 45 additions & 0 deletions cht-xgen/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import logging
from typing import Callable

import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from models import XGenBasedModel
from routes import router as api_router

load_dotenv()

logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)


def create_start_app_handler(app: FastAPI) -> Callable[[], None]:
def start_app() -> None:
XGenBasedModel.get_model()

return start_app


def get_application() -> FastAPI:
application = FastAPI(title="prem-chat", debug=True, version="0.0.1")
application.include_router(api_router, prefix="/v1")
application.add_event_handler("startup", create_start_app_handler(application))
application.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
return application


app = get_application()


if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000)
78 changes: 78 additions & 0 deletions cht-xgen/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os
from abc import ABC, abstractmethod
from typing import List

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline


class ChatModel(ABC):
@abstractmethod
def get_model(cls) -> None:
pass

@abstractmethod
def generate(
cls,
messages: list,
temperature: float = 0.9,
top_p: float = 0.9,
n: int = 1,
stream: bool = False,
max_tokens: int = 128,
stop: str = "",
**kwargs,
) -> None:
pass

@abstractmethod
def embeddings(cls, text) -> None:
pass


class XGenBasedModel(ChatModel):
model = None
stopping_criteria = None

@classmethod
def generate(
cls,
messages: list,
temperature: float = 0.9,
top_p: float = 0.9,
n: int = 1,
stream: bool = False,
max_tokens: int = 128,
stop: str = "",
**kwargs,
) -> List:
message = messages[-1]["content"]
return [
cls.model(
message,
max_length=max_tokens,
num_return_sequences=n,
temperature=temperature,
top_p=top_p,
eos_token_id=cls.tokenizer.eos_token_id,
return_full_text=kwargs.get("return_full_text", False),
do_sample=kwargs.get("do_sample", True),
stop_sequence=stop[0] if stop else None,
)[0]["generated_text"].rstrip(cls.tokenizer.eos_token)
]

@classmethod
def get_model(cls) -> AutoModelForCausalLM:
if cls.model is None:
cls.tokenizer = AutoTokenizer.from_pretrained(
os.getenv("MODEL_ID", "Salesforce/xgen-7b-8k-inst"),
trust_remote_code=True,
)
cls.model = pipeline(
tokenizer=cls.tokenizer,
model=os.getenv("MODEL_ID", "Salesforce/xgen-7b-8k-inst"),
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map=os.getenv("DEVICE", "auto"),
)
return cls.model
106 changes: 106 additions & 0 deletions cht-xgen/routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import json
import os
import uuid
from datetime import datetime as dt
from typing import Any, Dict, Generator, List, Optional, Union

from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from models import XGenBasedModel as model
from pydantic import BaseModel


class ChatCompletionInput(BaseModel):
model: str
messages: List[dict]
temperature: float = 0.7
top_p: float = 0.75
n: int = 1
stream: bool = False
stop: Optional[Union[str, List[str]]] = ["User:"]
max_tokens: int = 64
presence_penalty: float = 0.0
frequence_penalty: float = 0.0
logit_bias: Optional[dict] = {}
user: str = ""


class ChatCompletionResponse(BaseModel):
id: str = uuid.uuid4()
model: str
object: str = "chat.completion"
created: int = int(dt.now().timestamp())
choices: List[dict]
usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}


class HealthResponse(BaseModel):
status: bool


router = APIRouter()


@router.get("/", response_model=HealthResponse)
async def health() -> HealthResponse:
return HealthResponse(status=True)


async def generate_chunk_based_response(body, text) -> Generator[str, Any, None]:
yield "event: completion\ndata: " + json.dumps(
{
"id": str(uuid.uuid4()),
"model": body.model,
"object": "chat.completion",
"choices": [
{
"role": "assistant",
"index": 1,
"delta": {"role": "assistant", "content": text},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
) + "\n\n"
yield "event: done\ndata: [DONE]\n\n"


@router.post("/chat/completions", response_model=ChatCompletionResponse)
async def chat_completions(body: ChatCompletionInput) -> Dict[str, Any]:
try:
predictions = model.generate(
messages=body.messages,
temperature=body.temperature,
top_p=body.top_p,
n=body.n,
stream=body.stream,
max_tokens=body.max_tokens,
stop=body.stop,
)
if body.stream:
return StreamingResponse(
generate_chunk_based_response(body, predictions[0]),
media_type="text/event-stream",
)
return ChatCompletionResponse(
id=str(uuid.uuid4()),
model=os.getenv("MODEL_ID", "Salesforce/xgen-7b-8k-inst"),
object="chat.completion",
created=int(dt.now().timestamp()),
choices=[
{
"role": "assistant",
"index": idx,
"message": {"role": "assistant", "content": text},
"finish_reason": "stop",
}
for idx, text in enumerate(predictions)
],
usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
)
except ValueError as error:
raise HTTPException(
status_code=400,
detail={"message": str(error)},
)

0 comments on commit a3d9f62

Please sign in to comment.