Skip to content

Commit

Permalink
Merge pull request #39 from 01-ai/reed/add_eos_token
Browse files Browse the repository at this point in the history
feat: use '\n' as default eos_token in demo
  • Loading branch information
ZhaoFancy authored Nov 7, 2023
2 parents e228d96 + bb2d975 commit 3ee39d4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 0 deletions.
3 changes: 3 additions & 0 deletions demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ python text_generation.py \
--model 01-ai/Yi-6B \
--tokenizer 01-ai/Yi-6B \
--max-tokens 512 \
--eos-token $'\n' \
--streaming
```

Expand All @@ -21,5 +22,7 @@ torchrun --nproc_per_node 2 \
text_generation_tp.py \
--model 01-ai/Yi-6B \
--max-tokens 512 \
--eos-token $'\n' \
--streaming

```
7 changes: 7 additions & 0 deletions demo/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def parse_inputs():
default="Let me tell you an interesting story about cat Tom and mouse Jerry,",
help="The prompt to start with",
)
parser.add_argument(
"--eos-token",
type=str,
default="<|endoftext|>",
help="End of sentence token",
)
args = parser.parse_args()
return args

Expand All @@ -55,6 +61,7 @@ def main(args):
inputs.input_ids.cuda(),
max_new_tokens=args.max_tokens,
streamer=streamer,
eos_token_id=tokenizer.convert_tokens_to_ids(args.eos_token),
do_sample=True,
)
if streamer is None:
Expand Down
7 changes: 7 additions & 0 deletions demo/text_generation_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def parse_inputs():
default="Let me tell you an interesting story about cat Tom and mouse Jerry,",
help="The prompt to start with",
)
parser.add_argument(
"--eos-token",
type=str,
default="<|endoftext|>",
help="End of sentence token",
)
args = parser.parse_args()
return args

Expand Down Expand Up @@ -93,6 +99,7 @@ def on_finalized_text(self, text: str, stream_end: bool = False):
inputs.input_ids.cuda(),
max_new_tokens=args.max_tokens,
streamer=streamer,
eos_token_id=tokenizer.convert_tokens_to_ids(args.eos_token),
do_sample=True,
)
if distributed.get_rank() == 0 and streamer is None:
Expand Down

0 comments on commit 3ee39d4

Please sign in to comment.