Skip to content

Commit

Permalink
add some cpn test notebooks and fix an issue with network_type attrib…
Browse files Browse the repository at this point in the history
…ute for MLPJax class
  • Loading branch information
AlexanderFengler committed Oct 23, 2023
1 parent f6472fb commit bfb5553
Show file tree
Hide file tree
Showing 37 changed files with 7,237 additions and 23 deletions.
2 changes: 1 addition & 1 deletion docs/overrides/main.html
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Navigate the site here!
</span>
<span class="right-margin">
v0.4.3 is out!
v0.4.4 is out!
</span>
<span>
<span class="twemoji">
Expand Down
2 changes: 1 addition & 1 deletion lanfactory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.4.3"
__version__ = "0.4.4"

from . import config
from . import trainers
Expand Down
18 changes: 12 additions & 6 deletions lanfactory/trainers/jax_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pickle
from functools import partial
from frozendict import frozendict
from typing import Sequence
from typing import Sequence, Optional

from lanfactory.utils import try_gen_folder
from time import time
Expand Down Expand Up @@ -71,7 +71,9 @@ class MLPJax(nn.Module):
train_output_type (str):
The output type of the model during training.
"""

network_type_dict: dict = frozendict({'logprob': 'lan',
'logits': 'cpn'}
)
layer_sizes: Sequence[int] = (100, 90, 80, 1)
activations: Sequence[str] = ("tanh", "tanh", "tanh", "linear")
train: bool = True
Expand All @@ -81,9 +83,10 @@ class MLPJax(nn.Module):
activations_dict = frozendict(
{"relu": nn.relu, "tanh": nn.tanh, "sigmoid": nn.sigmoid}
)
# network_type: Optional[str] = "none"

# Define network type
network_type = "lan" if train_output_type == "logprob" else "cpn"
# network_type = "lan" if train_output_type == "logprob" else "cpn"

def setup(self):
"""Setup function for the JaxMLP class.
Expand All @@ -99,7 +102,7 @@ def setup(self):
]

# Identification
# self.network_type = "lan" if self.train_output_type == "logprob" else "cpn"
self.network_type = self.network_type_dict[self.train_output_type]

def __call__(self, inputs):
"""Call function for the JaxMLP class.
Expand All @@ -126,11 +129,14 @@ def __call__(self, inputs):
else:
x = self.activation_funs[i](x)

if not self.train and self.train_output_type == "logprob":
if (not self.train) and (self.train_output_type == "logprob"):
print('passing through identity')
x = x # just for pedagogy
elif not self.train and self.train_output_type == "logits":
elif (not self.train) and (self.train_output_type == "logits"):
print('passing through transform')
x = -jnp.log((1 + jnp.exp(-x)))
elif not self.train:
print('passing through identity 2')
x = x # just for pedagogy

return x
Expand Down
15 changes: 1 addition & 14 deletions notebooks/test_notebooks/test_jax_network.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,7 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"ename": "ImportError",
"evalue": "cannot import name 'onnx' from partially initialized module 'lanfactory' (most likely due to a circular import) (/users/afengler/data/software/miniconda3/envs/lan_pipe/lib/python3.10/site-packages/lanfactory/__init__.py)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mssms\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mlanfactory\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mos\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mnumpy\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mnp\u001b[39;00m\n",
"File \u001b[0;32m~/data/software/miniconda3/envs/lan_pipe/lib/python3.10/site-packages/lanfactory/__init__.py:6\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39m.\u001b[39;00m \u001b[39mimport\u001b[39;00m trainers\n\u001b[1;32m 5\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39m.\u001b[39;00m \u001b[39mimport\u001b[39;00m utils\n\u001b[0;32m----> 6\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39m.\u001b[39;00m \u001b[39mimport\u001b[39;00m onnx\n\u001b[1;32m 8\u001b[0m __all__ \u001b[39m=\u001b[39m [\u001b[39m\"\u001b[39m\u001b[39mconfig\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mtrainers\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mutils\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39monnx\u001b[39m\u001b[39m\"\u001b[39m]\n",
"\u001b[0;31mImportError\u001b[0m: cannot import name 'onnx' from partially initialized module 'lanfactory' (most likely due to a circular import) (/users/afengler/data/software/miniconda3/envs/lan_pipe/lib/python3.10/site-packages/lanfactory/__init__.py)"
]
}
],
"outputs": [],
"source": [
"import ssms\n",
"import lanfactory\n",
Expand Down
Loading

0 comments on commit bfb5553

Please sign in to comment.