This repository contains the official code for PluGeN4Faces, a method for explicitly disentangling attributes from person's identity. This work was presented at WACV 2024.
Conditional GANs are frequently used for manipulating the attributes of face images, such as expression, hairstyle, pose, or age. Even though the state-of-the-art models successfully modify the requested attributes, they simultaneously modify other important characteristics of the image, such as a person’s identity. In this paper, we focus on solving this problem by introducing PluGeN4Faces, a plugin to StyleGAN, which explicitly disentangles face attributes from a person’s identity. Our key idea is to perform training on images retrieved from movie frames, where a given person appears in various poses and with different attributes. By applying a type of contrastive loss, we encourage the model to group images of the same person in similar regions of latent space. Our experiments demonstrate that the modifications of face attributes performed by PluGeN4Faces are significantly less invasive on the remaining characteristics of the image than in the existing state-of-the-art models.
If you found this code useful, please cite:
@InProceedings{Suwala_2024_WACV,
author = {Suwa{\l}a, Adrian and W\'ojcik, Bartosz and Proszewska, Magdalena and Tabor, Jacek and Spurek, Przemys{\l}aw and \'Smieja, Marek},
title = {Face Identity-Aware Disentanglement in StyleGAN},
booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
month = {January},
year = {2024},
pages = {5222-5231}
}
chmod +x scripts/setup.sh
./scripts/setup.sh
All needed models are expected to be in pretrained_models
.
To train the model you need to first encode the training data into latents, using e4e. Once the checkpoint is saved, execute:
python scripts/encode.py --src DIR_WITH_IMAGES --target NEW_DIR_WITH_LATENTS
Training the model requires also a StyleGAN checkpoint compatible with rosinality/stylegan2-pytorch.
chmod +x scripts/train_flow.sbatch
./scripts/train_flow.sbatch
The options are documented in scripts/train.py
or --help
on the called python script.
For evaluation you'll also need ArcFace checkpoint and an attribute classifier checkpoint.
Before you run the evaluation itself you have to cache
face_recognition
, ArcFace
and StyleGAN latent space
on your evaluation data.
./scripts/cache_embs.sbatch
The main evaluation script searches for flow values to cross the given thresholds for edits in the given range of values. It can use either binary search based on the outputs from the attribute classifier, or a linear search, which uses equally spaced points across the given range. Linear search can be faster since it knows all the points upfront, and uses vectorization to process several at the same time.
chmod +x scripts/detector_threshold.sbatch
./scripts/detector_threshold.sbatch
The meaning of the options is documented in the CLI code in scripts/
, or just --help
on the
(underlying python script, not the bash script).
from plugen4faces.flow import load_plugen4faces
model = load_plugen4faces("model.pt")
from plugen4faces.const import NUM_WS
onehots = (
F.one_hot(torch.arange(NUM_WS), num_classes=NUM_WS)
.float()
#.repeat(B, 1, 1)
.to(model.device)
)
# ws should have shape [B, NUM_WS, 512], and onehots [B, NUM_WS, NUM_WS]
z = model.transform_to_noise(ws, onehots)
# z should have shape [B, NUM_WS, 512], and onehots [B, NUM_WS, NUM_WS]
ws = model.inverse(z, onehots)
from plugen4faces.const import attr2idx
z[..., attridx["beard"]] = 2
z[..., attridx["gender"]] = -2
There is also attr2detidx
that converts to indices of the attribute classifier.
from plugen4faces.utils import load_detector, load_stylegan, load_plugen, load_styleflow
All need a path, but have a default from global_config
defined in plugen4faces/const.py
plugen4faces/
├── environment.yml
├── plugen4faces
│ ├── const.py # globally used defaults, paths, values
│ ├── encoder4editing # Image to StyleGAN latent inverter
│ ├── evaluation
│ │ ├── configs.py # wrappers around evaluated models
│ │ ├── evaluate.py # main evaluation code
│ │ ├── helpers.py
│ │ ├── __init__.py
│ │ └── search.py
│ ├── flow.py # implements the method
│ ├── __init__.py
│ ├── resnets.py # for attribute classifier
│ ├── StyleFlow # implementation of PluGeN and StyleFlow
│ ├── stylegan2
│ └── utils.py
├── README.md
└── scripts
├── cache_embs.sbatch # for caching orignal images, required to run evaluation
├── detector_threshold.sbatch # main evaluation script, detector is the attribute classifier
├── encode.py # for encoding directories of images with e4e
├── evaluate.py # CLI interface for evaluation
├── postprocess.py # for merging runs and calculating metrics from detector_threshold results
├── train_flow.sbatch # training script
└── train.py # CLI interface for training