Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

37 improve dolly in order to accept parameters as falcon does #47

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions cht-dolly-v2/models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import os
from abc import ABC, abstractmethod
from typing import List

import torch
from transformers import pipeline
from transformers import logging, pipeline

logging.set_verbosity_error()

class ChatModel:
@classmethod

class ChatModel(ABC):
@abstractmethod
def get_model(cls):
pass

@classmethod
@abstractmethod
def generate(
cls,
messages: list,
Expand All @@ -23,13 +27,14 @@ def generate(
):
pass

@classmethod
@abstractmethod
def embeddings(cls, text):
pass


class DollyBasedModel(ChatModel):
model = None
tokenizer = None

@classmethod
def generate(
Expand All @@ -42,9 +47,20 @@ def generate(
max_tokens: int = 128,
stop: str = "",
**kwargs,
):
) -> List:
message = messages[-1]["content"]
return [cls.model(message)[0]["generated_text"]]
return [
cls.model(
message,
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
num_return_sequences=n,
return_full_text=kwargs.get("return_full_text", False),
do_sample=kwargs.get("do_sample", True),
stop_sequence=stop[0] if stop else None,
)[0]["generated_text"]
]

@classmethod
def get_model(cls):
Expand Down
2 changes: 1 addition & 1 deletion scripts/cht_dolly_v2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

set -e

export VERSION=1.0.2
export VERSION=1.0.3

docker buildx build --push \
--cache-from ghcr.io/premai-io/chat-dolly-v2-12b-gpu:latest \
Expand Down