Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Fixed general errors about Mistral model #147

Merged
merged 1 commit into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cht-llama-cpp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

load_dotenv()

MODEL_PATH = f"./ml/models/{os.getenv('MODEL_ID', 'mistral-7b-instruct-v0.1.Q5_0')}.gguf"
MODEL_PATH = f"./ml/models/{os.getenv('MODEL_ID', 'yarn-mistral-7b-128k.Q4_K_M')}.gguf"
casperdcl marked this conversation as resolved.
Show resolved Hide resolved
# Mistral gguf follows ChatML syntax
# https://github.com/openai/openai-python/blob/main/chatml.md
PROMPT_TEMPLATE_STRING = '{"system_prompt_template": "<|im_start|>system\\n{}\\n<|im_end|>\\n", "default_system_text": "You are an helpful AI assistant.", "user_prompt_template": "<|im_start|>user\\n{}\\n<|im_end|>\\n", "assistant_prompt_template": "<|im_start|>assistant\\n{}\\n<|im_end|>\\n", "request_assistant_response_token": "<|im_start|>assistant\\n", "template_format": "chatml"}' # noqa
Expand All @@ -19,8 +19,10 @@
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", help="Path to GGUF", default=MODEL_PATH)
parser.add_argument("--port", help="Port to run model server on", type=int, default=8000)
parser.add_argument("--ctx", help="Context dimension", type=int, default=4096)
args = parser.parse_args()
MODEL_PATH = args.model_path
MODEL_CTX = args.ctx

logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
Expand All @@ -33,7 +35,7 @@ def create_start_app_handler(app: FastAPI):
def start_app() -> None:
from models import LLaMACPPBasedModel

LLaMACPPBasedModel.get_model(MODEL_PATH, PROMPT_TEMPLATE_STRING)
LLaMACPPBasedModel.get_model(MODEL_PATH, PROMPT_TEMPLATE_STRING, MODEL_CTX)

return start_app

Expand Down
19 changes: 2 additions & 17 deletions cht-llama-cpp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,6 @@ class LLaMACPPBasedModel(object):
def tokenize(cls, prompt):
return cls.model.tokenize(b" " + prompt.encode("utf-8"))

@classmethod
def reduce_number_of_messages(cls, messages, max_tokens):
buffer_tokens = 32
ctx_max_tokens = 4096
num_messages = len(messages)

tokens = [len(cls.tokenize(doc["content"])) for doc in messages]

token_count = sum(tokens[:num_messages])
while token_count + max_tokens + buffer_tokens > ctx_max_tokens:
num_messages -= 1
token_count -= tokens[num_messages]
return messages[:num_messages]

@classmethod
def generate(
cls,
Expand All @@ -55,7 +41,6 @@ def generate(
):
if stop is None:
stop = []
messages = cls.reduce_number_of_messages(messages[::-1], max_tokens)[::-1]
cls.model.n_threads = n_threads
cht_resp = cls.model.create_chat_completion(
messages,
Expand All @@ -75,13 +60,13 @@ def generate(
return cht_resp

@classmethod
def get_model(cls, model_path, prompt_template_jsonstr):
def get_model(cls, model_path, prompt_template_jsonstr, n_ctx):
chat_format = "llama-2"
if "mistral" in model_path:
cls.PROMPT_TEMPLATE = json.loads(prompt_template_jsonstr)
chat_format = cls.PROMPT_TEMPLATE.get("template_format", "chatml")
if cls.model is None:
cls.model = Llama(model_path, chat_format=chat_format)
cls.model = Llama(model_path, chat_format=chat_format, n_ctx=n_ctx)

return cls.model

Expand Down
2 changes: 1 addition & 1 deletion cht-llama-cpp/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ tqdm==4.65.0
httpx==0.23.3
python-dotenv==1.0.0
tenacity==8.2.2
llama-cpp-python==0.2.11
llama-cpp-python==0.2.14
Loading