diff --git a/cht-llama-cpp/models.py b/cht-llama-cpp/models.py index 28f4f18..8a5554b 100644 --- a/cht-llama-cpp/models.py +++ b/cht-llama-cpp/models.py @@ -43,9 +43,12 @@ def generate( stream: bool = False, max_tokens: int = 256, stop: list = [], + n_threads: int = None, **kwargs, ): messages = cls.reduce_number_of_messages(messages[::-1], max_tokens)[::-1] + if n_threads is not None: + cls.model.n_threads = n_threads return cls.model.create_chat_completion( messages, temperature=temperature, diff --git a/cht-llama-cpp/routes.py b/cht-llama-cpp/routes.py index ec42b24..f7d301f 100644 --- a/cht-llama-cpp/routes.py +++ b/cht-llama-cpp/routes.py @@ -1,6 +1,7 @@ import json import uuid from datetime import datetime as dt +from typing import List, Optional, Union from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse @@ -15,12 +16,13 @@ class ChatCompletionInput(BaseModel): top_p: float = 0.95 n: int = 1 stream: bool = False - stop: str | list | None = [] + stop: Optional[Union[str, List[str]]] = [] max_tokens: int = 256 presence_penalty: float = 0.0 frequence_penalty: float = 0.0 - logit_bias: dict | None = {} + logit_bias: Optional[dict] = {} user: str = "" + n_threads: int = None class ChatCompletionResponse(BaseModel): @@ -80,6 +82,7 @@ async def generate_chunk_based_response(body): presence_penalty=body.presence_penalty, frequence_penalty=body.frequence_penalty, logit_bias=body.logit_bias, + n_threads=body.n_threads, ) for chunk in chunks: yield f"event: completion\ndata: {json.dumps(chunk)}\n\n" @@ -104,6 +107,7 @@ async def chat_completions(body: ChatCompletionInput): presence_penalty=body.presence_penalty, frequence_penalty=body.frequence_penalty, logit_bias=body.logit_bias, + n_threads=body.n_threads, ) except ValueError as error: raise HTTPException( diff --git a/cht-llama-cpp/tests/test_views.py b/cht-llama-cpp/tests/test_views.py index 13a4086..7ad0a9e 100644 --- a/cht-llama-cpp/tests/test_views.py +++ b/cht-llama-cpp/tests/test_views.py @@ -10,6 +10,7 @@ def test_chat_llama_cpp() -> None: json={ "model": "vicuna-7b-q4", "messages": [{"role": "user", "content": "Hello!"}], + "n_threads": 10, }, ) assert response.status_code == 200 diff --git a/dfs-diffusers/docker/gpu/Dockerfile b/dfs-diffusers/docker/gpu/Dockerfile index 45f8295..f190d5e 100644 --- a/dfs-diffusers/docker/gpu/Dockerfile +++ b/dfs-diffusers/docker/gpu/Dockerfile @@ -15,6 +15,6 @@ RUN python3 download.py --model $MODEL_ID COPY . . ENV MODEL_ID=$MODEL_ID -ENV DEVICE=gpu +ENV DEVICE=cuda CMD python3 main.py diff --git a/scripts/cht_llama_cpp.sh b/scripts/cht_llama_cpp.sh index 4ce5a0f..d9f5483 100644 --- a/scripts/cht_llama_cpp.sh +++ b/scripts/cht_llama_cpp.sh @@ -2,7 +2,7 @@ set -e -export VERSION=1.0.2 +export VERSION=1.0.3 docker buildx build --push \ --cache-from ghcr.io/premai-io/chat-gpt4all-lora-q4-cpu:latest \ diff --git a/scripts/dfs_diffusers.sh b/scripts/dfs_diffusers.sh index 30e96bb..f9fc5aa 100644 --- a/scripts/dfs_diffusers.sh +++ b/scripts/dfs_diffusers.sh @@ -2,7 +2,7 @@ set -e -export VERSION=1.0.0 +export VERSION=1.0.1 docker buildx build --push \ --cache-from=ghcr.io/premai-io/diffuser-stable-diffusion-2-1-base-gpu:latest \