Skip to content

Commit

Permalink
Refactor API structure and enhance Mochi model conversion
Browse files Browse the repository at this point in the history
- Deleted the outdated client.py file to streamline the API.
- Introduced serve.py, a new combined API router for multiple LitServe-based models, allowing clients to specify the model via a single endpoint.
- Updated convert_mochi_to_diffusers.py with detailed docstrings and improved conversion functions for Mochi model checkpoints.
- Enhanced mochi_diffusers.py to support additional GPU models and optimize memory management during video generation.
- Removed the upscale-video.py script as it is no longer needed.

These changes aim to improve the API's usability and maintainability while enhancing the Mochi model's integration with the Diffusers framework.
  • Loading branch information
VikramxD committed Dec 10, 2024
1 parent 67911a1 commit 9b9c09c
Show file tree
Hide file tree
Showing 5 changed files with 418 additions and 50 deletions.
18 changes: 0 additions & 18 deletions api/client.py

This file was deleted.

257 changes: 257 additions & 0 deletions api/serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
"""
Combined API router for multiple LitServe-based models.
This script imports multiple model-specific LitAPI classes (e.g., LTXVideoAPI
and MochiVideoAPI) and integrates them into a single endpoint. Clients specify
which model to invoke by providing a `model_name` field in the request body.
Features:
- Single endpoint routing for multiple models
- Prometheus metrics for request duration tracking
- Comprehensive logging (stdout and file) with loguru
- Detailed docstrings and structured JSON responses
- Extensible: Just add new model APIs and register them in `model_apis`.
Usage:
1. Ensure `ltx_serve.py` and `mochi_serve.py` are in the same directory.
2. Run `python combined_serve.py`.
3. Send POST requests to `http://localhost:8000/predict` with JSON like:
{
"model_name": "ltx",
"prompt": "Generate a video about a sunny day at the beach"
}
or
{
"model_name": "mochi",
"prompt": "Generate a video about a futuristic city"
}
"""

import sys
import os
import time
from typing import Dict, Any, List, Union
from pydantic import BaseModel, Field
from loguru import logger

import torch
import litserve as ls
from prometheus_client import (
CollectorRegistry,
Histogram,
make_asgi_app,
multiprocess
)

# Import the individual model APIs
from ltx_serve import LTXVideoAPI
from mochi_serve import MochiVideoAPI

# Setup Prometheus multiprocess mode
os.environ["PROMETHEUS_MULTIPROC_DIR"] = "/tmp/prometheus_multiproc_dir"
if not os.path.exists("/tmp/prometheus_multiproc_dir"):
os.makedirs("/tmp/prometheus_multiproc_dir")

registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)

class PrometheusLogger(ls.Logger):
"""Custom logger for Prometheus metrics."""
def __init__(self):
super().__init__()
self.function_duration = Histogram(
"combined_request_processing_seconds",
"Time spent processing combined API request",
["function_name"],
registry=registry
)

def process(self, key: str, value: float) -> None:
"""Record metric observations for function durations."""
self.function_duration.labels(function_name=key).observe(value)

class CombinedRequest(BaseModel):
"""
Pydantic model for incoming requests to the combined endpoint.
The `model_name` field is used to select which model to route to.
Other fields depend on the target model, so they are optional here.
"""
model_name: str = Field(..., description="Name of the model to use (e.g., 'ltx' or 'mochi').")
# Any additional fields will be passed through to the selected model's decode_request.
# We keep this flexible by using an extra allowed attributes pattern.
# For more strict validation, define fields matching each model's requirements.
class Config:
extra = "allow"

class CombinedAPI(ls.LitAPI):
"""
A combined API class that delegates requests to multiple model-specific APIs
based on the `model_name` field in the request.
This approach allows adding new models by:
1. Importing their API class.
2. Initializing and registering them in `model_apis` dictionary.
"""
def setup(self, device: str) -> None:
"""Setup all sub-model APIs and logging/metrics."""

logger.info(f"Initializing combined API with device={device}")

