diff --git a/cht-llama-cpp/main.py b/cht-llama-cpp/main.py index 2059423..6978eb5 100644 --- a/cht-llama-cpp/main.py +++ b/cht-llama-cpp/main.py @@ -10,7 +10,7 @@ load_dotenv() -MODEL_PATH = f"./ml/models/{os.getenv('MODEL_ID', 'mistral-7b-instruct-v0.1.Q5_0')}.gguf" +MODEL_PATH = f"./ml/models/{os.getenv('MODEL_ID', 'yarn-mistral-7b-128k.Q4_K_M')}.gguf" # Mistral gguf follows ChatML syntax # https://github.com/openai/openai-python/blob/main/chatml.md PROMPT_TEMPLATE_STRING = '{"system_prompt_template": "<|im_start|>system\\n{}\\n<|im_end|>\\n", "default_system_text": "You are an helpful AI assistant.", "user_prompt_template": "<|im_start|>user\\n{}\\n<|im_end|>\\n", "assistant_prompt_template": "<|im_start|>assistant\\n{}\\n<|im_end|>\\n", "request_assistant_response_token": "<|im_start|>assistant\\n", "template_format": "chatml"}' # noqa @@ -19,8 +19,10 @@ parser = argparse.ArgumentParser() parser.add_argument("--model-path", help="Path to GGUF", default=MODEL_PATH) parser.add_argument("--port", help="Port to run model server on", type=int, default=8000) + parser.add_argument("--ctx", help="Context dimension", type=int, default=4096) args = parser.parse_args() MODEL_PATH = args.model_path + MODEL_CTX = args.ctx logging.basicConfig( format="%(asctime)s %(levelname)-8s %(message)s", @@ -33,7 +35,7 @@ def create_start_app_handler(app: FastAPI): def start_app() -> None: from models import LLaMACPPBasedModel - LLaMACPPBasedModel.get_model(MODEL_PATH, PROMPT_TEMPLATE_STRING) + LLaMACPPBasedModel.get_model(MODEL_PATH, PROMPT_TEMPLATE_STRING, MODEL_CTX) return start_app diff --git a/cht-llama-cpp/models.py b/cht-llama-cpp/models.py index 2aec3cb..1ecb899 100644 --- a/cht-llama-cpp/models.py +++ b/cht-llama-cpp/models.py @@ -26,20 +26,6 @@ class LLaMACPPBasedModel(object): def tokenize(cls, prompt): return cls.model.tokenize(b" " + prompt.encode("utf-8")) - @classmethod - def reduce_number_of_messages(cls, messages, max_tokens): - buffer_tokens = 32 - ctx_max_tokens = 4096 - num_messages = len(messages) - - tokens = [len(cls.tokenize(doc["content"])) for doc in messages] - - token_count = sum(tokens[:num_messages]) - while token_count + max_tokens + buffer_tokens > ctx_max_tokens: - num_messages -= 1 - token_count -= tokens[num_messages] - return messages[:num_messages] - @classmethod def generate( cls, @@ -55,7 +41,6 @@ def generate( ): if stop is None: stop = [] - messages = cls.reduce_number_of_messages(messages[::-1], max_tokens)[::-1] cls.model.n_threads = n_threads cht_resp = cls.model.create_chat_completion( messages, @@ -75,13 +60,13 @@ def generate( return cht_resp @classmethod - def get_model(cls, model_path, prompt_template_jsonstr): + def get_model(cls, model_path, prompt_template_jsonstr, n_ctx): chat_format = "llama-2" if "mistral" in model_path: cls.PROMPT_TEMPLATE = json.loads(prompt_template_jsonstr) chat_format = cls.PROMPT_TEMPLATE.get("template_format", "chatml") if cls.model is None: - cls.model = Llama(model_path, chat_format=chat_format) + cls.model = Llama(model_path, chat_format=chat_format, n_ctx=n_ctx) return cls.model diff --git a/cht-llama-cpp/requirements.txt b/cht-llama-cpp/requirements.txt index df5894e..cf7f103 100644 --- a/cht-llama-cpp/requirements.txt +++ b/cht-llama-cpp/requirements.txt @@ -6,4 +6,4 @@ tqdm==4.65.0 httpx==0.23.3 python-dotenv==1.0.0 tenacity==8.2.2 -llama-cpp-python==0.2.11 +llama-cpp-python==0.2.14