Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible one-line solution for Runtime error (variables modified in-place) #148

Open
williantrevizan opened this issue Apr 20, 2021 · 8 comments

Comments

@williantrevizan
Copy link

Hi, thanks for the repository and this amazing work!

I opened this issue because it might provide a solution for the runtime error reported by cdjameson in another topic, that happens in newer versions of torch ('one of the variables needed for gradient computation has been modified by an inplace operation...'), that seems to be more straightfoward than the solution that Clefspear99 is proposing as a pool request.

The problem happens in the function train_single_scale() in training.py
This function is composed basically of two sequential loops, one for optimizing the discriminator D, and the other for optimizing the generator G. At the end of the first loop, a fake image is generated by the generator. As soon as the second loop starts, this fake image is passed throught the discriminator, with generates a patch discrimination map, which is then used to calculate the loss errG. The command errG.backwards() calculates the gradients which are used for the optimization of netG weights via optimizerG.step(). The first time we go through this second loop everything runs smoothly and the optimizer changes netG weights inplace. However, the second time we go through this loop, the same fake image is used to calculate the loss (that is, the fake image that had been generated with a previous set of netG weights). Therefore, once we call the backwards function, the computational graph will point back to netG weights that were in their original version, before the optimization step. Newer versions of torch are able to catch this inconsistency and that seems to be the reason why the error occurs.

So, instead of downgrading torch, a simple solution would be to add the line,

fake = netG(noise.detach(), prev.detach())

right in the beggining of the second loop, to always recalculate the fake image with the correct weights.

tamarott, I think this might solve this problem. If you allow, I will submit a pull request with this modification.

@tamarott
Copy link
Owner

This is a possible solution, but pat attention that it changed the optimization process and therefore might change performances.
So the results won't necessarily be identical to the original version.

@williantrevizan
Copy link
Author

You are right, I'll pay atention to that! I ran a few tests with the application I'm working on, and it seems to be doing fine with this modification, but I didn't stress these tests too much.

About the optimization process, when I first thought about your paper and code, it made sense to me that conceptually the fake image should be recalculated at every step on that loop (for optimizing G). However what seems to be going on is that the adversarial loss is kept fixed (because you use the same fake image 3 times) and only the reconstruction loss is updated inside the loop. Is there a reason why that should work better?

@tamarott
Copy link
Owner

We found it to work better empirically. But other solutions might also work.
Just be careful and make sure performances are the same.

@williantrevizan
Copy link
Author

Nice, thanks a lot!!

@ariel415el
Copy link

Thanks @williantrevizan, Your fix worked for me

@JasonBournePark
Copy link

It works for me well too.
You saved my time!!
Thanks a lot!

@WZLHQ
Copy link

WZLHQ commented Aug 25, 2022

thanks. You realy save my time!

@jethrolam
Copy link

Thank you @williantrevizan! Confirmed that this solution works on torch==1.12.0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants