Skip to content

Commit

Permalink
Document user inferable modules
Browse files Browse the repository at this point in the history
  • Loading branch information
vyzyv committed Mar 27, 2020
1 parent 87ae23a commit 5d7f4b4
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 155 deletions.
17 changes: 10 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,14 @@ As you can see both modules "compiled" into original `pytorch` layers.

## Custom modules with shape inference capabilities

User can define any module and make it shape inferable with `torchlayers.Infer`
decorator class:
User can define any module and make it shape inferable with `torchlayers.infer`
function:

```python
@torchlayers.Infer() # Remember to instantiate it
class MyLinear(torch.nn.Module):
# Class defined with in_features
# It might be a good practice to use _ prefix and Impl as postfix
# to differentiate from shape inferable version
class _MyLinearImpl(torch.nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(out_features, in_features))
Expand All @@ -121,15 +123,16 @@ class MyLinear(torch.nn.Module):
def forward(self, inputs):
return torch.nn.functional.linear(inputs, self.weight, self.bias)

MyLinear = torchlayers.infer(_MyLinearImpl)

layer = MyLinear(out_features=32)
# [WIP] Currently custom layers are unbuildable, you can still use them without build though
# Build and use just like any other layer in this library
layer =torchlayers.build(MyLinear(out_features=32), torch.randn(1, 64))
layer(torch.randn(1, 64))
```

By default `inputs.shape[1]` will be used as `in_features` value
during initial `forward` pass. If you wish to use different `index` (e.g. to infer using
`inputs.shape[3]`) use `@torchlayers.Infer(index=3)` as a decorator.
`inputs.shape[3]`) use `MyLayer = torchlayers.infer(_MyLayerImpl, index=3)` as a decorator.

## Autoencoder with inverted residual bottleneck and pixel shuffle

Expand Down
128 changes: 62 additions & 66 deletions docs/_modules/torchlayers.html

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions docs/_modules/torchlayers/normalization.html
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,15 @@ <h1>Source code for torchlayers.normalization</h1><div class="highlight"><pre>

<span class="k">def</span> <span class="nf">_module_not_found</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">inputs</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
<span class="n">inner_class</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_module_name</span><span class="si">}</span><span class="s2">1d&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="n">inner_class</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="p">,</span> <span class="s2">&quot;</span><span class="si">{}</span><span class="s2">1d&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_module_name</span><span class="p">),</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">if</span> <span class="n">inner_class</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span> <span class="n">inner_class</span>

<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_module_name</span><span class="si">}</span><span class="s2"> could not be inferred from shape. &quot;</span>
<span class="sa">f</span><span class="s2">&quot;Only 5, 4, 3 or 2 dimensional input allowed (including batch dimension), got </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">inputs</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="si">}</span><span class="s2">.&quot;</span>
<span class="s2">&quot;</span><span class="si">{}</span><span class="s2"> could not be inferred from shape. &quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_module_name</span><span class="p">)</span>
<span class="o">+</span> <span class="s2">&quot;Only 5, 4, 3 or 2 dimensional input allowed (including batch dimension), got </span><span class="si">{}</span><span class="s2">.&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span>
<span class="nb">len</span><span class="p">(</span><span class="n">inputs</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="p">)</span>
<span class="p">)</span></div>


Expand Down
2 changes: 1 addition & 1 deletion docs/_modules/torchlayers/pooling.html
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ <h1>Source code for torchlayers.pooling</h1><div class="highlight"><pre>
<span class="k">return</span> <span class="n">values</span>

<span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2">()&quot;</span>
<span class="k">return</span> <span class="s2">&quot;</span><span class="si">{}</span><span class="s2">()&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span>
<span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">inputs</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">2</span><span class="p">:</span>
Expand Down
2 changes: 1 addition & 1 deletion docs/genindex.html
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ <h2 id="I">I</h2>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="packages/torchlayers.upsample.html#torchlayers.upsample.ConvPixelShuffle.icnr_initialization">icnr_initialization() (torchlayers.upsample.ConvPixelShuffle method)</a>
</li>
<li><a href="packages/torchlayers.html#torchlayers.Infer">Infer (class in torchlayers)</a>
<li><a href="packages/torchlayers.html#torchlayers.infer">infer() (in module torchlayers)</a>
</li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
Expand Down
Binary file modified docs/objects.inv
Binary file not shown.
20 changes: 12 additions & 8 deletions docs/packages/torchlayers.html
Original file line number Diff line number Diff line change
Expand Up @@ -251,30 +251,34 @@
</dl>
</dd></dl>

