Skip to content

Commit

Permalink
Merge pull request #11 from kengz/criterion-optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
kengz authored Jun 24, 2022
2 parents e170433 + 4376e7b commit 1c4a444
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 12 deletions.
77 changes: 71 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# TorchArc ![CI](https://github.com/kengz/torcharc/workflows/CI/badge.svg)

Build PyTorch networks by specifying architectures.

## Installation
Expand All @@ -16,16 +17,15 @@ conda install pytorch -c pytorch
## Usage

Given just the architecture, `torcharc` can build generic DAG (directed acyclic graph) of nn modules, which consists of:

- single-input-output modules: `Conv1d, Conv2d, Conv3d, Linear, Perceiver` or any other valid nn.Module
- fork modules: `ReuseFork, SplitFork`
- merge modules: `ConcatMerge, FiLMMerge`

The custom modules are defined in [`torcharc/module`](https://github.com/kengz/torcharc/tree/master/torcharc/module), registered in [`torcharc/module_builder.py`](https://github.com/kengz/torcharc/blob/master/torcharc/module_builder.py).


The full examples of architecture references are in [`torcharc/arc_ref.py`](https://github.com/kengz/torcharc/blob/master/torcharc/arc_ref.py), and full functional examples are in [`test/module/`](https://github.com/kengz/torcharc/tree/master/test/module). Below we walk through some main examples.


### ConvNet

```python
Expand Down Expand Up @@ -70,7 +70,6 @@ Sequential(
</p>
</details>


### MLP

```python
Expand Down Expand Up @@ -112,10 +111,9 @@ Sequential(
</p>
</details>


### Perceiver

>See [`torcharc/arc_ref.py`](https://github.com/kengz/torcharc/blob/master/torcharc/arc_ref.py) for multimodal Perceiver.
> See [`torcharc/arc_ref.py`](https://github.com/kengz/torcharc/blob/master/torcharc/arc_ref.py) for multimodal Perceiver.
```python
arc = {
Expand Down Expand Up @@ -1452,7 +1450,6 @@ Perceiver(
</p>
</details>


### DAG: Hydra

Ultimately, we can build a generic DAG network using the modules linked by the fork and merge modules. The example below shows HydraNet - a network with multiple inputs and multiple outputs.
Expand Down Expand Up @@ -1550,6 +1547,74 @@ DAG module accepts a `dict` (example below) as input, and the module selects its

For example, the input `xs` with keys `image, vector` passes through the first `image` module, and the output becomes `{'image': image_module(xs.image), 'vector': xs.vector}`. This is then passed through the remainder of the modules in the arc as declared.

### Criterion (Loss)

TorchArc provides convenience method to construct criterion modules (loss function) in the same config-driven manner.

```python
import torch
import torcharc


loss_spec = {
'type': 'BCEWithLogitsLoss',
'reduction': 'mean',
'pos_weight': 10.0,
}
criterion = torcharc.build_criterion(loss_spec)

pred = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
loss = criterion(pred, target)
# tensor(11.6296, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
```

### Optimizer

TorchArc also provides convenience method to construct optimizer in the same config-driven manner.

```python
import torcharc


arc = {
'type': 'Linear',
'in_features': 8,
'layers': [64, 32],
'batch_norm': True,
'activation': 'ReLU',
'dropout': 0.2,
'init': {
'type': 'normal_',
'std': 0.01,
},
}
optim_spec = {
'type': 'Adam',
'lr': 0.001,
}

model = torcharc.build(arc)
optimizer = torcharc.build_optimizer(optim_spec, model)
```

<details><summary>optimizer</summary>
<p>

```
Adam (
Parameter Group 0
amsgrad: False
betas: (0.9, 0.999)
eps: 1e-08
lr: 0.001
weight_decay: 0
)
```

</p>
</details>

## Development

### Setup
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def run_tests(self):

setup(
name='torcharc',
version='1.1.3',
version='1.2.0',
description='Build PyTorch networks by specifying architectures.',
long_description='https://github.com/kengz/torcharc',
keywords='torcharc',
Expand Down
23 changes: 21 additions & 2 deletions test/test_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch import nn
from torcharc import arc_ref, build
from torch import nn, optim
from torcharc import arc_ref, build, build_criterion, build_optimizer
import pytest


Expand All @@ -8,3 +8,22 @@ def test_builder(name, arc):
print('building', name)
model = build(arc)
assert isinstance(model, nn.Module)


@pytest.mark.parametrize('loss_spec', [
{'type': 'MSELoss'},
{'type': 'BCEWithLogitsLoss', 'reduction': 'mean', 'pos_weight': 10.0}, # with numeric arg to be converted to tensor
])
def test_build_criterion(loss_spec):
criterion = build_criterion(loss_spec)
assert isinstance(criterion, nn.Module)


@pytest.mark.parametrize('optim_spec', [
{'type': 'SGD', 'lr': 0.1},
{'type': 'Adam', 'lr': 0.001},
])
def test_build_optimizer(optim_spec):
model = build(arc_ref.REF_ARCS['Linear'])
criterion = build_optimizer(optim_spec, model)
assert isinstance(criterion, optim.Optimizer)
27 changes: 24 additions & 3 deletions torcharc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,32 @@
from contextlib import suppress
from torcharc import module_builder
from torcharc.module import dag
from torch import nn
import torch


def build(arc: dict) -> nn.Module:
'''Interface method to build a DAGNet or a simple nn module. See arf_ref.py for arc references.'''
def build(arc: dict) -> torch.nn.Module:
'''Interface method to build a DAGNet or a simple torch.nn module. See arf_ref.py for arc references.'''
if 'dag_in_shape' in arc:
return dag.DAGNet(arc)
else:
return module_builder.build_module(arc)


# additional convenience methods to build criterion and optimizer

def build_criterion(loss_spec: dict) -> torch.nn.Module:
'''Build criterion (loss function) from loss spec'''
criterion_cls = getattr(torch.nn, loss_spec.pop('type'))
# any numeric arg has to be tensor; scan and try-cast
for k, v in loss_spec.items():
with suppress(Exception):
loss_spec[k] = torch.tensor(v)
criterion = criterion_cls(**loss_spec)
return criterion


def build_optimizer(optim_spec: dict, model: torch.nn.Module) -> torch.optim.Optimizer:
'''Build optimizer from optimizer spec'''
optim_cls = getattr(torch.optim, optim_spec.pop('type'))
optimizer = optim_cls(model.parameters(), **optim_spec)
return optimizer

0 comments on commit 1c4a444

Please sign in to comment.