# Initialize sub-model APIs
self.ltx_api = LTXVideoAPI()
self.mochi_api = MochiVideoAPI()

# Setup each sub-model on the provided device
self.ltx_api.setup(device=device)
self.mochi_api.setup(device=device)

# Register them in a dictionary for easy routing
self.model_apis = {
"ltx": self.ltx_api,
"mochi": self.mochi_api
}

logger.info("Combined API setup completed successfully.")

def decode_request(
self,
request: Union[Dict[str, Any], List[Dict[str, Any]]]
) -> Dict[str, Any]:
"""
Decode the incoming request to determine which model to use.
We expect `model_name` to route the request accordingly.
The rest of the fields will be passed to the chosen model's decode_request.
"""
if isinstance(request, list):
# We handle only single requests for simplicity
request = request[0]

validated = CombinedRequest(**request).dict()
model_name = validated.pop("model_name").lower()

if model_name not in self.model_apis:
raise ValueError(f"Unknown model_name '{model_name}'. Available: {list(self.model_apis.keys())}")

# We'll store the selected model_name and request data
return {
"model_name": model_name,
"request_data": validated
}

def predict(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
Perform prediction by routing to the chosen model API.
Steps:
1. Extract model_name and request_data.
2. Pass request_data to the chosen model's decode_request -> predict pipeline.
3. Return the predictions from the model.
"""
model_name = inputs["model_name"]
request_data = inputs["request_data"]
model_api = self.model_apis[model_name]

start_time = time.time()

try:
# The sub-model APIs typically handle lists of requests.
# We'll wrap request_data in a list if needed.
decoded = model_api.decode_request(request_data)
# decoded is typically a list of requests for that model
predictions = model_api.predict(decoded)
# predictions is typically a list of results
result = predictions[0] if predictions else {"status": "error", "error": "No result returned"}

end_time = time.time()
self.log("combined_inference_time", end_time - start_time)

return {
"model_name": model_name,
"result": result
}

except Exception as e:
import traceback
logger.error(f"Error in combined predict: {e}\n{traceback.format_exc()}")
return {
"model_name": model_name,
"status": "error",
"error": str(e),
"traceback": traceback.format_exc()
}
finally:
if torch.cuda.is_available():
torch.cuda.empty_cache()

def encode_response(self, output: Dict[str, Any]) -> Dict[str, Any]:
"""
Encode the final response. We call the chosen model's encode_response if the result
is from a model inference. If there's an error at the combined level, we return a generic error response.
"""
model_name = output.get("model_name")
if model_name and model_name in self.model_apis:
# If there's a result from the model, encode it using the model's encoder
result = output.get("result", {})
if result.get("status") == "error":
# Model-specific error case
return {
"status": "error",
"error": result.get("error", "Unknown error"),
"traceback": result.get("traceback", None)
}
# Successful result
encoded = self.model_apis[model_name].encode_response(result)
# Add the model name to the final response for clarity
encoded["model_name"] = model_name
return encoded
else:
# If we got here, there's a top-level routing error
return {
"status": "error",
"error": output.get("error", "Unknown top-level error"),
"traceback": output.get("traceback", None)
}


def main():
"""Main entry point to run the combined server."""
# Set up Prometheus logger
prometheus_logger = PrometheusLogger()
prometheus_logger.mount(
path="/api/v1/metrics",
app=make_asgi_app(registry=registry)
)

# Configure logging
logger.remove()
logger.add(
sys.stdout,
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan> - <level>{message}</level>",
level="INFO"
)
logger.add(
"logs/combined_api.log",
rotation="100 MB",
retention="1 week",
level="DEBUG"
)

logger.info("Starting Combined Video Generation Server on port 8000")

# Initialize and run the combined server
api = CombinedAPI()
server = ls.LitServer(
api,
api_path="/predict", # A single endpoint for all models
accelerator="auto",
devices="auto",
max_batch_size=1,
track_requests=True,
loggers=prometheus_logger
)
server.run(port=8000)

if __name__ == "__main__":
main()
Loading

0 comments on commit 9b9c09c

Please sign in to comment.