Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No grad for regularization loss #101

Open
matsen opened this issue Feb 20, 2021 · 6 comments · May be fixed by #104
Open

No grad for regularization loss #101

matsen opened this issue Feb 20, 2021 · 6 comments · May be fixed by #104

Comments

@matsen
Copy link
Contributor

matsen commented Feb 20, 2021

I think that our implementation of regularization loss is broken!

Here's how it looks now:

    def regularization_loss(self):
        """L1 penalize single mutant effects, and pre-latent interaction
        weights."""
        penalty = 0.0
        if self.beta_l1_coefficient > 0.0:
            penalty += self.beta_l1_coefficient * self.latent_layer.weight[
                :, : self.input_size
            ].norm(1)
        if self.interaction_l1_coefficient > 0.0:
            for interaction_layer in self.layers[: self.latent_idx]:
                penalty += self.interaction_l1_coefficient * getattr(
                    self, interaction_layer
                ).weight.norm(1)
        return penalty

The thing is, penalty is thus a float and we have no option for backprop!

I can check this out by using

diff --git a/torchdms/analysis.py b/torchdms/analysis.py
index 767791a..044bb55 100644
--- a/torchdms/analysis.py
+++ b/torchdms/analysis.py
@@ -88,6 +88,9 @@ class Analysis:
                 range(targets.shape[1]), loss_decays
             )
         ]
+        qqq = sum(per_target_loss)
+        ppp = self.model.regularization_loss()
+        breakpoint()
         return sum(per_target_loss) + self.model.regularization_loss()
 
     def train(

If we print qqq, it's a tensor, but ppp is a float.

@matsen
Copy link
Contributor Author

matsen commented Feb 20, 2021

I pushed this version of regularization loss to 101-no-regularization-grad

    def regularization_loss(self):                                                                                                                                                                                                
        """L1 penalize single mutant effects, and pre-latent interaction                                                                                                                                                          
        weights."""                                                                                                                                                                                                               
        penalty = self.beta_l1_coefficient * self.latent_layer.weight[                                                                                                                                                            
            :, : self.input_size                                                                                                                                                                                                  
        ].norm(1)                                                                                                                                                                                                                 
        if self.interaction_l1_coefficient > 0.0:                                                                                                                                                                                 
            for interaction_layer in self.layers[: self.latent_idx]:                                                                                                                                                              
                penalty += self.interaction_l1_coefficient * torch.sum(                                                                                                                                                           
                    [getattr(self, interaction_layer).weight.norm(1)]                                                                                                                                                             
                )                                                                                                                                                                                                                 
        return penalty                                                                                                                                                                                                            

This version gives

(Pdb++) ppp
tensor(0., grad_fn=<MulBackward0>)

when I run make test.

@wsdewitt wsdewitt linked a pull request Feb 23, 2021 that will close this issue
7 tasks
@wsdewitt
Copy link
Contributor

wsdewitt commented Feb 24, 2021

@matsen Hmm, I can't reproduce the float issue:

>>> from torchdms.model import FullyConnected
>>> model = FullyConnected(10, [2], [None], [None], None, beta_l1_coefficient=1e-3)
>>> loss = model.regularization_loss()
>>> print(loss)
tensor([14.5608], grad_fn=<AddBackward0>)

@matsen
Copy link
Contributor Author

matsen commented Feb 24, 2021

That's strange.

Did you try dropping into the debugger as in my original report?

@wsdewitt
Copy link
Contributor

Yes the issue surfaces in the debugger.

(Pdb++) print(ppp)
0.0
(Pdb++) ppp.backward()
*** AttributeError: 'float' object has no attribute 'backward'

@matsen
Copy link
Contributor Author

matsen commented Feb 24, 2021

Fascinating. And sorry if I sent you on a goose chase.

How do you propose moving forward?

@wsdewitt
Copy link
Contributor

I still don't understand the behavior, so no proposal yet. I'll keep poking! 👨‍🏭

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants