Skip to content

Commit

Permalink
adjustments on usecases and api entrypoints
Browse files Browse the repository at this point in the history
  • Loading branch information
april-yyt committed Jan 14, 2024
1 parent f38165c commit 9d1a901
Show file tree
Hide file tree
Showing 13 changed files with 154 additions and 139 deletions.
6 changes: 4 additions & 2 deletions docs/source/rag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,7 @@ Example Implementation:
llm_chain_rag = LLMChain(llm=ff_llm_wrapper, prompt=prompt_rag)
# Run
with ff_llm:
rag_result = llm_chain_rag(docs_text)
rag_result = llm_chain_rag(docs_text)
# Stop the server
ff_llm.stop_server()
36 changes: 23 additions & 13 deletions docs/source/serve_fastapi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Example
class PromptRequest(BaseModel):
prompt: str
llm_model = None
llm = None
Endpoint Creation
=================
Expand All @@ -62,19 +62,20 @@ Example
@app.on_event("startup")
async def startup_event():
global llm_model
# Initialize and compile the LLM model
# ...
global llm
# Initialize and compile the LLM model
llm.compile(
generation_config,
# ... other params as needed
)
llm.start_server()
@app.post("/generate/")
async def generate(prompt_request: PromptRequest):
global llm_model
if llm_model is None:
raise HTTPException(status_code=503, detail="LLM model is not initialized.")
with llm_model:
full_output = llm_model.generate([prompt_request.prompt])[0].output_text.decode('utf-8')
return {"prompt": prompt_request.prompt, "response": full_output}
# ... exception handling
full_output = llm.generate([prompt_request.prompt])[0].output_text.decode('utf-8')
# ... split prompt and response text for returning results
return {"prompt": prompt_request.prompt, "response": full_output}
Running and Testing
===================
Expand All @@ -91,6 +92,15 @@ Example
-------

.. code-block:: bash
# Running within the inference/python folder:
uvicorn entrypoint.fastapi_incr:app --reload --port 3000
# Running within the inference/python folder:
uvicorn entrypoint.fastapi_incr:app --reload --port 3000
Full API Entrypoint Code
=========================

A complete code example for a web-document Q&A using FlexFlow can be found here:

1. `FastAPI Example with incremental decoding <https://github.com/flexflow/FlexFlow/blob/chatbot-2/inference/python/entrypoint/fastapi_incr.py>`__

2. `FastAPI Example with speculative inference <https://github.com/flexflow/FlexFlow/blob/chatbot-2/inference/python//entrypoint/fastapi_specinfer.py>`__
2 changes: 2 additions & 0 deletions inference/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@ prompt
output
python/.chainlit
python/chainlit.md
python/spec_infer.py
python/incr_decoding.py
.env
python/chain_testing.py
45 changes: 26 additions & 19 deletions inference/python/entrypoint/fastapi_incr.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class PromptRequest(BaseModel):
prompt: str

# Global variable to store the LLM model
llm_model = None
llm = None


def get_configs():
Expand Down Expand Up @@ -95,15 +95,15 @@ def get_configs():
# Initialize model on startup
@app.on_event("startup")
async def startup_event():
global llm_model
global llm

# Initialize your LLM model configuration here
configs_dict = get_configs()
configs = SimpleNamespace(**configs_dict)
ff.init(configs_dict)

ff_data_type = ff.DataType.DT_FLOAT if configs.full_precision else ff.DataType.DT_HALF
llm_model = ff.LLM(
llm = ff.LLM(
configs.llm_model,
data_type=ff_data_type,
cache_path=configs.cache_path,
Expand All @@ -114,35 +114,42 @@ async def startup_event():
generation_config = ff.GenerationConfig(
do_sample=False, temperature=0.9, topp=0.8, topk=1
)
llm_model.compile(
llm.compile(
generation_config,
max_requests_per_batch=1,
max_seq_length=256,
max_tokens_per_batch=64,
)
llm.start_server()

# API endpoint to generate response
@app.post("/generate/")
async def generate(prompt_request: PromptRequest):
if llm_model is None:
if llm is None:
raise HTTPException(status_code=503, detail="LLM model is not initialized.")

# Call the model to generate a response
with llm_model:
full_output = llm_model.generate([prompt_request.prompt])[0].output_text.decode('utf-8')
full_output = llm.generate([prompt_request.prompt])[0].output_text.decode('utf-8')

# Separate the prompt and response
split_output = full_output.split('\n', 1)
if len(split_output) > 1:
response_text = split_output[1]
else:
response_text = ""

# Separate the prompt and response
split_output = full_output.split('\n', 1)
if len(split_output) > 1:
response_text = split_output[1]
else:
response_text = ""
# Return the prompt and the response in JSON format
return {
"prompt": prompt_request.prompt,
"response": response_text
}
# Return the prompt and the response in JSON format
return {
"prompt": prompt_request.prompt,
"response": response_text
}

# Shutdown event to stop the model server
@app.on_event("shutdown")
async def shutdown_event():
global llm
if llm is not None:
llm.stop_server()

# Main function to run Uvicorn server
if __name__ == "__main__":
Expand Down
58 changes: 29 additions & 29 deletions inference/python/entrypoint/fastapi_specinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class PromptRequest(BaseModel):
prompt: str

# Global variable to store the LLM model
llm_model = None
llm = None

def get_configs():
# Fetch configuration file path from environment variable
Expand Down Expand Up @@ -90,28 +90,19 @@ def get_configs():
"cache_path": "",
"refresh_cache": False,
"full_precision": False,
},
{
# required ssm parameter
"ssm_model": "facebook/opt-125m",
# optional ssm parameters
"cache_path": "",
"refresh_cache": False,
"full_precision": False,
},
}
],
# "prompt": "../prompt/test.json",
# "prompt": "",
"output_file": "",
}
# Merge dictionaries
ff_init_configs.update(llm_configs)
return ff_init_configs


