diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py index 3a160d98c0..e2240f0b4f 100644 --- a/python/flexflow/core/flexflow_cffi.py +++ b/python/flexflow/core/flexflow_cffi.py @@ -4671,6 +4671,7 @@ def _estimate_max_num_chars( + 100 ) + # deprecated def generate_inf_only( self, prompt_list: List[str], diff --git a/python/flexflow/serve/serve.py b/python/flexflow/serve/serve.py index d9a22dfcc0..c8540a6ed3 100644 --- a/python/flexflow/serve/serve.py +++ b/python/flexflow/serve/serve.py @@ -515,6 +515,39 @@ def compile( atexit.register(self.rm.stop_server) + def _generate(self, requests: List[Request]): + if len(requests) == 0: + return [] + for req in requests: + if req.req_type == RequestType.REQ_INFERENCE: + # check max_length and max_new_tokens parameters + if req.max_length == -1 and req.max_new_tokens == -1: + req.max_length = self.max_seq_length -1 + elif req.max_length != -1 and req.max_new_tokens != -1: + warnings.warn( + f"Both `max_new_tokens` (={req.max_new_tokens}) and `max_length`(={req.max_length}) seem to have been set. `max_new_tokens` will take precedence." + ) + req.max_length = -1 + if ( + req.max_length >= self.max_seq_length + or req.max_new_tokens >= self.max_seq_length + ): + raise ValueError( + f"max_length ({req.max_length}) or max_new_tokens ({req.max_new_tokens}) exceeds the maximum sequence length ({self.max_seq_length})" + ) + else: + if req.max_new_tokens != -1: + raise ValueError( + f"max_new_tokens ({req.max_new_tokens}) is not allowed for finetuning requests." + ) + if req.max_length == -1: + req.max_length = self.max_seq_length -1 + if req.max_length >= self.max_seq_length: + raise ValueError( + f"max_length ({req.max_length}) exceeds the maximum sequence length ({self.max_seq_length})" + ) + return self.model.ffmodel.generate(requests) + def generate( self, requests_or_prompts: Union[str, List[str], Request, List[Request]], @@ -532,129 +565,37 @@ def generate( :return: the generation results :rtype: GenerationResult """ - - def inf_only() -> bool: - if type(requests_or_prompts) == str: - return True - if type(requests_or_prompts) == list: - if type(requests_or_prompts[0]) == str: - return True - return False - - # Inference only (type is str or List[str]) - # - if max_length and max_new_tokens are both unset, set max_length to max_sequence_length - # - if both are set, give precedence to max_new_tokens - # - if only one of them is set, all good. just check that we don't exceed the max_sequence_length - # Inference + finetunining (type is Request or List[Request]): - # - inference requests: same conditions as above - # - finetuning requests: return error if max_new_tokens is set. If max_lenght is unset, set it to max_sequence_length. If it's set, check that it doesn't exceed max_sequence_length - if inf_only(): - # single prompt (str) or list of prompts in str format - if max_length == -1 and max_new_tokens == -1: - max_length = self.max_seq_length -1 - elif max_length != -1 and max_new_tokens != -1: - warnings.warn( - f"Both `max_new_tokens` (={max_new_tokens}) and `max_length`(={max_length}) seem to have been set. `max_new_tokens` will take precedence." - ) - max_length = -1 - if max_length >= self.max_seq_length or max_new_tokens >= self.max_seq_length: - raise ValueError( - f"max_length ({max_length}) or max_new_tokens ({max_new_tokens}) exceeds the maximum sequence length ({self.max_seq_length})" - ) - elif type(requests_or_prompts) == Request: - # single Request object (inference or finetuning) - if max_length != -1 or max_new_tokens != -1: - warnings.warn( - f"max_length (={max_length}) and max_new_tokens (={max_new_tokens}) are not used for Request objects." - ) - if requests_or_prompts.req_type == RequestType.REQ_INFERENCE: - # check max_length and max_new_tokens parameters - if ( - requests_or_prompts.max_length == -1 - and requests_or_prompts.max_new_tokens == -1 - ): - requests_or_prompts.max_length = self.max_seq_length -1 - elif ( - requests_or_prompts.max_length != -1 - and requests_or_prompts.max_new_tokens != -1 - ): - warnings.warn( - f"Both `max_new_tokens` (={requests_or_prompts.max_new_tokens}) and `max_length`(={requests_or_prompts.max_length}) seem to have been set. `max_new_tokens` will take precedence." - ) - requests_or_prompts.max_length = -1 - if ( - requests_or_prompts.max_length >= self.max_seq_length - or requests_or_prompts.max_new_tokens >= self.max_seq_length - ): - raise ValueError( - f"max_length ({requests_or_prompts.max_length}) or max_new_tokens ({requests_or_prompts.max_new_tokens}) exceeds the maximum sequence length ({self.max_seq_length})" - ) - else: - if requests_or_prompts.max_new_tokens != -1: - raise ValueError( - f"max_new_tokens ({requests_or_prompts.max_new_tokens}) is not allowed for finetuning requests." - ) - if requests_or_prompts.max_length == -1: - requests_or_prompts.max_length = self.max_seq_length -1 - if requests_or_prompts.max_length >= self.max_seq_length: - raise ValueError( - f"max_length ({requests_or_prompts.max_length}) exceeds the maximum sequence length ({self.max_seq_length})" - ) - else: - # list of Request objects (inference or finetuning) - if max_length != -1 or max_new_tokens != -1: - warnings.warn( - f"max_length (={max_length}) and max_new_tokens (={max_new_tokens}) are not used for Request objects." - ) - for req in requests_or_prompts: - if req.req_type == RequestType.REQ_INFERENCE: - # check max_length and max_new_tokens parameters - if req.max_length == -1 and req.max_new_tokens == -1: - req.max_length = self.max_seq_length -1 - elif req.max_length != -1 and req.max_new_tokens != -1: - warnings.warn( - f"Both `max_new_tokens` (={req.max_new_tokens}) and `max_length`(={req.max_length}) seem to have been set. `max_new_tokens` will take precedence." - ) - req.max_length = -1 - if ( - req.max_length >= self.max_seq_length - or req.max_new_tokens >= self.max_seq_length - ): - raise ValueError( - f"max_length ({req.max_length}) or max_new_tokens ({req.max_new_tokens}) exceeds the maximum sequence length ({self.max_seq_length})" - ) - else: - if req.max_new_tokens != -1: - raise ValueError( - f"max_new_tokens ({req.max_new_tokens}) is not allowed for finetuning requests." - ) - if req.max_length == -1: - req.max_length = self.max_seq_length -1 - if req.max_length >= self.max_seq_length: - raise ValueError( - f"max_length ({req.max_length}) exceeds the maximum sequence length ({self.max_seq_length})" - ) - if type(requests_or_prompts) == str: if len(requests_or_prompts) == 0: - return None - return self.model.ffmodel.generate_inf_only( - [requests_or_prompts], max_length, max_new_tokens + return [] + request = Request( + req_type=RequestType.REQ_INFERENCE, + prompt=requests_or_prompts, + max_length=max_length, + max_new_tokens=max_new_tokens, ) + return self._generate([request]) elif type(requests_or_prompts) == Request: - return self.model.ffmodel.generate(requests_or_prompts) + return self._generate([requests_or_prompts]) elif type(requests_or_prompts) == list: if len(requests_or_prompts) == 0: return [] if type(requests_or_prompts[0]) == str: - return self.model.ffmodel.generate_inf_only( - requests_or_prompts, max_length, max_new_tokens - ) + requests = [ + Request( + req_type=RequestType.REQ_INFERENCE, + prompt=req, + max_length=max_length, + max_new_tokens=max_new_tokens, + ) + for req in requests_or_prompts + ] + return self._generate(requests) else: print(requests_or_prompts) - return self.model.ffmodel.generate(requests_or_prompts) + return self._generate(requests_or_prompts) else: - assert False, "Please pass a non-empty string or list of strings" + assert False, "Please pass a string, list of strings, Request, or list of Requests" def start_server(self): self.rm.start_server(self.model.ffmodel)