Skip to content

Commit

Permalink
Moments layer (#42)
Browse files Browse the repository at this point in the history
* add moments layer

* add moments layer test

* fix typo in n_moments

* adjust test

* adjust moments forward

* adjust moments test

* adjust moments forward

* add proper broadcasting test

* extra moments test

* bump version
  • Loading branch information
jbrightuniverse authored Dec 31, 2021
1 parent 0a5263b commit de6b7b3
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 1 deletion.
14 changes: 14 additions & 0 deletions econ_layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@
from typing import Optional
from jsonargparse import lazy_instance


# produces the first m moments of a given input
class Moments(nn.Module):
def __init__(
self,
n_moments: int,
):
super().__init__()
self.n_moments = n_moments

def forward(self, input):
return torch.cat([input.pow(m) for m in torch.arange(1, self.n_moments + 1)], 1)


# rescaling by a specific element of a given input
class RescaleOutputsByInput(nn.Module):
def __init__(self, rescale_index: int = 0):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,6 @@
test_suite="tests",
tests_require=test_requirements,
url="https://github.com/HighDimensionalEconLab/econ_layers",
version="0.0.22",
version="0.0.23",
zip_safe=False,
)
70 changes: 70 additions & 0 deletions tests/test_moments_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python

"""Tests for Moments layer"""

import pytest
import torch
import numpy
from econ_layers.layers import Moments


def test_moments():
moments_layer = Moments(5)
x = torch.tensor([2.0, 1.5])
x_moments = moments_layer(x.reshape([2, 1]))
assert torch.all(
torch.isclose(
x_moments,
torch.tensor(
[[2.0, 4.0, 8.0, 16.0, 32.0], [1.5, 2.25, 3.375, 5.0625, 7.59375]]
),
)
)


def test_moments_broadcast():
num_moments = 5
test_data = torch.tensor([[1.0, 2, 3, 4], [5.0, 6, 7, 8]])
num_batches, N = test_data.shape
moments_layer = Moments(num_moments)
generated_data = torch.stack(
[
torch.mean(moments_layer(test_data[i, :].reshape([N, 1])), 0)
for i in range(num_batches)
]
)
expected_data = torch.stack(
[
torch.mean(
torch.stack(
[
torch.tensor(
[elt ** moment for moment in range(1, num_moments + 1)]
)
for elt in test_data[i, :]
]
),
0,
)
for i in range(num_batches)
]
)

assert torch.all(
torch.isclose(
generated_data,
expected_data,
)
)

assert torch.all(
torch.isclose(
generated_data,
torch.tensor(
[
[2.5e00, 7.5e00, 2.5e01, 8.85e01, 3.25e02],
[6.5e00, 4.35e01, 2.99e02, 2.1045e03, 1.5119e04],
]
),
)
)

0 comments on commit de6b7b3

Please sign in to comment.