diff --git a/cht-llama-v2/build.sh b/cht-llama-v2/build.sh index 523c5f5..935850a 100755 --- a/cht-llama-v2/build.sh +++ b/cht-llama-v2/build.sh @@ -1,6 +1,6 @@ #!/bin/bash set -e -export VERSION=1.0.0 +export VERSION=1.0.1 source "$(dirname "${BASH_SOURCE[0]}")/../utils.sh" build_gpu ghcr.io/premai-io/chat-llama-2-7b-gpu llama-2-7b-hf ${@:1} diff --git a/cht-llama-v2/models.py b/cht-llama-v2/models.py index d7cc2b7..4fd6eca 100644 --- a/cht-llama-v2/models.py +++ b/cht-llama-v2/models.py @@ -31,6 +31,10 @@ def generate( def embeddings(cls, text) -> None: pass + @abstractmethod + def stitch_prompt(messages: list) -> str: + pass + class LlamaBasedModel(ChatModel): model = None @@ -49,10 +53,10 @@ def generate( stop: str = "", **kwargs, ) -> List: - message = messages[-1]["content"] + prompt = cls.stitch_prompt(messages) return [ cls.model( - message, + prompt, max_length=max_tokens, max_new_tokens=max_new_tokens, num_return_sequences=n, @@ -62,8 +66,11 @@ def generate( return_full_text=kwargs.get("return_full_text", False), do_sample=kwargs.get("do_sample", True), stop_sequence=stop[0] if stop else None, - stopping_criteria=cls.stopping_criteria(stop, message, cls.tokenizer), - )[0]["generated_text"].rstrip(stop[0] if stop else "") + stopping_criteria=cls.stopping_criteria(stop, prompt, cls.tokenizer), + )[0]["generated_text"] + .rstrip(stop[0] if stop else "") + .rsplit(".", 1)[0] + .strip() ] @classmethod @@ -87,3 +94,26 @@ def get_model(cls) -> Pipeline: ) cls.stopping_criteria = LlamaStoppingCriteria return cls.model + + @staticmethod + def stitch_prompt(messages: list) -> str: + system_prompt_template = "[INST] <>\n{}\n<>\n\n" # noqa + default_system_text = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information." # noqa + user_prompt_template = "{} [/INST] " # noqa + assistant_prompt_template = "{} [INST] " # noqa + + system_prompt, chat_prompt = "", "" + for message in messages: + role = message["role"] + content = message["content"] + if role == "system": + system_prompt = system_prompt_template.format(content) + elif role == "user": + chat_prompt += user_prompt_template.format(content) + elif role == "assistant": + chat_prompt += assistant_prompt_template.format(content) + + if not system_prompt: + system_prompt = system_prompt_template.format(default_system_text) + + return system_prompt + chat_prompt