# Initialize model on startup
@app.on_event("startup")
async def startup_event():
global llm_model
global llm

# Initialize your LLM model configuration here
configs_dict = get_configs()
Expand Down Expand Up @@ -145,11 +136,12 @@ async def startup_event():
output_file=configs.output_file,
)
ssms.append(ssm)

# Create the sampling configs
generation_config = ff.GenerationConfig(
do_sample=False, temperature=0.9, topp=0.8, topk=1
)

# Compile the SSMs for inference and load the weights into memory
for ssm in ssms:
ssm.compile(
Expand All @@ -167,29 +159,37 @@ async def startup_event():
max_tokens_per_batch=64,
ssms=ssms,
)

llm.start_server()

# API endpoint to generate response
@app.post("/generate/")
async def generate(prompt_request: PromptRequest):
if llm_model is None:
if llm is None:
raise HTTPException(status_code=503, detail="LLM model is not initialized.")

# Call the model to generate a response
with llm_model:
full_output = llm_model.generate([prompt_request.prompt])[0].output_text.decode('utf-8')
full_output = llm.generate([prompt_request.prompt])[0].output_text.decode('utf-8')

# Separate the prompt and response
split_output = full_output.split('\n', 1)
if len(split_output) > 1:
response_text = split_output[1]
else:
response_text = ""

# Return the prompt and the response in JSON format
return {
"prompt": prompt_request.prompt,
"response": response_text
}

# Separate the prompt and response
split_output = full_output.split('\n', 1)
if len(split_output) > 1:
response_text = split_output[1]
else:
response_text = ""

# Return the prompt and the response in JSON format
return {
"prompt": prompt_request.prompt,
"response": response_text
}
# Shutdown event to stop the model server
@app.on_event("shutdown")
async def shutdown_event():
global llm
if llm is not None:
llm.stop_server()

# Main function to run Uvicorn server
if __name__ == "__main__":
Expand Down
16 changes: 9 additions & 7 deletions inference/python/spec_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,15 @@ def main():

llm.start_server()

if len(configs.prompt) > 0:
prompts = [s for s in json.load(open(configs.prompt))]
results = llm.generate(prompts)
else:
result = llm.generate("Three tips for staying healthy are: ")

llm.stop_server()
# if len(configs.prompt) > 0:
# prompts = [s for s in json.load(open(configs.prompt))]
# results = llm.generate(prompts)
# else:
# result = llm.generate("Three tips for staying healthy are: ")

result = llm.generate("Three tips for staying healthy are: ")



if __name__ == "__main__":
print("flexflow inference example (speculative inference)")
Expand Down
7 changes: 4 additions & 3 deletions inference/python/usecases/gradio_incr.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,10 @@ def main():
# )

# interface version 2
with llm:
iface = gr.ChatInterface(fn=generate_response)
iface.launch()
iface = gr.ChatInterface(fn=generate_response)
llm.start_server()
iface.launch()
llm.stop_server()

if __name__ == "__main__":
print("flexflow inference example with gradio interface")
Expand Down
19 changes: 6 additions & 13 deletions inference/python/usecases/gradio_specinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,9 @@ def get_configs():
"cache_path": "",
"refresh_cache": False,
"full_precision": False,
},
{
# required ssm parameter
"ssm_model": "facebook/opt-125m",
# optional ssm parameters
"cache_path": "",
"refresh_cache": False,
"full_precision": False,
},
}
],
# "prompt": "../prompt/test.json",
# "prompt": "",
"output_file": "",
}
# Merge dictionaries
Expand Down Expand Up @@ -203,9 +195,10 @@ def main():
# )

# interface version 2
with llm:
iface = gr.ChatInterface(fn=generate_response)
iface.launch()
iface = gr.ChatInterface(fn=generate_response)
llm.start_server()
iface.launch()
llm.stop_server()

if __name__ == "__main__":
print("flexflow inference example with gradio interface")
Expand Down
21 changes: 14 additions & 7 deletions inference/python/usecases/prompt_template_incr.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,19 @@ def create_llm(self):
return llm

def compile_and_start(self, generation_config, max_requests_per_batch, max_seq_length, max_tokens_per_batch):
self.llm.compile(generation_config, max_requests_per_batch, max_seq_length, max_tokens_per_batch)
self.llm.compile(generation_config, max_requests_per_batch, max_seq_length, max_tokens_per_batch)
self.llm.start_server()

def generate(self, prompt):
user_input = prompt
results = self.llm.generate(user_input)
results = self.llm.generate(prompt)
if isinstance(results, list):
result_txt = results[0].output_text.decode('utf-8')
else:
result_txt = results.output_text.decode('utf-8')
return result_txt

def stop_server(self):
self.llm.stop_server()

def __enter__(self):
return self.llm.__enter__()
Expand Down Expand Up @@ -171,10 +174,14 @@ def _call(
# USE CASE 1: Prompt Template
template = """Question: {question}
Answer: Let's think step by step."""
prompt = PromptTemplate(template=template, input_variables=["question"])

# Build prompt template and langchain
prompt = PromptTemplate(template=template, input_variables=["question"])
llm_chain = LLMChain(prompt=prompt, llm=ff_llm_wrapper)
with ff_llm:
question = "Who was the US president in the year the first Pokemon game was released?"
print(llm_chain.run(question))

question = "Who was the US president in the year the first Pokemon game was released?"
print(llm_chain.run(question))

# stop the server
ff_llm.stop_server()

Loading

0 comments on commit 9d1a901

Please sign in to comment.