Skip to content

Commit

Permalink
only stop server if rm is initialized
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Jan 14, 2024
1 parent 756128f commit 621c20c
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions python/flexflow/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,11 @@ def __init__(
self.cache_path = cache_path if len(cache_path) > 0 else "~/.cache/flexflow"
self.refresh_cache = refresh_cache
self.output_file = output_file
self.rm = None

def __del__(self):
# Stop the background server before deleting the object
if type(self) == LLM:
if type(self) == LLM and self.rm is not None:
self.rm.stop_server()

def __get_ff_model_type(self):
Expand Down Expand Up @@ -320,9 +321,9 @@ def compile(
:param ssms: The SSMs to use when operating in speculative inference mode, defaults to []
:type ssms: list, optional
"""
#self.max_requests_per_batch = max_requests_per_batch
#self.max_seq_length = max_seq_length
#self.max_tokens_per_batch = max_tokens_per_batch
# self.max_requests_per_batch = max_requests_per_batch
# self.max_seq_length = max_seq_length
# self.max_tokens_per_batch = max_tokens_per_batch
self.ssms = ssms
self.generation_config = GenerationConfig()
self.ffconfig = FFConfig()
Expand Down Expand Up @@ -362,7 +363,7 @@ def compile(
self.ffconfig,
self.hf_config,
self.data_type,
max_tokens_per_batch
max_tokens_per_batch,
)

# Download the weights from huggingface (if needed)
Expand All @@ -378,7 +379,7 @@ def compile(
model_configs.hidden_size,
model_configs.hidden_size // model_configs.num_attention_heads,
self.ffconfig.tensor_parallelism_degree,
self.data_type == DataType.DT_FLOAT
self.data_type == DataType.DT_FLOAT,
)

# Register weights file loader
Expand All @@ -404,8 +405,11 @@ def compile(
self.rm.register_ssm_model(ssm.model.ffmodel)

# start background server
if (mode == InferenceMode.TREE_VERIFY_MODE) or (mode == InferenceMode.INC_DECODING_MODE):
if (mode == InferenceMode.TREE_VERIFY_MODE) or (
mode == InferenceMode.INC_DECODING_MODE
):
import atexit

atexit.register(self.rm.stop_server)

def generate(self, prompts: Union[str, List[str]], max_length: int = 128):
Expand All @@ -426,26 +430,27 @@ def generate(self, prompts: Union[str, List[str]], max_length: int = 128):
return self.model.ffmodel.generate(prompts, max_length)
else:
assert False, "Please pass a non-empty string or list of strings"

def start_server(self):
self.rm.start_server(self.model.ffmodel)
print("Background server started.")

def stop_server(self):
self.rm.stop_server()
print("Background server stoped.")

def __enter__(self):
# Start the server when entering the context
#self.rm.start_server(self.model.ffmodel)
# self.rm.start_server(self.model.ffmodel)
return self

def __exit__(self, exc_type, exc_value, traceback):
# Stop the server when exiting the context
#self.rm.stop_server()
# self.rm.stop_server()
if exc_type:
print(f"Exception occurred: {exc_value}")


class SSM(LLM):
"""This class creates a SSM (Small-Speculative Model) object based on a model from HuggingFace"""

Expand Down

0 comments on commit 621c20c

Please sign in to comment.