Skip to content

Commit

Permalink
Bias for rescaling by input (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlperla authored Aug 23, 2022
1 parent de6b7b3 commit 518a190
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 8 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

- Exponential layer
- Flexible multi-layer neural network with optional nonlinear last layer
- Affine rescaling of output by an input


## Development
Expand Down
13 changes: 9 additions & 4 deletions econ_layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,20 @@ def forward(self, input):

# rescaling by a specific element of a given input
class RescaleOutputsByInput(nn.Module):
def __init__(self, rescale_index: int = 0):
def __init__(self, rescale_index: int = 0, bias=False):
super().__init__()
self.rescale_index = rescale_index

if bias:
self.bias = torch.nn.Parameter(torch.Tensor(1)) # only a scalar here
torch.nn.init.zeros_(self.bias)
else:
self.bias = 0.0 # register_parameter('bias', None) # necessary?

def forward(self, x, y):
if x.dim() == 1:
return x[self.rescale_index] * y
return x[self.rescale_index] * y + self.bias
else:
return x[:, [self.rescale_index]] * y
return x[:, [self.rescale_index]] * y + self.bias


# assuming 2D data
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,6 @@
test_suite="tests",
tests_require=test_requirements,
url="https://github.com/HighDimensionalEconLab/econ_layers",
version="0.0.23",
version="0.0.24",
zip_safe=False,
)
28 changes: 25 additions & 3 deletions tests/test_input_rescaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,19 @@
import pytest
import torch
import numpy
from econ_layers.layers import RescaleOutputsByInput

from econ_layers.layers import (
FlexibleSequential,
RescaleOutputsByInput,
ScalarExponentialRescaling,
RescaleOutputsByInput,
)
import numpy.testing
import torch.autograd.gradcheck
from torch.autograd import Variable
from tests.helpers import train, test

torch.set_printoptions(16) # to be able to see what is going on
torch.manual_seed(0)

def test_input_rescaling():

Expand All @@ -33,4 +44,15 @@ def test_input_rescaling():
assert torch.all(torch.isclose(input_mult_0, x[:, [0]] * x[:, [1]]))
assert torch.all(torch.isclose(input_mult_1, x[:, [0]] * x[:, [1]]))


def test_input_rescaling_bias():
n_in = 2
n_out = 1
mod = FlexibleSequential(
n_in,
n_out,
layers=2,
hidden_dim=128,
OutputRescalingLayer=RescaleOutputsByInput(rescale_index=0, bias = True),
).double()
input = (Variable(torch.randn(n_in).double(), requires_grad=True),)
assert torch.autograd.gradcheck(mod, input)

0 comments on commit 518a190

Please sign in to comment.