Skip to content

Commit

Permalink
Allow clean_up_tokenization_spaces parameters for mpt handler (#402)
Browse files Browse the repository at this point in the history
* tokenizer update

* replit
  • Loading branch information
margaretqian authored Jun 30, 2023
1 parent d62a728 commit 63b6e00
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
13 changes: 10 additions & 3 deletions examples/inference-deployments/mpt/mpt_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@ class MPTModelHandler():
INPUT_KEY = 'input'
PARAMETERS_KEY = 'parameters'

def __init__(self, model_name: str):
def __init__(self,
model_name: str,
attn_impl: str = 'torch',
clean_up_tokenization_spaces: bool = False):
self.device = torch.cuda.current_device()
self.model_name = model_name
self.clean_up_tokenization_spaces = clean_up_tokenization_spaces

config = AutoConfig.from_pretrained(self.model_name,
trust_remote_code=True)
config.attn_config['attn_impl'] = 'torch'
config.attn_config['attn_impl'] = attn_impl

model = AutoModelForCausalLM.from_pretrained(self.model_name,
config=config,
Expand Down Expand Up @@ -87,7 +91,10 @@ def predict(self, model_requests: List[Dict]):

print('Logging input to generate: ', generate_inputs)
with torch.autocast('cuda', dtype=torch.bfloat16):
outputs = self.generator(generate_inputs, **generate_kwargs)
outputs = self.generator(
generate_inputs,
clean_up_tokenization_spaces=self.clean_up_tokenization_spaces,
**generate_kwargs)
return self._extract_output(outputs)

def predict_stream(self, **inputs: Dict):
Expand Down
24 changes: 24 additions & 0 deletions examples/inference-deployments/mpt/replit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: replit
compute:
gpus: 1
instance: oci.vm.gpu.a10.1
image: mosaicml/inference:0.1.4
replicas: 1
command: |
export PYTHONPATH=$PYTHONPATH:/code/examples
integrations:
- integration_type: git_repo
git_repo: mosaicml/examples
ssh_clone: false
git_branch: margaret/mpt-tokenizer-update
- integration_type: pip_packages
packages:
- sentencepiece==0.1.99
- einops==0.6.1
model:
download_parameters:
hf_path: replit/replit-code-v1-3b
model_handler: examples.inference-deployments.mpt.mpt_handler.MPTModelHandler
model_parameters:
model_name: replit/replit-code-v1-3b
clean_up_tokenization_spaces: true

0 comments on commit 63b6e00

Please sign in to comment.