Skip to content

Commit

Permalink
simplify api
Browse files Browse the repository at this point in the history
  • Loading branch information
zhihao authored and goliaro committed Oct 15, 2024
1 parent cba1396 commit eeafdc7
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 112 deletions.
1 change: 1 addition & 0 deletions python/flexflow/core/flexflow_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4671,6 +4671,7 @@ def _estimate_max_num_chars(
+ 100
)

# deprecated
def generate_inf_only(
self,
prompt_list: List[str],
Expand Down
165 changes: 53 additions & 112 deletions python/flexflow/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,39 @@ def compile(

atexit.register(self.rm.stop_server)

def _generate(self, requests: List[Request]):
if len(requests) == 0:
return []
for req in requests:
if req.req_type == RequestType.REQ_INFERENCE:
# check max_length and max_new_tokens parameters
if req.max_length == -1 and req.max_new_tokens == -1:
req.max_length = self.max_seq_length -1
elif req.max_length != -1 and req.max_new_tokens != -1:
warnings.warn(
f"Both `max_new_tokens` (={req.max_new_tokens}) and `max_length`(={req.max_length}) seem to have been set. `max_new_tokens` will take precedence."
)
req.max_length = -1
if (
req.max_length >= self.max_seq_length
or req.max_new_tokens >= self.max_seq_length
):
raise ValueError(
f"max_length ({req.max_length}) or max_new_tokens ({req.max_new_tokens}) exceeds the maximum sequence length ({self.max_seq_length})"
)
else:
if req.max_new_tokens != -1:
raise ValueError(
f"max_new_tokens ({req.max_new_tokens}) is not allowed for finetuning requests."
)
if req.max_length == -1:
req.max_length = self.max_seq_length -1
if req.max_length >= self.max_seq_length:
raise ValueError(
f"max_length ({req.max_length}) exceeds the maximum sequence length ({self.max_seq_length})"
)
return self.model.ffmodel.generate(requests)

def generate(
self,
requests_or_prompts: Union[str, List[str], Request, List[Request]],
Expand All @@ -532,129 +565,37 @@ def generate(
:return: the generation results
:rtype: GenerationResult
"""

def inf_only() -> bool:
if type(requests_or_prompts) == str:
return True
if type(requests_or_prompts) == list:
if type(requests_or_prompts[0]) == str:
return True
return False

# Inference only (type is str or List[str])
# - if max_length and max_new_tokens are both unset, set max_length to max_sequence_length
# - if both are set, give precedence to max_new_tokens
# - if only one of them is set, all good. just check that we don't exceed the max_sequence_length
# Inference + finetunining (type is Request or List[Request]):
# - inference requests: same conditions as above
# - finetuning requests: return error if max_new_tokens is set. If max_lenght is unset, set it to max_sequence_length. If it's set, check that it doesn't exceed max_sequence_length
if inf_only():
# single prompt (str) or list of prompts in str format
if max_length == -1 and max_new_tokens == -1:
max_length = self.max_seq_length -1
elif max_length != -1 and max_new_tokens != -1:
warnings.warn(
f"Both `max_new_tokens` (={max_new_tokens}) and `max_length`(={max_length}) seem to have been set. `max_new_tokens` will take precedence."
)
max_length = -1
if max_length >= self.max_seq_length or max_new_tokens >= self.max_seq_length:
raise ValueError(
f"max_length ({max_length}) or max_new_tokens ({max_new_tokens}) exceeds the maximum sequence length ({self.max_seq_length})"
)
elif type(requests_or_prompts) == Request:
# single Request object (inference or finetuning)
if max_length != -1 or max_new_tokens != -1:
warnings.warn(
f"max_length (={max_length}) and max_new_tokens (={max_new_tokens}) are not used for Request objects."
)
if requests_or_prompts.req_type == RequestType.REQ_INFERENCE:
# check max_length and max_new_tokens parameters
if (
requests_or_prompts.max_length == -1
and requests_or_prompts.max_new_tokens == -1
):
requests_or_prompts.max_length = self.max_seq_length -1
elif (
requests_or_prompts.max_length != -1
and requests_or_prompts.max_new_tokens != -1
):
warnings.warn(
f"Both `max_new_tokens` (={requests_or_prompts.max_new_tokens}) and `max_length`(={requests_or_prompts.max_length}) seem to have been set. `max_new_tokens` will take precedence."
)
requests_or_prompts.max_length = -1
if (
requests_or_prompts.max_length >= self.max_seq_length
or requests_or_prompts.max_new_tokens >= self.max_seq_length
):
raise ValueError(
f"max_length ({requests_or_prompts.max_length}) or max_new_tokens ({requests_or_prompts.max_new_tokens}) exceeds the maximum sequence length ({self.max_seq_length})"
)
else:
if requests_or_prompts.max_new_tokens != -1:
raise ValueError(
f"max_new_tokens ({requests_or_prompts.max_new_tokens}) is not allowed for finetuning requests."
)
if requests_or_prompts.max_length == -1:
requests_or_prompts.max_length = self.max_seq_length -1
if requests_or_prompts.max_length >= self.max_seq_length:
raise ValueError(
f"max_length ({requests_or_prompts.max_length}) exceeds the maximum sequence length ({self.max_seq_length})"
)
else:
# list of Request objects (inference or finetuning)
if max_length != -1 or max_new_tokens != -1:
warnings.warn(
f"max_length (={max_length}) and max_new_tokens (={max_new_tokens}) are not used for Request objects."
)
for req in requests_or_prompts:
if req.req_type == RequestType.REQ_INFERENCE:
# check max_length and max_new_tokens parameters
if req.max_length == -1 and req.max_new_tokens == -1:
req.max_length = self.max_seq_length -1
elif req.max_length != -1 and req.max_new_tokens != -1:
warnings.warn(
f"Both `max_new_tokens` (={req.max_new_tokens}) and `max_length`(={req.max_length}) seem to have been set. `max_new_tokens` will take precedence."
)
req.max_length = -1
if (
req.max_length >= self.max_seq_length
or req.max_new_tokens >= self.max_seq_length
):
raise ValueError(
f"max_length ({req.max_length}) or max_new_tokens ({req.max_new_tokens}) exceeds the maximum sequence length ({self.max_seq_length})"
)
else:
if req.max_new_tokens != -1:
raise ValueError(
f"max_new_tokens ({req.max_new_tokens}) is not allowed for finetuning requests."
)
if req.max_length == -1:
req.max_length = self.max_seq_length -1
if req.max_length >= self.max_seq_length:
raise ValueError(
f"max_length ({req.max_length}) exceeds the maximum sequence length ({self.max_seq_length})"
)

if type(requests_or_prompts) == str:
if len(requests_or_prompts) == 0:
return None
return self.model.ffmodel.generate_inf_only(
[requests_or_prompts], max_length, max_new_tokens
return []
request = Request(
req_type=RequestType.REQ_INFERENCE,
prompt=requests_or_prompts,
max_length=max_length,
max_new_tokens=max_new_tokens,
)
return self._generate([request])
elif type(requests_or_prompts) == Request:
return self.model.ffmodel.generate(requests_or_prompts)
return self._generate([requests_or_prompts])
elif type(requests_or_prompts) == list:
if len(requests_or_prompts) == 0:
return []
if type(requests_or_prompts[0]) == str:
return self.model.ffmodel.generate_inf_only(
requests_or_prompts, max_length, max_new_tokens
)
requests = [
Request(
req_type=RequestType.REQ_INFERENCE,
prompt=req,
max_length=max_length,
max_new_tokens=max_new_tokens,
)
for req in requests_or_prompts
]
return self._generate(requests)
else:
print(requests_or_prompts)
return self.model.ffmodel.generate(requests_or_prompts)
return self._generate(requests_or_prompts)
else:
assert False, "Please pass a non-empty string or list of strings"
assert False, "Please pass a string, list of strings, Request, or list of Requests"

def start_server(self):
self.rm.start_server(self.model.ffmodel)
Expand Down

0 comments on commit eeafdc7

Please sign in to comment.