From 9d1a90112e3388ce72e16faee9546314b5cbf496 Mon Sep 17 00:00:00 2001 From: april-yyt Date: Sun, 14 Jan 2024 10:46:17 +0000 Subject: [PATCH] adjustments on usecases and api entrypoints --- docs/source/rag.rst | 6 +- docs/source/serve_fastapi.rst | 36 +++++++----- inference/.gitignore | 2 + inference/python/entrypoint/fastapi_incr.py | 45 ++++++++------ .../python/entrypoint/fastapi_specinfer.py | 58 +++++++++---------- inference/python/spec_infer.py | 16 ++--- inference/python/usecases/gradio_incr.py | 7 ++- inference/python/usecases/gradio_specinfer.py | 19 ++---- .../python/usecases/prompt_template_incr.py | 21 ++++--- .../usecases/prompt_template_specinfer.py | 32 +++++----- inference/python/usecases/rag_incr.py | 15 +++-- inference/python/usecases/rag_specinfer.py | 25 ++++---- python/flexflow/serve/serve.py | 11 ---- 13 files changed, 154 insertions(+), 139 deletions(-) diff --git a/docs/source/rag.rst b/docs/source/rag.rst index 13d4c3a7f9..4b869c2352 100644 --- a/docs/source/rag.rst +++ b/docs/source/rag.rst @@ -84,5 +84,7 @@ Example Implementation: llm_chain_rag = LLMChain(llm=ff_llm_wrapper, prompt=prompt_rag) # Run - with ff_llm: - rag_result = llm_chain_rag(docs_text) \ No newline at end of file + rag_result = llm_chain_rag(docs_text) + + # Stop the server + ff_llm.stop_server() \ No newline at end of file diff --git a/docs/source/serve_fastapi.rst b/docs/source/serve_fastapi.rst index 9856debebc..0aa6634670 100644 --- a/docs/source/serve_fastapi.rst +++ b/docs/source/serve_fastapi.rst @@ -42,7 +42,7 @@ Example class PromptRequest(BaseModel): prompt: str - llm_model = None + llm = None Endpoint Creation ================= @@ -62,19 +62,20 @@ Example @app.on_event("startup") async def startup_event(): - global llm_model - # Initialize and compile the LLM model - # ... + global llm + # Initialize and compile the LLM model + llm.compile( + generation_config, + # ... other params as needed + ) + llm.start_server() @app.post("/generate/") async def generate(prompt_request: PromptRequest): - global llm_model - if llm_model is None: - raise HTTPException(status_code=503, detail="LLM model is not initialized.") - - with llm_model: - full_output = llm_model.generate([prompt_request.prompt])[0].output_text.decode('utf-8') - return {"prompt": prompt_request.prompt, "response": full_output} + # ... exception handling + full_output = llm.generate([prompt_request.prompt])[0].output_text.decode('utf-8') + # ... split prompt and response text for returning results + return {"prompt": prompt_request.prompt, "response": full_output} Running and Testing =================== @@ -91,6 +92,15 @@ Example ------- .. code-block:: bash - # Running within the inference/python folder: - uvicorn entrypoint.fastapi_incr:app --reload --port 3000 + # Running within the inference/python folder: + uvicorn entrypoint.fastapi_incr:app --reload --port 3000 + +Full API Entrypoint Code +========================= + +A complete code example for a web-document Q&A using FlexFlow can be found here: + +1. `FastAPI Example with incremental decoding `__ + +2. `FastAPI Example with speculative inference `__ diff --git a/inference/.gitignore b/inference/.gitignore index 91ca0a55c9..e4f411d567 100644 --- a/inference/.gitignore +++ b/inference/.gitignore @@ -5,5 +5,7 @@ prompt output python/.chainlit python/chainlit.md +python/spec_infer.py +python/incr_decoding.py .env python/chain_testing.py \ No newline at end of file diff --git a/inference/python/entrypoint/fastapi_incr.py b/inference/python/entrypoint/fastapi_incr.py index d0dbef0ac9..34f61739fb 100644 --- a/inference/python/entrypoint/fastapi_incr.py +++ b/inference/python/entrypoint/fastapi_incr.py @@ -38,7 +38,7 @@ class PromptRequest(BaseModel): prompt: str # Global variable to store the LLM model -llm_model = None +llm = None def get_configs(): @@ -95,7 +95,7 @@ def get_configs(): # Initialize model on startup @app.on_event("startup") async def startup_event(): - global llm_model + global llm # Initialize your LLM model configuration here configs_dict = get_configs() @@ -103,7 +103,7 @@ async def startup_event(): ff.init(configs_dict) ff_data_type = ff.DataType.DT_FLOAT if configs.full_precision else ff.DataType.DT_HALF - llm_model = ff.LLM( + llm = ff.LLM( configs.llm_model, data_type=ff_data_type, cache_path=configs.cache_path, @@ -114,35 +114,42 @@ async def startup_event(): generation_config = ff.GenerationConfig( do_sample=False, temperature=0.9, topp=0.8, topk=1 ) - llm_model.compile( + llm.compile( generation_config, max_requests_per_batch=1, max_seq_length=256, max_tokens_per_batch=64, ) + llm.start_server() # API endpoint to generate response @app.post("/generate/") async def generate(prompt_request: PromptRequest): - if llm_model is None: + if llm is None: raise HTTPException(status_code=503, detail="LLM model is not initialized.") # Call the model to generate a response - with llm_model: - full_output = llm_model.generate([prompt_request.prompt])[0].output_text.decode('utf-8') + full_output = llm.generate([prompt_request.prompt])[0].output_text.decode('utf-8') + + # Separate the prompt and response + split_output = full_output.split('\n', 1) + if len(split_output) > 1: + response_text = split_output[1] + else: + response_text = "" - # Separate the prompt and response - split_output = full_output.split('\n', 1) - if len(split_output) > 1: - response_text = split_output[1] - else: - response_text = "" - - # Return the prompt and the response in JSON format - return { - "prompt": prompt_request.prompt, - "response": response_text - } + # Return the prompt and the response in JSON format + return { + "prompt": prompt_request.prompt, + "response": response_text + } + +# Shutdown event to stop the model server +@app.on_event("shutdown") +async def shutdown_event(): + global llm + if llm is not None: + llm.stop_server() # Main function to run Uvicorn server if __name__ == "__main__": diff --git a/inference/python/entrypoint/fastapi_specinfer.py b/inference/python/entrypoint/fastapi_specinfer.py index f910f56854..416aee6dc5 100644 --- a/inference/python/entrypoint/fastapi_specinfer.py +++ b/inference/python/entrypoint/fastapi_specinfer.py @@ -38,7 +38,7 @@ class PromptRequest(BaseModel): prompt: str # Global variable to store the LLM model -llm_model = None +llm = None def get_configs(): # Fetch configuration file path from environment variable @@ -90,28 +90,19 @@ def get_configs(): "cache_path": "", "refresh_cache": False, "full_precision": False, - }, - { - # required ssm parameter - "ssm_model": "facebook/opt-125m", - # optional ssm parameters - "cache_path": "", - "refresh_cache": False, - "full_precision": False, - }, + } ], - # "prompt": "../prompt/test.json", + # "prompt": "", "output_file": "", } # Merge dictionaries ff_init_configs.update(llm_configs) return ff_init_configs - # Initialize model on startup @app.on_event("startup") async def startup_event(): - global llm_model + global llm # Initialize your LLM model configuration here configs_dict = get_configs() @@ -145,11 +136,12 @@ async def startup_event(): output_file=configs.output_file, ) ssms.append(ssm) - + # Create the sampling configs generation_config = ff.GenerationConfig( do_sample=False, temperature=0.9, topp=0.8, topk=1 ) + # Compile the SSMs for inference and load the weights into memory for ssm in ssms: ssm.compile( @@ -167,29 +159,37 @@ async def startup_event(): max_tokens_per_batch=64, ssms=ssms, ) + + llm.start_server() # API endpoint to generate response @app.post("/generate/") async def generate(prompt_request: PromptRequest): - if llm_model is None: + if llm is None: raise HTTPException(status_code=503, detail="LLM model is not initialized.") # Call the model to generate a response - with llm_model: - full_output = llm_model.generate([prompt_request.prompt])[0].output_text.decode('utf-8') + full_output = llm.generate([prompt_request.prompt])[0].output_text.decode('utf-8') + + # Separate the prompt and response + split_output = full_output.split('\n', 1) + if len(split_output) > 1: + response_text = split_output[1] + else: + response_text = "" + + # Return the prompt and the response in JSON format + return { + "prompt": prompt_request.prompt, + "response": response_text + } - # Separate the prompt and response - split_output = full_output.split('\n', 1) - if len(split_output) > 1: - response_text = split_output[1] - else: - response_text = "" - - # Return the prompt and the response in JSON format - return { - "prompt": prompt_request.prompt, - "response": response_text - } +# Shutdown event to stop the model server +@app.on_event("shutdown") +async def shutdown_event(): + global llm + if llm is not None: + llm.stop_server() # Main function to run Uvicorn server if __name__ == "__main__": diff --git a/inference/python/spec_infer.py b/inference/python/spec_infer.py index fcb1b8f891..f5c5bc6a88 100644 --- a/inference/python/spec_infer.py +++ b/inference/python/spec_infer.py @@ -143,13 +143,15 @@ def main(): llm.start_server() - if len(configs.prompt) > 0: - prompts = [s for s in json.load(open(configs.prompt))] - results = llm.generate(prompts) - else: - result = llm.generate("Three tips for staying healthy are: ") - - llm.stop_server() + # if len(configs.prompt) > 0: + # prompts = [s for s in json.load(open(configs.prompt))] + # results = llm.generate(prompts) + # else: + # result = llm.generate("Three tips for staying healthy are: ") + + result = llm.generate("Three tips for staying healthy are: ") + + if __name__ == "__main__": print("flexflow inference example (speculative inference)") diff --git a/inference/python/usecases/gradio_incr.py b/inference/python/usecases/gradio_incr.py index 2c7b448819..2735b665bb 100644 --- a/inference/python/usecases/gradio_incr.py +++ b/inference/python/usecases/gradio_incr.py @@ -152,9 +152,10 @@ def main(): # ) # interface version 2 - with llm: - iface = gr.ChatInterface(fn=generate_response) - iface.launch() + iface = gr.ChatInterface(fn=generate_response) + llm.start_server() + iface.launch() + llm.stop_server() if __name__ == "__main__": print("flexflow inference example with gradio interface") diff --git a/inference/python/usecases/gradio_specinfer.py b/inference/python/usecases/gradio_specinfer.py index 799df65e98..08cde3f00b 100644 --- a/inference/python/usecases/gradio_specinfer.py +++ b/inference/python/usecases/gradio_specinfer.py @@ -103,17 +103,9 @@ def get_configs(): "cache_path": "", "refresh_cache": False, "full_precision": False, - }, - { - # required ssm parameter - "ssm_model": "facebook/opt-125m", - # optional ssm parameters - "cache_path": "", - "refresh_cache": False, - "full_precision": False, - }, + } ], - # "prompt": "../prompt/test.json", + # "prompt": "", "output_file": "", } # Merge dictionaries @@ -203,9 +195,10 @@ def main(): # ) # interface version 2 - with llm: - iface = gr.ChatInterface(fn=generate_response) - iface.launch() + iface = gr.ChatInterface(fn=generate_response) + llm.start_server() + iface.launch() + llm.stop_server() if __name__ == "__main__": print("flexflow inference example with gradio interface") diff --git a/inference/python/usecases/prompt_template_incr.py b/inference/python/usecases/prompt_template_incr.py index 596767ef86..8bffe9ddad 100644 --- a/inference/python/usecases/prompt_template_incr.py +++ b/inference/python/usecases/prompt_template_incr.py @@ -116,16 +116,19 @@ def create_llm(self): return llm def compile_and_start(self, generation_config, max_requests_per_batch, max_seq_length, max_tokens_per_batch): - self.llm.compile(generation_config, max_requests_per_batch, max_seq_length, max_tokens_per_batch) + self.llm.compile(generation_config, max_requests_per_batch, max_seq_length, max_tokens_per_batch) + self.llm.start_server() def generate(self, prompt): - user_input = prompt - results = self.llm.generate(user_input) + results = self.llm.generate(prompt) if isinstance(results, list): result_txt = results[0].output_text.decode('utf-8') else: result_txt = results.output_text.decode('utf-8') return result_txt + + def stop_server(self): + self.llm.stop_server() def __enter__(self): return self.llm.__enter__() @@ -171,10 +174,14 @@ def _call( # USE CASE 1: Prompt Template template = """Question: {question} Answer: Let's think step by step.""" - prompt = PromptTemplate(template=template, input_variables=["question"]) + # Build prompt template and langchain + prompt = PromptTemplate(template=template, input_variables=["question"]) llm_chain = LLMChain(prompt=prompt, llm=ff_llm_wrapper) - with ff_llm: - question = "Who was the US president in the year the first Pokemon game was released?" - print(llm_chain.run(question)) + + question = "Who was the US president in the year the first Pokemon game was released?" + print(llm_chain.run(question)) + + # stop the server + ff_llm.stop_server() diff --git a/inference/python/usecases/prompt_template_specinfer.py b/inference/python/usecases/prompt_template_specinfer.py index f7f1848782..dfc92e9ac2 100644 --- a/inference/python/usecases/prompt_template_specinfer.py +++ b/inference/python/usecases/prompt_template_specinfer.py @@ -106,17 +106,9 @@ def get_configs(self, config_file): "cache_path": "", "refresh_cache": False, "full_precision": False, - }, - { - # required ssm parameter - "ssm_model": "facebook/opt-125m", - # optional ssm parameters - "cache_path": "", - "refresh_cache": False, - "full_precision": False, - }, + } ], - # "prompt": "../prompt/test.json", + # "prompt": "", "output_file": "", } # Merge dictionaries @@ -173,15 +165,18 @@ def compile_and_start(self, generation_config, max_requests_per_batch, max_seq_l max_tokens_per_batch, ssms = self.ssms ) + self.llm.start_server() def generate(self, prompt): - user_input = prompt - results = self.llm.generate(user_input) + results = self.llm.generate(prompt) if isinstance(results, list): result_txt = results[0].output_text.decode('utf-8') else: result_txt = results.output_text.decode('utf-8') return result_txt + + def stop_server(self): + self.llm.stop_server() def __enter__(self): return self.llm.__enter__() @@ -227,10 +222,15 @@ def _call( # USE CASE 1: Prompt Template template = """Question: {question} Answer: Let's think step by step.""" + + # Build prompt template and langchain prompt = PromptTemplate(template=template, input_variables=["question"]) - llm_chain = LLMChain(prompt=prompt, llm=ff_llm_wrapper) - with ff_llm: - question = "Who was the US president in the year the first Pokemon game was released?" - print(llm_chain.run(question)) + + question = "Who was the US president in the year the first Pokemon game was released?" + print(llm_chain.run(question)) + + # stop the server + ff_llm.stop_server() + diff --git a/inference/python/usecases/rag_incr.py b/inference/python/usecases/rag_incr.py index fea9af487d..15e7f3d092 100644 --- a/inference/python/usecases/rag_incr.py +++ b/inference/python/usecases/rag_incr.py @@ -122,16 +122,19 @@ def create_llm(self): def compile_and_start(self, generation_config, max_requests_per_batch, max_seq_length, max_tokens_per_batch): self.llm.compile(generation_config, max_requests_per_batch, max_seq_length, max_tokens_per_batch) + self.llm.start_server() def generate(self, prompt): - user_input = prompt - results = self.llm.generate(user_input) + results = self.llm.generate(prompt) if isinstance(results, list): result_txt = results[0].output_text.decode('utf-8') else: result_txt = results.output_text.decode('utf-8') return result_txt + def stop_server(self): + self.llm.stop_server() + def __enter__(self): return self.llm.__enter__() @@ -205,11 +208,13 @@ def _call( prompt_rag = PromptTemplate.from_template( "Summarize the main themes in these retrieved docs: {docs_text}" ) - + # Chain llm_chain_rag = LLMChain(llm=ff_llm_wrapper, prompt=prompt_rag) # Run - with ff_llm: - rag_result = llm_chain_rag(docs_text) + rag_result = llm_chain_rag(docs_text) + + # Stop the server + ff_llm.stop_server() diff --git a/inference/python/usecases/rag_specinfer.py b/inference/python/usecases/rag_specinfer.py index 4523111b08..512b973955 100644 --- a/inference/python/usecases/rag_specinfer.py +++ b/inference/python/usecases/rag_specinfer.py @@ -110,17 +110,9 @@ def get_configs(self, config_file): "cache_path": "", "refresh_cache": False, "full_precision": False, - }, - { - # required ssm parameter - "ssm_model": "facebook/opt-125m", - # optional ssm parameters - "cache_path": "", - "refresh_cache": False, - "full_precision": False, - }, + } ], - # "prompt": "../prompt/test.json", + # "prompt": "", "output_file": "", } # Merge dictionaries @@ -177,16 +169,20 @@ def compile_and_start(self, generation_config, max_requests_per_batch, max_seq_l max_tokens_per_batch, ssms = self.ssms ) + # start server + self.llm.start_server() def generate(self, prompt): - user_input = prompt - results = self.llm.generate(user_input) + results = self.llm.generate(prompt) if isinstance(results, list): result_txt = results[0].output_text.decode('utf-8') else: result_txt = results.output_text.decode('utf-8') return result_txt + def stop_server(self): + self.llm.stop_server() + def __enter__(self): return self.llm.__enter__() @@ -264,6 +260,7 @@ def _call( llm_chain_rag = LLMChain(llm=ff_llm_wrapper, prompt=prompt_rag) # Run - with ff_llm: - rag_result = llm_chain_rag(docs_text) + rag_result = llm_chain_rag(docs_text) + # stop the server + ff_llm.stop_server() diff --git a/python/flexflow/serve/serve.py b/python/flexflow/serve/serve.py index 6ea95f7727..d1a935e5fc 100644 --- a/python/flexflow/serve/serve.py +++ b/python/flexflow/serve/serve.py @@ -426,12 +426,6 @@ def generate(self, prompts: Union[str, List[str]], max_length: int = 128): return self.model.ffmodel.generate(prompts, max_length) else: assert False, "Please pass a non-empty string or list of strings" -<<<<<<< HEAD - - def __enter__(self): - # Start the server when entering the context - self.rm.start_server(self.model.ffmodel) -======= def start_server(self): self.rm.start_server(self.model.ffmodel) @@ -444,16 +438,11 @@ def stop_server(self): def __enter__(self): # Start the server when entering the context #self.rm.start_server(self.model.ffmodel) ->>>>>>> origin/inference return self def __exit__(self, exc_type, exc_value, traceback): # Stop the server when exiting the context -<<<<<<< HEAD - self.rm.stop_server() -======= #self.rm.stop_server() ->>>>>>> origin/inference if exc_type: print(f"Exception occurred: {exc_value}")