Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Commit

Permalink
fix: stable beluga generation issues (stopping midway)
Browse files Browse the repository at this point in the history
  • Loading branch information
biswaroop1547 committed Nov 6, 2023
1 parent 4ea6a43 commit 2449ff8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
22 changes: 11 additions & 11 deletions cht-petals/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from abc import ABC, abstractmethod
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union

import torch
from petals import AutoDistributedModelForCausalLM
Expand Down Expand Up @@ -42,12 +42,12 @@ class PetalsBasedModel(ChatModel):
def generate(
cls,
messages: list,
stop: Optional[Union[str, List[str]]],
temperature: float = 0.9,
top_p: float = 0.9,
n: int = 1,
stream: bool = False,
max_tokens: int = 128,
stop: str = "/s>",
**kwargs,
) -> List:
prompt = cls.stitch_prompt(messages, cls.PROMPT_TEMPLATE)
Expand All @@ -61,19 +61,19 @@ def generate(
top_p=top_p,
max_new_tokens=max_tokens,
)
outputs = cls.safe_decode(cls.tokenizer, outputs[0, n_input_tokens:], streaming=stream, stop=stop)
outputs = cls.safe_decode(cls.tokenizer, outputs[0, n_input_tokens:], stop_tokens=stop)
return [outputs]

@classmethod
def generate_streaming(
cls,
messages: list,
stop: Optional[Union[str, List[str]]],
temperature: float = 0.9,
top_p: float = 0.9,
n: int = 1,
stream: bool = False,
max_tokens: int = 128,
stop: str = "/s>",
session=None,
inputs=None,
**kwargs,
Expand All @@ -95,7 +95,7 @@ def generate_streaming(
)
delta = outputs[0, n_input_tokens:].tolist()
token_count = len(delta) # noqa
outputs = cls.safe_decode(cls.tokenizer, delta, streaming=stream, stop=stop)
outputs = cls.safe_decode(cls.tokenizer, delta, stop_tokens=stop)
if not outputs:
return None # end
outputs = outputs.lstrip() if inputs is not None else outputs
Expand Down Expand Up @@ -153,13 +153,13 @@ def stitch_prompt(messages: list, prompt_template: Dict[str, str]) -> str:
def safe_decode(
tokenizer: PreTrainedTokenizer,
outputs: Union[torch.Tensor, List[int]],
streaming: bool = False,
stop: str = "/s>",
stop_tokens: Optional[Union[str, List[str]]],
) -> str:
# Workaround to make SentencePiece .decode() keep leading spaces in a token
fake_token = tokenizer("^")["input_ids"][0]
outputs = outputs.tolist() if isinstance(outputs, torch.Tensor) else outputs
result = tokenizer.decode([fake_token] + outputs)
if streaming:
return result.lstrip("<s>").lstrip(stop)
return result.lstrip("<s>").rsplit("</s>", 1)[0].rsplit(stop, 1)[0].strip()
result = tokenizer.decode([fake_token] + outputs).replace("<s>", "")

for stop_token in stop_tokens:
result = result.split(stop_token)[0]
return result
2 changes: 1 addition & 1 deletion cht-petals/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class ChatCompletionInput(BaseModel):
model: str
messages: List[dict]
stop: Optional[Union[str, List[str]]] = "/s>"
stop: Optional[Union[str, List[str]]] = ["</s>", "/s>"]
temperature: float = 1.0
top_p: float = 1.0
n: int = 1
Expand Down

0 comments on commit 2449ff8

Please sign in to comment.