Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cross normalization #45

Open
nzhang258 opened this issue Sep 2, 2024 · 4 comments
Open

cross normalization #45

nzhang258 opened this issue Sep 2, 2024 · 4 comments

Comments

@nzhang258
Copy link

Thanks a lot for such an amazing work and here are some questions about training code.

  1. Could you plz tell me which part of the training code in SVD-v2 is about cross Normalization?
  2. if i have some attn params like ip-adapter, it's seem does not work well if I only update the "to_out" in unet. Did you meet this problem?

Looking forward to your reply~

@AlonzoLeeeooo
Copy link

AlonzoLeeeooo commented Sep 5, 2024

Hi @nzhang258 ,

Maybe I can help you with this. In ControlNeXt-SVD-v2-Training/models/unet_spatio_temporal_condition_controlnext.py, line 456 to 464, you can see:

if idx == 0 and conditional_controls is not None:
                scale = conditional_controls['scale']
                conditional_controls = conditional_controls['output']
                mean_latents, std_latents = torch.mean(sample, dim=(1, 2, 3), keepdim=True), torch.std(sample, dim=(1, 2, 3), keepdim=True)
                mean_control, std_control = torch.mean(conditional_controls, dim=(1, 2, 3), keepdim=True), torch.std(conditional_controls, dim=(1, 2, 3), keepdim=True)
                conditional_controls = (conditional_controls - mean_control) * (std_latents / (std_control + 1e-5)) + mean_latents
                conditional_controls = F.adaptive_avg_pool2d(conditional_controls, sample.shape[-2:])
                #  0.2: This superparameter is used to adjust the control level: increasing this value will strengthen the control level.
                sample = sample + conditional_controls * scale * 0.2

This is how the cross normalization is computed. Hope this may help.

@jackyyang9
Copy link

It seems not like the formulation presented in paper eq (10)?

@nighting0le01
Copy link

did you test it out with IP-adapter? @nzhang258 ? did it not work well with pre-trained Ip-adapters?

@AlonzoLeeeooo
Copy link

AlonzoLeeeooo commented Sep 23, 2024

Hi @nzhang258 ,

Maybe I can help you with this. In ControlNeXt-SVD-v2-Training/models/unet_spatio_temporal_condition_controlnext.py, line 456 to 464, you can see:

if idx == 0 and conditional_controls is not None:
                scale = conditional_controls['scale']
                conditional_controls = conditional_controls['output']
                mean_latents, std_latents = torch.mean(sample, dim=(1, 2, 3), keepdim=True), torch.std(sample, dim=(1, 2, 3), keepdim=True)
                mean_control, std_control = torch.mean(conditional_controls, dim=(1, 2, 3), keepdim=True), torch.std(conditional_controls, dim=(1, 2, 3), keepdim=True)
                conditional_controls = (conditional_controls - mean_control) * (std_latents / (std_control + 1e-5)) + mean_latents
                conditional_controls = F.adaptive_avg_pool2d(conditional_controls, sample.shape[-2:])
                #  0.2: This superparameter is used to adjust the control level: increasing this value will strengthen the control level.
                sample = sample + conditional_controls * scale * 0.2

This is how the cross normalization is computed. Hope this may help.

@nzhang258 @jackyyang9 Here is my assumption:
They want to maintain that: $\frac{x_c - \mu_c}{\sqrt{\sigma_c^2}}$ = $\frac{x_m- \mu_m}{\sqrt{\sigma_m^2}}$.
If we follow this equation, we can see that the authors expect the lightweight module to output feature $\widehat{x_m}$ that is similar to $x_m$ as much as possible. So this equation becomes: $\widehat{x_m} = \frac{(x_c - \mu_c) \cdot \sqrt{\sigma_m^2}}{\sqrt{\sigma_c^2}} + \mu_m$, where this equation exactly aligns with their released code.
But honestly, I am not sure whether it's okay to use Eq. (10) to denote this process.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants