diff --git a/api/client.py b/api/client.py
deleted file mode 100644
index 12b71e5..0000000
--- a/api/client.py
+++ /dev/null
@@ -1,18 +0,0 @@
-
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import requests
-
-response = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0})
-print(f"Status: {response.status_code}\nResponse:\n {response.text}")
diff --git a/api/serve.py b/api/serve.py
new file mode 100644
index 0000000..cf6fb61
--- /dev/null
+++ b/api/serve.py
@@ -0,0 +1,257 @@
+"""
+Combined API router for multiple LitServe-based models.
+
+This script imports multiple model-specific LitAPI classes (e.g., LTXVideoAPI
+and MochiVideoAPI) and integrates them into a single endpoint. Clients specify
+which model to invoke by providing a `model_name` field in the request body.
+
+Features:
+- Single endpoint routing for multiple models
+- Prometheus metrics for request duration tracking
+- Comprehensive logging (stdout and file) with loguru
+- Detailed docstrings and structured JSON responses
+- Extensible: Just add new model APIs and register them in `model_apis`.
+
+Usage:
+1. Ensure `ltx_serve.py` and `mochi_serve.py` are in the same directory.
+2. Run `python combined_serve.py`.
+3. Send POST requests to `http://localhost:8000/predict` with JSON like:
+ {
+ "model_name": "ltx",
+ "prompt": "Generate a video about a sunny day at the beach"
+ }
+
+ or
+
+ {
+ "model_name": "mochi",
+ "prompt": "Generate a video about a futuristic city"
+ }
+"""
+
+import sys
+import os
+import time
+from typing import Dict, Any, List, Union
+from pydantic import BaseModel, Field
+from loguru import logger
+
+import torch
+import litserve as ls
+from prometheus_client import (
+ CollectorRegistry,
+ Histogram,
+ make_asgi_app,
+ multiprocess
+)
+
+# Import the individual model APIs
+from ltx_serve import LTXVideoAPI
+from mochi_serve import MochiVideoAPI
+
+# Setup Prometheus multiprocess mode
+os.environ["PROMETHEUS_MULTIPROC_DIR"] = "/tmp/prometheus_multiproc_dir"
+if not os.path.exists("/tmp/prometheus_multiproc_dir"):
+ os.makedirs("/tmp/prometheus_multiproc_dir")
+
+registry = CollectorRegistry()
+multiprocess.MultiProcessCollector(registry)
+
+class PrometheusLogger(ls.Logger):
+ """Custom logger for Prometheus metrics."""
+ def __init__(self):
+ super().__init__()
+ self.function_duration = Histogram(
+ "combined_request_processing_seconds",
+ "Time spent processing combined API request",
+ ["function_name"],
+ registry=registry
+ )
+
+ def process(self, key: str, value: float) -> None:
+ """Record metric observations for function durations."""
+ self.function_duration.labels(function_name=key).observe(value)
+
+class CombinedRequest(BaseModel):
+ """
+ Pydantic model for incoming requests to the combined endpoint.
+ The `model_name` field is used to select which model to route to.
+ Other fields depend on the target model, so they are optional here.
+ """
+ model_name: str = Field(..., description="Name of the model to use (e.g., 'ltx' or 'mochi').")
+ # Any additional fields will be passed through to the selected model's decode_request.
+ # We keep this flexible by using an extra allowed attributes pattern.
+ # For more strict validation, define fields matching each model's requirements.
+ class Config:
+ extra = "allow"
+
+class CombinedAPI(ls.LitAPI):
+ """
+ A combined API class that delegates requests to multiple model-specific APIs
+ based on the `model_name` field in the request.
+
+ This approach allows adding new models by:
+ 1. Importing their API class.
+ 2. Initializing and registering them in `model_apis` dictionary.
+ """
+ def setup(self, device: str) -> None:
+ """Setup all sub-model APIs and logging/metrics."""
+
+ logger.info(f"Initializing combined API with device={device}")
+
+ # Initialize sub-model APIs
+ self.ltx_api = LTXVideoAPI()
+ self.mochi_api = MochiVideoAPI()
+
+ # Setup each sub-model on the provided device
+ self.ltx_api.setup(device=device)
+ self.mochi_api.setup(device=device)
+
+ # Register them in a dictionary for easy routing
+ self.model_apis = {
+ "ltx": self.ltx_api,
+ "mochi": self.mochi_api
+ }
+
+ logger.info("Combined API setup completed successfully.")
+
+ def decode_request(
+ self,
+ request: Union[Dict[str, Any], List[Dict[str, Any]]]
+ ) -> Dict[str, Any]:
+ """
+ Decode the incoming request to determine which model to use.
+ We expect `model_name` to route the request accordingly.
+ The rest of the fields will be passed to the chosen model's decode_request.
+ """
+ if isinstance(request, list):
+ # We handle only single requests for simplicity
+ request = request[0]
+
+ validated = CombinedRequest(**request).dict()
+ model_name = validated.pop("model_name").lower()
+
+ if model_name not in self.model_apis:
+ raise ValueError(f"Unknown model_name '{model_name}'. Available: {list(self.model_apis.keys())}")
+
+ # We'll store the selected model_name and request data
+ return {
+ "model_name": model_name,
+ "request_data": validated
+ }
+
+ def predict(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Perform prediction by routing to the chosen model API.
+
+ Steps:
+ 1. Extract model_name and request_data.
+ 2. Pass request_data to the chosen model's decode_request -> predict pipeline.
+ 3. Return the predictions from the model.
+ """
+ model_name = inputs["model_name"]
+ request_data = inputs["request_data"]
+ model_api = self.model_apis[model_name]
+
+ start_time = time.time()
+
+ try:
+ # The sub-model APIs typically handle lists of requests.
+ # We'll wrap request_data in a list if needed.
+ decoded = model_api.decode_request(request_data)
+ # decoded is typically a list of requests for that model
+ predictions = model_api.predict(decoded)
+ # predictions is typically a list of results
+ result = predictions[0] if predictions else {"status": "error", "error": "No result returned"}
+
+ end_time = time.time()
+ self.log("combined_inference_time", end_time - start_time)
+
+ return {
+ "model_name": model_name,
+ "result": result
+ }
+
+ except Exception as e:
+ import traceback
+ logger.error(f"Error in combined predict: {e}\n{traceback.format_exc()}")
+ return {
+ "model_name": model_name,
+ "status": "error",
+ "error": str(e),
+ "traceback": traceback.format_exc()
+ }
+ finally:
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ def encode_response(self, output: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Encode the final response. We call the chosen model's encode_response if the result
+ is from a model inference. If there's an error at the combined level, we return a generic error response.
+ """
+ model_name = output.get("model_name")
+ if model_name and model_name in self.model_apis:
+ # If there's a result from the model, encode it using the model's encoder
+ result = output.get("result", {})
+ if result.get("status") == "error":
+ # Model-specific error case
+ return {
+ "status": "error",
+ "error": result.get("error", "Unknown error"),
+ "traceback": result.get("traceback", None)
+ }
+ # Successful result
+ encoded = self.model_apis[model_name].encode_response(result)
+ # Add the model name to the final response for clarity
+ encoded["model_name"] = model_name
+ return encoded
+ else:
+ # If we got here, there's a top-level routing error
+ return {
+ "status": "error",
+ "error": output.get("error", "Unknown top-level error"),
+ "traceback": output.get("traceback", None)
+ }
+
+
+def main():
+ """Main entry point to run the combined server."""
+ # Set up Prometheus logger
+ prometheus_logger = PrometheusLogger()
+ prometheus_logger.mount(
+ path="/api/v1/metrics",
+ app=make_asgi_app(registry=registry)
+ )
+
+ # Configure logging
+ logger.remove()
+ logger.add(
+ sys.stdout,
+ format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function} - {message}",
+ level="INFO"
+ )
+ logger.add(
+ "logs/combined_api.log",
+ rotation="100 MB",
+ retention="1 week",
+ level="DEBUG"
+ )
+
+ logger.info("Starting Combined Video Generation Server on port 8000")
+
+ # Initialize and run the combined server
+ api = CombinedAPI()
+ server = ls.LitServer(
+ api,
+ api_path="/predict", # A single endpoint for all models
+ accelerator="auto",
+ devices="auto",
+ max_batch_size=1,
+ track_requests=True,
+ loggers=prometheus_logger
+ )
+ server.run(port=8000)
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/convert_mochi_to_diffusers.py b/scripts/convert_mochi_to_diffusers.py
index f75791f..9399cab 100644
--- a/scripts/convert_mochi_to_diffusers.py
+++ b/scripts/convert_mochi_to_diffusers.py
@@ -1,3 +1,23 @@
+"""
+Mochi Model Checkpoint Converter
+
+This script converts Mochi model checkpoints from their original format to the Diffusers format.
+It handles three main components:
+1. The transformer model
+2. The VAE (encoder and decoder)
+3. The text encoder (T5)
+
+The script provides utility functions to:
+- Convert state dict keys between formats
+- Handle weight transformations and reshaping
+- Create and save a complete Diffusers pipeline
+
+Usage:
+ python convert_mochi_to_diffusers.py
+
+Configuration is handled via the MochiWeightsSettings class (see configs/mochi_weights.py)
+"""
+
from contextlib import nullcontext
import torch
from accelerate import init_empty_weights
@@ -10,24 +30,55 @@
CTX = init_empty_weights if is_accelerate_available else nullcontext
def swap_scale_shift(weight, dim):
+ """
+ Swaps the order of scale and shift parameters in normalization layers.
+
+ Args:
+ weight (torch.Tensor): Input tensor containing scale and shift parameters
+ dim (int): Dimension along which to split the parameters
+
+ Returns:
+ torch.Tensor: Reordered tensor with scale parameters first, then shift parameters
+ """
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def swap_proj_gate(weight):
+ """
+ Swaps projection and gate weights in attention mechanisms.
+
+ Args:
+ weight (torch.Tensor): Input tensor containing projection and gate weights
+
+ Returns:
+ torch.Tensor: Reordered tensor with gate weights first, then projection weights
+ """
proj, gate = weight.chunk(2, dim=0)
new_weight = torch.cat([gate, proj], dim=0)
return new_weight
def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
+ """
+ Converts a Mochi transformer checkpoint to the Diffusers format.
+
+ This function handles the conversion of:
+ - Embedding layers
+ - Transformer blocks
+ - Attention mechanisms
+ - Normalization layers
+ - Output projections
+
+ Args:
+ ckpt_path (str): Path to the original Mochi checkpoint file
+
+ Returns:
+ dict: State dictionary in Diffusers format
+ """
original_state_dict = load_file(ckpt_path, device="cpu")
new_state_dict = {}
-
- # Convert patch_embed
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight")
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias")
-
- # Convert time_embed
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop("t_embedder.mlp.0.weight")
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("t_embedder.mlp.0.bias")
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop("t_embedder.mlp.2.weight")
@@ -122,9 +173,24 @@ def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
return new_state_dict
-
-
def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_path):
+ """
+ Converts Mochi VAE encoder and decoder checkpoints to the Diffusers format.
+
+ This function handles:
+ - Input/output projections
+ - ResNet blocks
+ - Up/down sampling blocks
+ - Attention layers
+ - Normalization layers
+
+ Args:
+ encoder_ckpt_path (str): Path to the VAE encoder checkpoint
+ decoder_ckpt_path (str): Path to the VAE decoder checkpoint
+
+ Returns:
+ dict: Combined state dictionary for both encoder and decoder in Diffusers format
+ """
encoder_state_dict = load_file(encoder_ckpt_path, device="cpu")
decoder_state_dict = load_file(decoder_ckpt_path, device="cpu")
new_state_dict = {}
@@ -362,6 +428,29 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
return new_state_dict
def main(settings: MochiWeightsSettings = MochiWeightsSettings()):
+ """
+ Main execution function that orchestrates the conversion process.
+
+ This function:
+ 1. Configures the computation dtype (fp16, bf16, or fp32)
+ 2. Converts the transformer model if a checkpoint is provided
+ 3. Converts the VAE if encoder/decoder checkpoints are provided
+ 4. Loads the T5 text encoder and tokenizer
+ 5. Creates and saves a complete Diffusers pipeline
+
+ Args:
+ settings (MochiWeightsSettings): Configuration object containing paths and options
+ - dtype: Computation precision ("fp16", "bf16", "fp32", or None)
+ - transformer_checkpoint_path: Path to transformer checkpoint
+ - vae_encoder_checkpoint_path: Path to VAE encoder checkpoint
+ - vae_decoder_checkpoint_path: Path to VAE decoder checkpoint
+ - text_encoder_cache_dir: Cache directory for T5 model
+ - output_path: Where to save the converted pipeline
+ - push_to_hub: Whether to push the converted model to HuggingFace Hub
+
+ Raises:
+ ValueError: If an unsupported dtype is specified
+ """
if settings.dtype is None:
dtype = None
elif settings.dtype == "fp16":
diff --git a/scripts/mochi_diffusers.py b/scripts/mochi_diffusers.py
index d3bc232..a8d406e 100644
--- a/scripts/mochi_diffusers.py
+++ b/scripts/mochi_diffusers.py
@@ -1,6 +1,6 @@
"""
Inference class for Mochi video generation model.
-Tested on A6000
+Tested on A6000 /A40 / A100
Generates Video for 20 inference steps at 480 resolution in 10 minutes
"""
@@ -8,10 +8,11 @@
from typing import Optional, Union, List, Tuple
import torch
from loguru import logger
-from diffusers import MochiPipeline, MochiTransformer3DModel
+from diffusers import MochiPipeline, MochiTransformer3DModel, AutoencoderKLMochi
from diffusers.utils import export_to_video
-
+from diffusers.video_processor import VideoProcessor
from configs.mochi_settings import MochiSettings
+import gc
class MochiInference:
"""
@@ -112,39 +113,85 @@ def generate(
Raises:
RuntimeError: If video generation fails.
"""
+ logger.info("Starting video generation for prompt: {}", prompt)
try:
# Set random seed if provided
if seed is not None:
logger.info("Setting random seed to {}", seed)
torch.manual_seed(seed)
- # Build generation parameters
- params = {
- "prompt": prompt,
- "negative_prompt": negative_prompt,
- "num_inference_steps": num_inference_steps or self.settings.num_inference_steps,
- "guidance_scale": guidance_scale or self.settings.guidance_scale,
- "height": height or self.settings.height,
- "width": width or self.settings.width,
- "num_frames": num_frames or self.settings.num_frames,
- }
+ # Encode prompt first and free text encoder memory
+ with torch.no_grad():
+ logger.debug("Encoding prompt")
+ prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
+ self.pipe.encode_prompt(prompt=prompt)
+ )
- logger.info("Generating video with prompt: {}", prompt)
- logger.debug("Generation parameters: {}", params)
+ logger.debug("Freeing text encoder memory")
+ del self.pipe.text_encoder
+ gc.collect()
- frames = self.pipe(**params).frames[0]
+ # Generate latents
+ logger.debug("Generating latents")
+ with torch.autocast("cuda", dtype=torch.bfloat16):
+ frames = self.pipe(
+ prompt_embeds=prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_embeds=negative_prompt_embeds,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ guidance_scale=guidance_scale or self.settings.guidance_scale,
+ num_inference_steps=num_inference_steps or self.settings.num_inference_steps,
+ height=height or self.settings.height,
+ width=width or self.settings.width,
+ num_frames=num_frames or self.settings.num_frames,
+ output_type="latent",
+ return_dict=False,
+ )[0]
+
+ logger.debug("Freeing transformer memory")
+ del self.pipe.transformer
+ gc.collect()
+
+ # Setup VAE and process latents
+ logger.debug("Loading VAE model")
+ vae = AutoencoderKLMochi.from_pretrained(
+ self.settings.pipeline_path,
+ subfolder="vae"
+ ).to(self.settings.device)
+ vae._enable_framewise_decoding()
+ # Scale latents appropriately
+ logger.debug("Scaling latents")
+ if hasattr(vae.config, "latents_mean") and hasattr(vae.config, "latents_std"):
+ latents_mean = torch.tensor(vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
+ latents_std = torch.tensor(vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
+ frames = frames * latents_std / vae.config.scaling_factor + latents_mean
+ else:
+ frames = frames / vae.config.scaling_factor
+
+ # Decode frames
+ logger.debug("Decoding frames")
+ with torch.no_grad():
+ video = vae.decode(frames.to(vae.dtype), return_dict=False)[0]
+
+ logger.debug("Post-processing video")
+ video_processor = VideoProcessor(vae_scale_factor=8)
+ video = video_processor.postprocess_video(video)[0]
+
+ # Save or return video
if output_path:
+ logger.info("Saving video to {}", output_path)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
fps = fps or self.settings.fps
- export_to_video(frames, str(output_path), fps=fps)
+ export_to_video(video, str(output_path), fps=fps)
logger.success("Video saved to: {}", output_path)
return str(output_path)
- return export_to_video(frames[0])
-
+ logger.success("Video generation completed successfully")
+ return video
+
except Exception as e:
logger.exception("Video generation failed")
raise RuntimeError(f"Video generation failed: {str(e)}") from e
@@ -200,10 +247,6 @@ def clear_memory(self) -> None:
print(f"Video saved to: {video_path}")
except RuntimeError as e:
print(f"Failed to generate video: {e}")
-
- # Display GPU memory usage for debugging
allocated, max_allocated = mochi_inference.get_memory_usage()
print(f"Memory usage: {allocated:.2f}GB (peak: {max_allocated:.2f}GB)")
-
- # Clear memory cache after inference
mochi_inference.clear_memory()
diff --git a/src/scripts/upscale-video.py b/src/scripts/upscale-video.py
deleted file mode 100644
index 7f9120e..0000000
--- a/src/scripts/upscale-video.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from aura_sr import AuraSR
-aura_sr = AuraSR.from_pretrained('fal-ai/AuraSR')
-