<dl class="class">
<dt id="torchlayers.Infer">
<em class="property">class </em><code class="sig-prename descclassname">torchlayers.</code><code class="sig-name descname">Infer</code><span class="sig-paren">(</span><em class="sig-param">index: int = 1</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/torchlayers.html#Infer"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#torchlayers.Infer" title="Permalink to this definition"></a></dt>
<dl class="function">
<dt id="torchlayers.infer">
<code class="sig-prename descclassname">torchlayers.</code><code class="sig-name descname">infer</code><span class="sig-paren">(</span><em class="sig-param">module_class</em>, <em class="sig-param">index: str = 1</em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/torchlayers.html#infer"><span class="viewcode-link">[source]</span></a><a class="headerlink" href="#torchlayers.infer" title="Permalink to this definition"></a></dt>
<dd><p>Allows custom user modules to infer input shape.</p>
<p>Input shape should be the first argument after <code class="xref py py-obj docutils literal notranslate"><span class="pre">self</span></code>.</p>
<p>Usually used as class decorator, e.g.:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># Remember it&#39;s a class, it has to be instantiated</span>
<span class="nd">@torchlayers</span><span class="o">.</span><span class="n">Infer</span><span class="p">()</span>
<span class="k">class</span> <span class="nc">StrangeLinear</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">):</span>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">_StrangeLinearImpl</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_features</span><span class="p">,</span> <span class="n">out_features</span><span class="p">,</span> <span class="n">bias</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="n">out_features</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">params</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">out_features</span><span class="p">))</span>

<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span>

<span class="c1"># Now you can use shape inference of in_features</span>
<span class="n">StrangeLinear</span> <span class="o">=</span> <span class="n">torchlayers</span><span class="o">.</span><span class="n">infer</span><span class="p">(</span><span class="n">_StrangeLinearImpl</span><span class="p">)</span>

<span class="c1"># in_features can be inferred</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">StrangeLinear</span><span class="p">(</span><span class="n">out_features</span><span class="o">=</span><span class="mi">64</span><span class="p">)</span>
</pre></div>
</div>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>index</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.8)"><em>int</em></a><em>, </em><em>optional</em>) – Index into <code class="xref py py-obj docutils literal notranslate"><span class="pre">tensor.shape</span></code> input which should be inferred, e.g. tensor.shape[1].
Default: <code class="xref py py-obj docutils literal notranslate"><span class="pre">1</span></code> (<code class="xref py py-obj docutils literal notranslate"><span class="pre">0</span></code> being batch dimension)</p>
<dd class="field-odd"><ul class="simple">
<li><p><strong>module_class</strong> (<a class="reference external" href="https://pytorch.org/docs/stable/nn.html#torch.nn.Module" title="(in PyTorch vmaster (1.4.0a0+919fcbb ))"><em>torch.nn.Module</em></a>) – Class of module to be updated with shape inference capabilities.</p></li>
<li><p><strong>index</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.8)"><em>int</em></a><em>, </em><em>optional</em>) – Index into <code class="xref py py-obj docutils literal notranslate"><span class="pre">tensor.shape</span></code> input which should be inferred, e.g. tensor.shape[1].
Default: <code class="xref py py-obj docutils literal notranslate"><span class="pre">1</span></code> (<code class="xref py py-obj docutils literal notranslate"><span class="pre">0</span></code> being batch dimension)</p></li>
</ul>
</dd>
</dl>
</dd></dl>
Expand Down
Loading

0 comments on commit 5d7f4b4

Please sign in to comment.