From 9b9c09c52d8b3079f0c2ac2f199743783cadb98e Mon Sep 17 00:00:00 2001 From: Vikramjeet Date: Tue, 10 Dec 2024 20:09:47 +0530 Subject: [PATCH] Refactor API structure and enhance Mochi model conversion - Deleted the outdated client.py file to streamline the API. - Introduced serve.py, a new combined API router for multiple LitServe-based models, allowing clients to specify the model via a single endpoint. - Updated convert_mochi_to_diffusers.py with detailed docstrings and improved conversion functions for Mochi model checkpoints. - Enhanced mochi_diffusers.py to support additional GPU models and optimize memory management during video generation. - Removed the upscale-video.py script as it is no longer needed. These changes aim to improve the API's usability and maintainability while enhancing the Mochi model's integration with the Diffusers framework. --- api/client.py | 18 -- api/serve.py | 257 ++++++++++++++++++++++++++ scripts/convert_mochi_to_diffusers.py | 101 +++++++++- scripts/mochi_diffusers.py | 89 ++++++--- src/scripts/upscale-video.py | 3 - 5 files changed, 418 insertions(+), 50 deletions(-) delete mode 100644 api/client.py create mode 100644 api/serve.py delete mode 100644 src/scripts/upscale-video.py diff --git a/api/client.py b/api/client.py deleted file mode 100644 index 12b71e5..0000000 --- a/api/client.py +++ /dev/null @@ -1,18 +0,0 @@ - -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import requests - -response = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0}) -print(f"Status: {response.status_code}\nResponse:\n {response.text}") diff --git a/api/serve.py b/api/serve.py new file mode 100644 index 0000000..cf6fb61 --- /dev/null +++ b/api/serve.py @@ -0,0 +1,257 @@ +""" +Combined API router for multiple LitServe-based models. + +This script imports multiple model-specific LitAPI classes (e.g., LTXVideoAPI +and MochiVideoAPI) and integrates them into a single endpoint. Clients specify +which model to invoke by providing a `model_name` field in the request body. + +Features: +- Single endpoint routing for multiple models +- Prometheus metrics for request duration tracking +- Comprehensive logging (stdout and file) with loguru +- Detailed docstrings and structured JSON responses +- Extensible: Just add new model APIs and register them in `model_apis`. + +Usage: +1. Ensure `ltx_serve.py` and `mochi_serve.py` are in the same directory. +2. Run `python combined_serve.py`. +3. Send POST requests to `http://localhost:8000/predict` with JSON like: + { + "model_name": "ltx", + "prompt": "Generate a video about a sunny day at the beach" + } + + or + + { + "model_name": "mochi", + "prompt": "Generate a video about a futuristic city" + } +""" + +import sys +import os +import time +from typing import Dict, Any, List, Union +from pydantic import BaseModel, Field +from loguru import logger + +import torch +import litserve as ls +from prometheus_client import ( + CollectorRegistry, + Histogram, + make_asgi_app, + multiprocess +) + +# Import the individual model APIs +from ltx_serve import LTXVideoAPI +from mochi_serve import MochiVideoAPI + +# Setup Prometheus multiprocess mode +os.environ["PROMETHEUS_MULTIPROC_DIR"] = "/tmp/prometheus_multiproc_dir" +if not os.path.exists("/tmp/prometheus_multiproc_dir"): + os.makedirs("/tmp/prometheus_multiproc_dir") + +registry = CollectorRegistry() +multiprocess.MultiProcessCollector(registry) + +class PrometheusLogger(ls.Logger): + """Custom logger for Prometheus metrics.""" + def __init__(self): + super().__init__() + self.function_duration = Histogram( + "combined_request_processing_seconds", + "Time spent processing combined API request", + ["function_name"], + registry=registry + ) + + def process(self, key: str, value: float) -> None: + """Record metric observations for function durations.""" + self.function_duration.labels(function_name=key).observe(value) + +class CombinedRequest(BaseModel): + """ + Pydantic model for incoming requests to the combined endpoint. + The `model_name` field is used to select which model to route to. + Other fields depend on the target model, so they are optional here. + """ + model_name: str = Field(..., description="Name of the model to use (e.g., 'ltx' or 'mochi').") + # Any additional fields will be passed through to the selected model's decode_request. + # We keep this flexible by using an extra allowed attributes pattern. + # For more strict validation, define fields matching each model's requirements. + class Config: + extra = "allow" + +class CombinedAPI(ls.LitAPI): + """ + A combined API class that delegates requests to multiple model-specific APIs + based on the `model_name` field in the request. + + This approach allows adding new models by: + 1. Importing their API class. + 2. Initializing and registering them in `model_apis` dictionary. + """ + def setup(self, device: str) -> None: + """Setup all sub-model APIs and logging/metrics.""" + + logger.info(f"Initializing combined API with device={device}") + + # Initialize sub-model APIs + self.ltx_api = LTXVideoAPI() + self.mochi_api = MochiVideoAPI() + + # Setup each sub-model on the provided device + self.ltx_api.setup(device=device) + self.mochi_api.setup(device=device) + + # Register them in a dictionary for easy routing + self.model_apis = { + "ltx": self.ltx_api, + "mochi": self.mochi_api + } + + logger.info("Combined API setup completed successfully.") + + def decode_request( + self, + request: Union[Dict[str, Any], List[Dict[str, Any]]] + ) -> Dict[str, Any]: + """ + Decode the incoming request to determine which model to use. + We expect `model_name` to route the request accordingly. + The rest of the fields will be passed to the chosen model's decode_request. + """ + if isinstance(request, list): + # We handle only single requests for simplicity + request = request[0] + + validated = CombinedRequest(**request).dict() + model_name = validated.pop("model_name").lower() + + if model_name not in self.model_apis: + raise ValueError(f"Unknown model_name '{model_name}'. Available: {list(self.model_apis.keys())}") + + # We'll store the selected model_name and request data + return { + "model_name": model_name, + "request_data": validated + } + + def predict(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """ + Perform prediction by routing to the chosen model API. + + Steps: + 1. Extract model_name and request_data. + 2. Pass request_data to the chosen model's decode_request -> predict pipeline. + 3. Return the predictions from the model. + """ + model_name = inputs["model_name"] + request_data = inputs["request_data"] + model_api = self.model_apis[model_name] + + start_time = time.time() + + try: + # The sub-model APIs typically handle lists of requests. + # We'll wrap request_data in a list if needed. + decoded = model_api.decode_request(request_data) + # decoded is typically a list of requests for that model + predictions = model_api.predict(decoded) + # predictions is typically a list of results + result = predictions[0] if predictions else {"status": "error", "error": "No result returned"} + + end_time = time.time() + self.log("combined_inference_time", end_time - start_time) + + return { + "model_name": model_name, + "result": result + } + + except Exception as e: + import traceback + logger.error(f"Error in combined predict: {e}\n{traceback.format_exc()}") + return { + "model_name": model_name, + "status": "error", + "error": str(e), + "traceback": traceback.format_exc() + } + finally: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def encode_response(self, output: Dict[str, Any]) -> Dict[str, Any]: + """ + Encode the final response. We call the chosen model's encode_response if the result + is from a model inference. If there's an error at the combined level, we return a generic error response. + """ + model_name = output.get("model_name") + if model_name and model_name in self.model_apis: + # If there's a result from the model, encode it using the model's encoder + result = output.get("result", {}) + if result.get("status") == "error": + # Model-specific error case + return { + "status": "error", + "error": result.get("error", "Unknown error"), + "traceback": result.get("traceback", None) + } + # Successful result + encoded = self.model_apis[model_name].encode_response(result) + # Add the model name to the final response for clarity + encoded["model_name"] = model_name + return encoded + else: + # If we got here, there's a top-level routing error + return { + "status": "error", + "error": output.get("error", "Unknown top-level error"), + "traceback": output.get("traceback", None) + } + + +def main(): + """Main entry point to run the combined server.""" + # Set up Prometheus logger + prometheus_logger = PrometheusLogger() + prometheus_logger.mount( + path="/api/v1/metrics", + app=make_asgi_app(registry=registry) + ) + + # Configure logging + logger.remove() + logger.add( + sys.stdout, + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function} - {message}", + level="INFO" + ) + logger.add( + "logs/combined_api.log", + rotation="100 MB", + retention="1 week", + level="DEBUG" + ) + + logger.info("Starting Combined Video Generation Server on port 8000") + + # Initialize and run the combined server + api = CombinedAPI() + server = ls.LitServer( + api, + api_path="/predict", # A single endpoint for all models + accelerator="auto", + devices="auto", + max_batch_size=1, + track_requests=True, + loggers=prometheus_logger + ) + server.run(port=8000) + +if __name__ == "__main__": + main() diff --git a/scripts/convert_mochi_to_diffusers.py b/scripts/convert_mochi_to_diffusers.py index f75791f..9399cab 100644 --- a/scripts/convert_mochi_to_diffusers.py +++ b/scripts/convert_mochi_to_diffusers.py @@ -1,3 +1,23 @@ +""" +Mochi Model Checkpoint Converter + +This script converts Mochi model checkpoints from their original format to the Diffusers format. +It handles three main components: +1. The transformer model +2. The VAE (encoder and decoder) +3. The text encoder (T5) + +The script provides utility functions to: +- Convert state dict keys between formats +- Handle weight transformations and reshaping +- Create and save a complete Diffusers pipeline + +Usage: + python convert_mochi_to_diffusers.py + +Configuration is handled via the MochiWeightsSettings class (see configs/mochi_weights.py) +""" + from contextlib import nullcontext import torch from accelerate import init_empty_weights @@ -10,24 +30,55 @@ CTX = init_empty_weights if is_accelerate_available else nullcontext def swap_scale_shift(weight, dim): + """ + Swaps the order of scale and shift parameters in normalization layers. + + Args: + weight (torch.Tensor): Input tensor containing scale and shift parameters + dim (int): Dimension along which to split the parameters + + Returns: + torch.Tensor: Reordered tensor with scale parameters first, then shift parameters + """ shift, scale = weight.chunk(2, dim=0) new_weight = torch.cat([scale, shift], dim=0) return new_weight def swap_proj_gate(weight): + """ + Swaps projection and gate weights in attention mechanisms. + + Args: + weight (torch.Tensor): Input tensor containing projection and gate weights + + Returns: + torch.Tensor: Reordered tensor with gate weights first, then projection weights + """ proj, gate = weight.chunk(2, dim=0) new_weight = torch.cat([gate, proj], dim=0) return new_weight def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path): + """ + Converts a Mochi transformer checkpoint to the Diffusers format. + + This function handles the conversion of: + - Embedding layers + - Transformer blocks + - Attention mechanisms + - Normalization layers + - Output projections + + Args: + ckpt_path (str): Path to the original Mochi checkpoint file + + Returns: + dict: State dictionary in Diffusers format + """ original_state_dict = load_file(ckpt_path, device="cpu") new_state_dict = {} - - # Convert patch_embed new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight") new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias") - - # Convert time_embed new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop("t_embedder.mlp.0.weight") new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("t_embedder.mlp.0.bias") new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop("t_embedder.mlp.2.weight") @@ -122,9 +173,24 @@ def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path): return new_state_dict - - def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_path): + """ + Converts Mochi VAE encoder and decoder checkpoints to the Diffusers format. + + This function handles: + - Input/output projections + - ResNet blocks + - Up/down sampling blocks + - Attention layers + - Normalization layers + + Args: + encoder_ckpt_path (str): Path to the VAE encoder checkpoint + decoder_ckpt_path (str): Path to the VAE decoder checkpoint + + Returns: + dict: Combined state dictionary for both encoder and decoder in Diffusers format + """ encoder_state_dict = load_file(encoder_ckpt_path, device="cpu") decoder_state_dict = load_file(decoder_ckpt_path, device="cpu") new_state_dict = {} @@ -362,6 +428,29 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa return new_state_dict def main(settings: MochiWeightsSettings = MochiWeightsSettings()): + """ + Main execution function that orchestrates the conversion process. + + This function: + 1. Configures the computation dtype (fp16, bf16, or fp32) + 2. Converts the transformer model if a checkpoint is provided + 3. Converts the VAE if encoder/decoder checkpoints are provided + 4. Loads the T5 text encoder and tokenizer + 5. Creates and saves a complete Diffusers pipeline + + Args: + settings (MochiWeightsSettings): Configuration object containing paths and options + - dtype: Computation precision ("fp16", "bf16", "fp32", or None) + - transformer_checkpoint_path: Path to transformer checkpoint + - vae_encoder_checkpoint_path: Path to VAE encoder checkpoint + - vae_decoder_checkpoint_path: Path to VAE decoder checkpoint + - text_encoder_cache_dir: Cache directory for T5 model + - output_path: Where to save the converted pipeline + - push_to_hub: Whether to push the converted model to HuggingFace Hub + + Raises: + ValueError: If an unsupported dtype is specified + """ if settings.dtype is None: dtype = None elif settings.dtype == "fp16": diff --git a/scripts/mochi_diffusers.py b/scripts/mochi_diffusers.py index d3bc232..a8d406e 100644 --- a/scripts/mochi_diffusers.py +++ b/scripts/mochi_diffusers.py @@ -1,6 +1,6 @@ """ Inference class for Mochi video generation model. -Tested on A6000 +Tested on A6000 /A40 / A100 Generates Video for 20 inference steps at 480 resolution in 10 minutes """ @@ -8,10 +8,11 @@ from typing import Optional, Union, List, Tuple import torch from loguru import logger -from diffusers import MochiPipeline, MochiTransformer3DModel +from diffusers import MochiPipeline, MochiTransformer3DModel, AutoencoderKLMochi from diffusers.utils import export_to_video - +from diffusers.video_processor import VideoProcessor from configs.mochi_settings import MochiSettings +import gc class MochiInference: """ @@ -112,39 +113,85 @@ def generate( Raises: RuntimeError: If video generation fails. """ + logger.info("Starting video generation for prompt: {}", prompt) try: # Set random seed if provided if seed is not None: logger.info("Setting random seed to {}", seed) torch.manual_seed(seed) - # Build generation parameters - params = { - "prompt": prompt, - "negative_prompt": negative_prompt, - "num_inference_steps": num_inference_steps or self.settings.num_inference_steps, - "guidance_scale": guidance_scale or self.settings.guidance_scale, - "height": height or self.settings.height, - "width": width or self.settings.width, - "num_frames": num_frames or self.settings.num_frames, - } + # Encode prompt first and free text encoder memory + with torch.no_grad(): + logger.debug("Encoding prompt") + prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = ( + self.pipe.encode_prompt(prompt=prompt) + ) - logger.info("Generating video with prompt: {}", prompt) - logger.debug("Generation parameters: {}", params) + logger.debug("Freeing text encoder memory") + del self.pipe.text_encoder + gc.collect() - frames = self.pipe(**params).frames[0] + # Generate latents + logger.debug("Generating latents") + with torch.autocast("cuda", dtype=torch.bfloat16): + frames = self.pipe( + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_attention_mask=negative_prompt_attention_mask, + guidance_scale=guidance_scale or self.settings.guidance_scale, + num_inference_steps=num_inference_steps or self.settings.num_inference_steps, + height=height or self.settings.height, + width=width or self.settings.width, + num_frames=num_frames or self.settings.num_frames, + output_type="latent", + return_dict=False, + )[0] + + logger.debug("Freeing transformer memory") + del self.pipe.transformer + gc.collect() + + # Setup VAE and process latents + logger.debug("Loading VAE model") + vae = AutoencoderKLMochi.from_pretrained( + self.settings.pipeline_path, + subfolder="vae" + ).to(self.settings.device) + vae._enable_framewise_decoding() + # Scale latents appropriately + logger.debug("Scaling latents") + if hasattr(vae.config, "latents_mean") and hasattr(vae.config, "latents_std"): + latents_mean = torch.tensor(vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype) + latents_std = torch.tensor(vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype) + frames = frames * latents_std / vae.config.scaling_factor + latents_mean + else: + frames = frames / vae.config.scaling_factor + + # Decode frames + logger.debug("Decoding frames") + with torch.no_grad(): + video = vae.decode(frames.to(vae.dtype), return_dict=False)[0] + + logger.debug("Post-processing video") + video_processor = VideoProcessor(vae_scale_factor=8) + video = video_processor.postprocess_video(video)[0] + + # Save or return video if output_path: + logger.info("Saving video to {}", output_path) output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) fps = fps or self.settings.fps - export_to_video(frames, str(output_path), fps=fps) + export_to_video(video, str(output_path), fps=fps) logger.success("Video saved to: {}", output_path) return str(output_path) - return export_to_video(frames[0]) - + logger.success("Video generation completed successfully") + return video + except Exception as e: logger.exception("Video generation failed") raise RuntimeError(f"Video generation failed: {str(e)}") from e @@ -200,10 +247,6 @@ def clear_memory(self) -> None: print(f"Video saved to: {video_path}") except RuntimeError as e: print(f"Failed to generate video: {e}") - - # Display GPU memory usage for debugging allocated, max_allocated = mochi_inference.get_memory_usage() print(f"Memory usage: {allocated:.2f}GB (peak: {max_allocated:.2f}GB)") - - # Clear memory cache after inference mochi_inference.clear_memory() diff --git a/src/scripts/upscale-video.py b/src/scripts/upscale-video.py deleted file mode 100644 index 7f9120e..0000000 --- a/src/scripts/upscale-video.py +++ /dev/null @@ -1,3 +0,0 @@ -from aura_sr import AuraSR -aura_sr = AuraSR.from_pretrained('fal-ai/AuraSR') -