-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Allegro inference configuration and video generation scripts
- Introduced `allegro_settings.py` for managing Allegro model settings with Pydantic. - Created `allegro_diffusers.py` to handle video generation using the Allegro inference pipeline. - Implemented logging for better traceability during model initialization and video generation. - Added error handling for robust pipeline setup and video export processes.
- Loading branch information
Showing
3 changed files
with
151 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from pydantic_settings import BaseSettings | ||
|
||
class AllegroSettings(BaseSettings): | ||
""" | ||
A Pydantic settings class for Allegro inference configuration. | ||
This class uses Pydantic to provide validation and easy environment-based configuration | ||
for Allegro inference pipeline settings. | ||
""" | ||
|
||
model_name:str = "rhymes-ai/Allegro" | ||
device: str = "cuda" | ||
seed: int = 42 | ||
guidance_scale: float = 7.5 | ||
max_sequence_length: int = 512 | ||
num_inference_steps: int = 100 | ||
fps: int = 15 | ||
|
||
class Config: | ||
""" | ||
Pydantic configuration class for environment variable support. | ||
""" | ||
env_prefix = "ALLEGRO_" # Prefix for environment variables | ||
validate_assignment = True | ||
|
||
def __repr__(self): | ||
""" | ||
Return a string representation of the settings for debugging purposes. | ||
:return: A string summarizing the settings. | ||
""" | ||
return (f"AllegroSettings(model_name={self.model_name}, device={self.device}, seed={self.seed}, " | ||
f"guidance_scale={self.guidance_scale}, max_sequence_length={self.max_sequence_length}, " | ||
f"num_inference_steps={self.num_inference_steps}, fps={self.fps})") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import torch | ||
from diffusers import AutoencoderKLAllegro, AllegroPipeline | ||
from diffusers.utils import export_to_video | ||
from loguru import logger | ||
from configs.allegro_settings import AllegroSettings | ||
|
||
class AllegroInference: | ||
""" | ||
A class for managing the Allegro inference pipeline for generating videos based on textual prompts. | ||
This class encapsulates the initialization, configuration, and video generation processes | ||
for the Allegro model pipeline. It provides a streamlined way to handle prompts, model setup, | ||
and output file management in a production-grade environment. | ||
""" | ||
|
||
def __init__(self, settings: AllegroSettings): | ||
""" | ||
Initialize the AllegroInference class with the given settings. | ||
:param settings: An instance of AllegroSettings containing model, device, and generation parameters. | ||
""" | ||
self.settings = settings | ||
self.pipe = None | ||
|
||
logger.info(f"Initializing {self.settings.model_name} inference pipeline") | ||
self._setup_pipeline() | ||
|
||
def _setup_pipeline(self): | ||
""" | ||
Set up the Allegro model pipeline by loading the VAE and the pipeline with specified configurations. | ||
This method loads the models, moves them to the specified device, and enables tiling for | ||
efficient memory usage during inference. | ||
:raises Exception: If there is an error during the model loading or configuration process. | ||
""" | ||
try: | ||
# Load VAE | ||
logger.info("Loading VAE model...") | ||
vae = AutoencoderKLAllegro.from_pretrained( | ||
self.settings.model_name, | ||
subfolder="vae", | ||
torch_dtype=torch.float32 | ||
) | ||
|
||
# Load Allegro pipeline | ||
logger.info("Loading Allegro pipeline...") | ||
self.pipe = AllegroPipeline.from_pretrained( | ||
self.settings.model_name, | ||
vae=vae, | ||
torch_dtype=torch.bfloat16 | ||
) | ||
|
||
# Move pipeline to the specified device | ||
self.pipe.to(self.settings.device) | ||
|
||
# Enable tiling for efficient memory usage | ||
self.pipe.vae.enable_tiling() | ||
|
||
logger.info("Pipeline successfully initialized") | ||
except Exception as e: | ||
logger.error(f"Error initializing pipeline: {e}") | ||
raise | ||
|
||
def generate_video(self, prompt: str, positive_prompt: str, negative_prompt: str, output_path: str): | ||
""" | ||
Generate a video based on the provided prompts and save it to the specified path. | ||
:param prompt: The main textual description of the video scene. | ||
:param positive_prompt: Additional positive prompts to enhance quality and style. | ||
:param negative_prompt: Prompts to avoid undesirable features in the generated video. | ||
:param output_path: File path to save the generated video. | ||
:raises Exception: If there is an error during video generation or export. | ||
""" | ||
try: | ||
logger.info("Preparing prompts...") | ||
prompt = positive_prompt.format(prompt.lower().strip()) | ||
|
||
logger.info("Generating video...") | ||
generator = torch.Generator(device=self.settings.device).manual_seed(self.settings.seed) | ||
video_frames = self.pipe( | ||
prompt, | ||
negative_prompt=negative_prompt, | ||
guidance_scale=self.settings.guidance_scale, | ||
max_sequence_length=self.settings.max_sequence_length, | ||
num_inference_steps=self.settings.num_inference_steps, | ||
generator=generator | ||
).frames[0] | ||
|
||
logger.info(f"Exporting video to {output_path}...") | ||
export_to_video(video_frames, output_path, fps=self.settings.fps) | ||
|
||
logger.info("Video generation completed successfully") | ||
except Exception as e: | ||
logger.error(f"Error during video generation: {e}") | ||
raise | ||
|
||
# Example usage (to be executed in a main script or testing environment) | ||
if __name__ == "__main__": | ||
settings = AllegroSettings() | ||
inference = AllegroInference(settings) | ||
|
||
prompt = "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this location might be a popular spot for docking fishing boats." | ||
positive_prompt = """ | ||
(masterpiece), (best quality), (ultra-detailed), (unwatermarked), | ||
{} | ||
emotional, harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, | ||
sharp focus, high budget, cinemascope, moody, epic, gorgeous | ||
""" | ||
|
||
negative_prompt = """ | ||
nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, | ||
low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry. | ||
""" | ||
|
||
output_path = "output.mp4" | ||
inference.generate_video(prompt, positive_prompt, negative_prompt, output_path) |