From 4b2c2951e81bae106981b8a4fca5f94a2c19335b Mon Sep 17 00:00:00 2001 From: Filippo Pedrazzini Date: Wed, 8 Nov 2023 19:03:40 +0100 Subject: [PATCH] prompt template as args --- cht-llama-cpp/main.py | 8 +++++++- cht-llama-cpp/models.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/cht-llama-cpp/main.py b/cht-llama-cpp/main.py index 6978eb5..6a5d41b 100644 --- a/cht-llama-cpp/main.py +++ b/cht-llama-cpp/main.py @@ -13,16 +13,22 @@ MODEL_PATH = f"./ml/models/{os.getenv('MODEL_ID', 'yarn-mistral-7b-128k.Q4_K_M')}.gguf" # Mistral gguf follows ChatML syntax # https://github.com/openai/openai-python/blob/main/chatml.md -PROMPT_TEMPLATE_STRING = '{"system_prompt_template": "<|im_start|>system\\n{}\\n<|im_end|>\\n", "default_system_text": "You are an helpful AI assistant.", "user_prompt_template": "<|im_start|>user\\n{}\\n<|im_end|>\\n", "assistant_prompt_template": "<|im_start|>assistant\\n{}\\n<|im_end|>\\n", "request_assistant_response_token": "<|im_start|>assistant\\n", "template_format": "chatml"}' # noqa if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", help="Path to GGUF", default=MODEL_PATH) parser.add_argument("--port", help="Port to run model server on", type=int, default=8000) parser.add_argument("--ctx", help="Context dimension", type=int, default=4096) + parser.add_argument( + "--prompt_template", + help="Prompt Template", + type=str, + default='{"system_prompt_template": "<|im_start|>system\\n{}\\n<|im_end|>\\n", "default_system_text": "You are an helpful AI assistant.", "user_prompt_template": "<|im_start|>user\\n{}\\n<|im_end|>\\n", "assistant_prompt_template": "<|im_start|>assistant\\n{}\\n<|im_end|>\\n", "request_assistant_response_token": "<|im_start|>assistant\\n", "template_format": "chatml"}', # noqa + ) # noqa args = parser.parse_args() MODEL_PATH = args.model_path MODEL_CTX = args.ctx + PROMPT_TEMPLATE_STRING = args.prompt_template logging.basicConfig( format="%(asctime)s %(levelname)-8s %(message)s", diff --git a/cht-llama-cpp/models.py b/cht-llama-cpp/models.py index 1ecb899..af4b944 100644 --- a/cht-llama-cpp/models.py +++ b/cht-llama-cpp/models.py @@ -62,7 +62,7 @@ def generate( @classmethod def get_model(cls, model_path, prompt_template_jsonstr, n_ctx): chat_format = "llama-2" - if "mistral" in model_path: + if prompt_template_jsonstr != "" and "mistral" in model_path: cls.PROMPT_TEMPLATE = json.loads(prompt_template_jsonstr) chat_format = cls.PROMPT_TEMPLATE.get("template_format", "chatml") if cls.model is None: