My VASA hack project has running /training code stage 1 (megaportraits) - with hot fixes
Implementation of Megaportrait using Claude Opus
All models / code is in
memory debug
mprof run
or just
- Save / restore checkpoint) specify in config ./configs/training/stage10base.yaml to restore checkpoint
- auto crop video frames to sweet spot
- tensorboard losses
- additional imagepyramide from one shot view code for loss - (this broke things..)
warp / crop / spline / remove background / transforms
- Total Videos: 35,000 facial videos
- Total Size: 40GB
for now - to simplify problem - use the 4 videos in junk folder. once models are validated - can point the video_dir to above torrent
# video_dir: '/Downloads/CelebV-HQ/celebvhq/35666'
video_dir: './junk'
the preprocessing is taking 1-2 mins for each video - I add some saving to npz format for faster reloading.
You can download the dataset via the provided magnet link or by visiting Academic Torrents.
- Description: Responsible for creating the foundational neural head avatar at a medium resolution of (512 x 512). Uses volumetric features to encode appearance and latent descriptors to encode motion.
- Components:
- Appearance Encoder (
): Encodes the appearance of the source frame into volumetric features and a global descriptor.class Eapp(nn.Module): # Architecture details omitted for brevity
- Motion Encoder (
): Encodes the motion from both source and driving images into head rotations, translations, and latent expression descriptors.class Emtn(nn.Module): # Architecture details omitted for brevity
- Warping Generators (
): Removes motion from the source and imposes driver motion onto canonical features.class WarpGenerator(nn.Module): # Architecture details omitted for brevity
- 3D Convolutional Network (
): Processes canonical volumetric features.class G3D(nn.Module): # Architecture details omitted for brevity
- 2D Convolutional Network (
): Projects 3D features into 2D and generates the output image.class G2D(nn.Module): # Architecture details omitted for brevity
- Appearance Encoder (
- Description: Enhances the resolution of the base model output from (512 \times 512) to (1024 \times 1024) using a high-resolution dataset of photographs.
- Components:
- Encoder: Takes the base model output and produces a 3D feature tensor.
class EncoderHR(nn.Module): # Architecture details omitted for brevity
- Decoder: Converts the 3D feature tensor to a high-resolution image.
class DecoderHR(nn.Module): # Architecture details omitted for brevity
- Encoder: Takes the base model output and produces a 3D feature tensor.
- Description: A distilled version of the high-resolution model for real-time applications. Trained to mimic the full model’s predictions but runs faster and is limited to a predefined number of avatars.
- Components:
- ResNet18 Encoder: Encodes the input image.
class ResNet18(nn.Module): # Architecture details omitted for brevity
- Generator with SPADE Normalization Layers: Generates the final output image. Each SPADE block uses tensors specific to an avatar.
class SPADEGenerator(nn.Module): # Architecture details omitted for brevity
- ResNet18 Encoder: Encodes the input image.
- Description: Computes the gaze and blink loss using a pretrained face mesh from MediaPipe and a custom network. The gaze loss uses MAE and MSE, while the blink loss uses binary cross-entropy.
- Components:
- Backbone (VGG16): Extracts features from the eye images.
class VGG16Backbone(nn.Module): # Architecture details omitted for brevity
- Keypoint Network: Processes 2D keypoints.
class KeypointNet(nn.Module): # Architecture details omitted for brevity
- Gaze Head: Predicts gaze direction.
class GazeHead(nn.Module): # Architecture details omitted for brevity
- Blink Head: Predicts blink probability.
class BlinkHead(nn.Module): # Architecture details omitted for brevity
- Backbone (VGG16): Extracts features from the eye images.
train_base(cfg, Gbase, Dbase, dataloader)
: Trains the base model using perceptual, adversarial, and cycle consistency losses.def train_base(cfg, Gbase, Dbase, dataloader): # Training code omitted for brevity
train_hr(cfg, GHR, Dhr, dataloader)
: Trains the high-resolution model using super-resolution objectives and adversarial losses.def train_hr(cfg, GHR, Dhr, dataloader): # Training code omitted for brevity
train_student(cfg, Student, GHR, dataloader)
: Distills the high-resolution model into a student model for faster inference.def train_student(cfg, Student, GHR, dataloader): # Training code omitted for brevity
- Data Augmentation: Applies random horizontal flips, color jitter, and other augmentations to the input images.
- Optimizers: Uses AdamW optimizer with cosine learning rate scheduling for both base and high-resolution models.
- Losses:
- Perceptual Loss: Matches the content and facial appearance between predicted and ground-truth images.
- Adversarial Loss: Ensures the realism of predicted images using a multi-scale patch discriminator.
- Cycle Consistency Loss: Prevents appearance leakage through the motion descriptor.
- Description: Sets up the dataset and data loaders, initializes the models, and calls the training functions for base, high-resolution, and student models.
- Implementation:
def main(cfg: OmegaConf) -> None: use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.RandomHorizontalFlip(), transforms.ColorJitter() ]) dataset = EMODataset( use_gpu=use_cuda,,,,, img_scale=(1.0, 1.0),,, transform=transform ) dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4) Gbase = model.Gbase() Dbase = model.Discriminator() train_base(cfg, Gbase, Dbase, dataloader) GHR = model.GHR() GHR.Gbase.load_state_dict(Gbase.state_dict()) Dhr = model.Discriminator() train_hr(cfg, GHR, Dhr, dataloader) Student = model.Student(num_avatars=100) train_student(cfg, Student, GHR, dataloader), 'Gbase.pth'), 'GHR.pth'), 'Student.pth') if __name__ == "__main__": config = OmegaConf.load("./configs/training/stage1-base.yaml") main(config)
rome/losses - cherry picked from
wget '' extract to state_dicts
git clone
cd rt_gene/rt_gene
pip install .