Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NetworkGlow has no field logdet #80

Closed
flo-he opened this issue Apr 21, 2023 · 4 comments
Closed

NetworkGlow has no field logdet #80

flo-he opened this issue Apr 21, 2023 · 4 comments

Comments

@flo-he
Copy link

flo-he commented Apr 21, 2023

Hi, I stumbled across this error message (see title) when trying to train a Glow network (but also applies to Hint network).

MWE:

using InvertibleNetworks, Flux

# Glow Network
model = NetworkGlow(2, 32, 2, 5)

# dummy input & target
X = randn(Float32, 1, 1, 2, 1)
Y = 2 .* X .+ 1

# loss fn
loss(model, X, Y) = Flux.mse(Y, model(X)[1])

θ = Flux.params(model)
opt = ADAM(0.001f0)

for i = 1:5
    l, grads = Flux.withgradient(θ) do
        loss(model, X, Y)
    end
    @show l
    Flux.update!(opt, θ, grads)
end
@flo-he
Copy link
Author

flo-he commented Sep 29, 2023

Hi, are there any news on this? Would be really useful if one could train the INNs as simple as any other Flux model.

@rafaelorozco
Copy link
Collaborator

Hello,

Sorry! I missed this discussion or probably forgot about this. There is an easy fix where we give GlowNetwork the optional logdet and then if logdet=false you can train it as you describe above. Would that be helpful?

If so I can make that PR in a couple of hours no problem

@flo-he
Copy link
Author

flo-he commented Sep 29, 2023

Yes, this would be fabulous, thank you!

@rafaelorozco
Copy link
Collaborator

All right pushed that quick fix. I want to be clear again that this will only work for logdet=false. Currently tracking/differentiating the logdet is a bit difficult to do with Julia AD. I think it is possible it just needs some time when I have that later.

I added the MWE that you suggested here:
https://github.com/slimgroup/InvertibleNetworks.jl/blob/master/examples/chainrules/train_with_flux.jl

I just had to increase the dimensionality of the input because the actnorm layer was exploding over the variance over a single element.

I hope this helps, Thank you for the input!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants