From 5d7f4b4c76fb28e5f74c170b3ac6d17391827304 Mon Sep 17 00:00:00 2001 From: vyzyv Date: Fri, 27 Mar 2020 22:13:44 +0100 Subject: [PATCH] Document user inferable modules --- README.md | 17 ++- docs/_modules/torchlayers.html | 128 +++++++++---------- docs/_modules/torchlayers/normalization.html | 8 +- docs/_modules/torchlayers/pooling.html | 2 +- docs/genindex.html | 2 +- docs/objects.inv | Bin 1189 -> 1191 bytes docs/packages/torchlayers.html | 20 +-- docs/searchindex.js | 2 +- tests/general_test.py | 12 +- torchlayers/__init__.py | 128 +++++++++---------- 10 files changed, 164 insertions(+), 155 deletions(-) diff --git a/README.md b/README.md index 6fbd014..e06eb94 100644 --- a/README.md +++ b/README.md @@ -107,12 +107,14 @@ As you can see both modules "compiled" into original `pytorch` layers. ## Custom modules with shape inference capabilities -User can define any module and make it shape inferable with `torchlayers.Infer` -decorator class: +User can define any module and make it shape inferable with `torchlayers.infer` +function: ```python -@torchlayers.Infer() # Remember to instantiate it -class MyLinear(torch.nn.Module): + # Class defined with in_features + # It might be a good practice to use _ prefix and Impl as postfix + # to differentiate from shape inferable version +class _MyLinearImpl(torch.nn.Module): def __init__(self, in_features: int, out_features: int): super().__init__() self.weight = torch.nn.Parameter(torch.randn(out_features, in_features)) @@ -121,15 +123,16 @@ class MyLinear(torch.nn.Module): def forward(self, inputs): return torch.nn.functional.linear(inputs, self.weight, self.bias) +MyLinear = torchlayers.infer(_MyLinearImpl) -layer = MyLinear(out_features=32) -# [WIP] Currently custom layers are unbuildable, you can still use them without build though +# Build and use just like any other layer in this library +layer =torchlayers.build(MyLinear(out_features=32), torch.randn(1, 64)) layer(torch.randn(1, 64)) ``` By default `inputs.shape[1]` will be used as `in_features` value during initial `forward` pass. If you wish to use different `index` (e.g. to infer using -`inputs.shape[3]`) use `@torchlayers.Infer(index=3)` as a decorator. +`inputs.shape[3]`) use `MyLayer = torchlayers.infer(_MyLayerImpl, index=3)` as a decorator. ## Autoencoder with inverted residual bottleneck and pixel shuffle diff --git a/docs/_modules/torchlayers.html b/docs/_modules/torchlayers.html index 26fb90d..9a60efd 100644 --- a/docs/_modules/torchlayers.html +++ b/docs/_modules/torchlayers.html @@ -226,7 +226,7 @@

Source code for torchlayers

                pooling, regularization, upsample)
 from ._version import __version__
 
-__all__ = ["build", "Infer", "Lambda", "Reshape", "Concatenate"]
+__all__ = ["build", "infer", "Lambda", "Reshape", "Concatenate"]
 
 
 
[docs]def build(module, *args, **kwargs): @@ -288,16 +288,14 @@

Source code for torchlayers

     return module
-
[docs]class Infer: +
[docs]def infer(module_class, index: str = 1): """Allows custom user modules to infer input shape. Input shape should be the first argument after `self`. Usually used as class decorator, e.g.:: - # Remember it's a class, it has to be instantiated - @torchlayers.Infer() - class StrangeLinear(torch.nn.Linear): + class _StrangeLinearImpl(torch.nn.Linear): def __init__(self, in_features, out_features, bias: bool = True): super().__init__(in_features, out_features, bias) self.params = torch.nn.Parameter(torch.randn(out_features)) @@ -305,77 +303,75 @@

Source code for torchlayers

             def forward(self, inputs):
                 super().forward(inputs) + self.params
 
+        # Now you can use shape inference of in_features
+        StrangeLinear = torchlayers.infer(_StrangeLinearImpl)
+
         # in_features can be inferred
         layer = StrangeLinear(out_features=64)
 
 
     Parameters
     ----------
+    module_class: torch.nn.Module
+        Class of module to be updated with shape inference capabilities.
+
     index: int, optional
         Index into `tensor.shape` input which should be inferred, e.g. tensor.shape[1].
         Default: `1` (`0` being batch dimension)
 
     """
 
-    def __init__(self, index: int = 1):
-        self.index: int = index
-
-    def __call__(self, module_class):
-        init_arguments = [
-            str(argument)
-            for argument in inspect.signature(module_class.__init__).parameters.values()
-        ]
-
-        # Other argument than self
-        if len(init_arguments) > 1:
-            name = module_class.__name__
-            infered_module = type(
-                name, (torch.nn.Module,), {_dev_utils.infer.MODULE_CLASS: module_class},
-            )
-            parsed_arguments, uninferable_arguments = _dev_utils.infer.parse_arguments(
-                init_arguments, infered_module
-            )
-
-            setattr(
-                infered_module,
-                "__init__",
-                _dev_utils.infer.create_init(parsed_arguments),
-            )
-
-            setattr(
-                infered_module,
-                "forward",
-                _dev_utils.infer.create_forward(
-                    _dev_utils.infer.MODULE,
-                    _dev_utils.infer.MODULE_CLASS,
-                    parsed_arguments,
-                    self.index,
-                ),
-            )
-            setattr(
-                infered_module,
-                "__repr__",
-                _dev_utils.infer.create_repr(
-                    _dev_utils.infer.MODULE, **uninferable_arguments
-                ),
-            )
-            setattr(
-                infered_module,
-                "__getattr__",
-                _dev_utils.infer.create_getattr(_dev_utils.infer.MODULE),
-            )
-
-            setattr(
-                infered_module,
-                "__reduce__",
-                _dev_utils.infer.create_reduce(
-                    _dev_utils.infer.MODULE, parsed_arguments
-                ),
-            )
-
-            return infered_module
-
-        return module_class
+ init_arguments = [ + str(argument) + for argument in inspect.signature(module_class.__init__).parameters.values() + ] + + # Other argument than self + if len(init_arguments) > 1: + name = module_class.__name__ + infered_module = type( + name, (torch.nn.Module,), {_dev_utils.infer.MODULE_CLASS: module_class}, + ) + parsed_arguments, uninferable_arguments = _dev_utils.infer.parse_arguments( + init_arguments, infered_module + ) + + setattr( + infered_module, "__init__", _dev_utils.infer.create_init(parsed_arguments), + ) + + setattr( + infered_module, + "forward", + _dev_utils.infer.create_forward( + _dev_utils.infer.MODULE, + _dev_utils.infer.MODULE_CLASS, + parsed_arguments, + index, + ), + ) + setattr( + infered_module, + "__repr__", + _dev_utils.infer.create_repr( + _dev_utils.infer.MODULE, **uninferable_arguments + ), + ) + setattr( + infered_module, + "__getattr__", + _dev_utils.infer.create_getattr(_dev_utils.infer.MODULE), + ) + + setattr( + infered_module, + "__reduce__", + _dev_utils.infer.create_reduce(_dev_utils.infer.MODULE, parsed_arguments), + ) + + return infered_module + + return module_class
[docs]class Lambda(torch.nn.Module): @@ -517,8 +513,8 @@

Source code for torchlayers

 
     module_class = _getattr(name)
     if name in _inferable.torch.all() + _inferable.custom.all():
-        return Infer(_dev_utils.helpers.get_per_module_index(module_class))(
-            module_class
+        return infer(
+            module_class, _dev_utils.helpers.get_per_module_index(module_class)
         )
     return module_class
 
diff --git a/docs/_modules/torchlayers/normalization.html b/docs/_modules/torchlayers/normalization.html index 5a89210..437fe5c 100644 --- a/docs/_modules/torchlayers/normalization.html +++ b/docs/_modules/torchlayers/normalization.html @@ -318,13 +318,15 @@

Source code for torchlayers.normalization

 
     def _module_not_found(self, inputs):
         if len(inputs.shape) == 2:
-            inner_class = getattr(torch.nn, f"{self._module_name}1d", None)
+            inner_class = getattr(torch.nn, "{}1d".format(self._module_name), None)
             if inner_class is not None:
                 return inner_class
 
         raise ValueError(
-            f"{self._module_name} could not be inferred from shape. "
-            f"Only 5, 4, 3 or 2 dimensional input allowed (including batch dimension), got {len(inputs.shape)}."
+            "{} could not be inferred from shape. ".format(self._module_name)
+            + "Only 5, 4, 3 or 2 dimensional input allowed (including batch dimension), got {}.".format(
+                len(inputs.shape)
+            )
         )
diff --git a/docs/_modules/torchlayers/pooling.html b/docs/_modules/torchlayers/pooling.html index edbcf5a..9c86aad 100644 --- a/docs/_modules/torchlayers/pooling.html +++ b/docs/_modules/torchlayers/pooling.html @@ -238,7 +238,7 @@

Source code for torchlayers.pooling

         return values
 
     def __repr__(self):
-        return f"{type(self).__name__}()"
+        return "{}()".format(type(self).__name__)
 
     def forward(self, inputs):
         while len(inputs.shape) > 2:
diff --git a/docs/genindex.html b/docs/genindex.html
index fd1a8a0..e45b8be 100644
--- a/docs/genindex.html
+++ b/docs/genindex.html
@@ -384,7 +384,7 @@ 

I

    diff --git a/docs/objects.inv b/docs/objects.inv index 825c0473a41b624d927720a3c7d48591e9b89746..3ee6e51b1b3fb4b589331a0d4078d7f90c1050dc 100644 GIT binary patch delta 1079 zcmV-71jzfP38x8=e}9=>Z{sKwhVT0;kUF<%c5bcJN;~arr_#2HCcE0pl-R^(i-85) zN&4#-e8d7hBN0pFZiBd{~BDM;RlvQ1miU} zlI;kVW7;h=awAVGWLMF&9tjBV3lbC^j=HTd@tXQ*bjs*EBUKaUyn_v~^$s?5_XVUB zk#$k+cEtU#x_?OF&-PzovyyOvRUG>@IhJ{S9N4rz{!2hAIg^1e=L4cqkdbQpfVe93 zZH&Sf`422mcw(eFJkfU~6v}i zQ*jAP-fnX3tEyL8*V9o{zqD6)@4ty2;=Q^4)tH+Ar&Me(#R?DKK2ehC4XWB4zm8m+ zBS#d|4DEqbSwi3z5wQ3S`m>D=Jk^IcCt1eQT7P6Hj2x{ldDWsfBogpzp%j(IlpJ-d`2hmz3}%8xf>`u^>8vl$vZ<-ToW!9$y`PtP$oa`OYn zcTD6g}N$q-k^gikY~cg9`8}1sgJW$;5mke`vzBy zE_-7z4f9T$xELO_n%>Y%!{=G=n}1iW%zvg8i~&57B-3#oH5uD|i~+ATch+&c3Ng}X z0SnGQbK<0U6v?t3u;6FYBSo@jv}vhn8<27`W~PD;$cLr zkB8&%kvFJsmdi)4%Yuh{o)ylcW_ACSB@wE7c|(X-`P71g%s6fiIQ x!diRTx0%^)k2-LoA2W(GD;7@o$Koixn1)19Cm15ybM3s>?f9y$(f?{z-@ypY7H|Ln delta 1077 zcmV-51j_rT38e{;e}7rgZsRBveeYKwbzjr$yjrW3cG}r)q-_;VW;Bl}v5CzV0}Hs5 z^y>%5c7hWdhzk!@9dqtEhwJNbA&QB#AaIX_1bm;AEUO3tzXJ}_SC}I?S0}(-p-OVw zKxV&z*^N1N&u9u2qWZ5UN(qz_e6sLZ2=2B_{DdMqI6|sm*?(0`3%W&Onu&j5xyj&1 zVn68pE0P7|6I0j*f6xL<6@G&XGwM+i2H092Chi~YZH4oa;mkSD3w}>x{~TPe@Ds}- zg7KO*lI;kVW7;h=u&;5$(aoNI3Ey=LK!Ky z4~VNmzs4whlK;j6g=dVEhXs8@LZM8z6gaCO`OK6eh@$j0>=edzB2DaR$;dwJ91fI` z?1p!HFpG9dz_mbJMDQCa6QOC0Xe}&LH`m#&0f?Yom za0=)2k*4tj?!&)GEonf@l!+1&{0{v(w;tTuyY-Qup-PKqeao4T0R=ntqhJ+}L%pC< zL7Jja?3^)zY3@#Q3$!M zL@~|K4oH5rl&~th zYi{eMyPWH8nMtfe_w(lNT)yv1MoTE)Uyb%B&+#1- z`J4;(A4-);KDV-K9#C_&yeY3nq4n7P@_UQSEy#N8@rcCrLQF_gCSxHz(8Bo8k~A-u z9Dj}P#-E!1ZIqFQ(MS_xK4;B1P&s@*>js+AXaw6D8Y)V&q(+bAOyv%B<2nj-GM^tC zf-#~ddrrZ`9)FgpYmWcNh^H7fPcvN^I&ZPT7?G!7VvqO8(Ae?n81WRt{e6QgN0+@Z z7$@_VOq7y#Un|U?SO@TqW@z*>B?EJdjW$KK01jd(jzy6P3<7%wGI?PLr~NQHLNBHv5!3>GisoK3?^QFttaJ1WbY0$2eFGE; diff --git a/docs/packages/torchlayers.html b/docs/packages/torchlayers.html index 894dd44..17d784a 100644 --- a/docs/packages/torchlayers.html +++ b/docs/packages/torchlayers.html @@ -251,15 +251,13 @@ -
    -
    -class torchlayers.Infer(index: int = 1)[source]
    +
    +
    +torchlayers.infer(module_class, index: str = 1)[source]

    Allows custom user modules to infer input shape.

    Input shape should be the first argument after self.

    Usually used as class decorator, e.g.:

    -
    # Remember it's a class, it has to be instantiated
    -@torchlayers.Infer()
    -class StrangeLinear(torch.nn.Linear):
    +
    class _StrangeLinearImpl(torch.nn.Linear):
         def __init__(self, in_features, out_features, bias: bool = True):
             super().__init__(in_features, out_features, bias)
             self.params = torch.nn.Parameter(torch.randn(out_features))
    @@ -267,14 +265,20 @@
         def forward(self, inputs):
             super().forward(inputs) + self.params
     
    +# Now you can use shape inference of in_features
    +StrangeLinear = torchlayers.infer(_StrangeLinearImpl)
    +
     # in_features can be inferred
     layer = StrangeLinear(out_features=64)
     
    Parameters
    -

    index (int, optional) – Index into tensor.shape input which should be inferred, e.g. tensor.shape[1]. -Default: 1 (0 being batch dimension)

    +
      +
    • module_class (torch.nn.Module) – Class of module to be updated with shape inference capabilities.

    • +
    • index (int, optional) – Index into tensor.shape input which should be inferred, e.g. tensor.shape[1]. +Default: 1 (0 being batch dimension)

    • +
    diff --git a/docs/searchindex.js b/docs/searchindex.js index 31a0746..242b3ac 100644 --- a/docs/searchindex.js +++ b/docs/searchindex.js @@ -1 +1 @@ -Search.setIndex({docnames:["index","packages/torchlayers","packages/torchlayers.activations","packages/torchlayers.convolution","packages/torchlayers.normalization","packages/torchlayers.pooling","packages/torchlayers.regularization","packages/torchlayers.upsample","related"],envversion:{"sphinx.domains.c":1,"sphinx.domains.changeset":1,"sphinx.domains.citation":1,"sphinx.domains.cpp":1,"sphinx.domains.index":1,"sphinx.domains.javascript":1,"sphinx.domains.math":2,"sphinx.domains.python":1,"sphinx.domains.rst":1,"sphinx.domains.std":1,"sphinx.ext.intersphinx":1,"sphinx.ext.todo":2,"sphinx.ext.viewcode":1,sphinx:56},filenames:["index.rst","packages/torchlayers.rst","packages/torchlayers.activations.rst","packages/torchlayers.convolution.rst","packages/torchlayers.normalization.rst","packages/torchlayers.pooling.rst","packages/torchlayers.regularization.rst","packages/torchlayers.upsample.rst","related.rst"],objects:{"":{torchlayers:[1,0,0,"-"]},"torchlayers.Concatenate":{forward:[1,2,1,""]},"torchlayers.Lambda":{forward:[1,2,1,""]},"torchlayers.Reshape":{forward:[1,2,1,""]},"torchlayers.activations":{HardSigmoid:[2,1,1,""],HardSwish:[2,1,1,""],Swish:[2,1,1,""],hard_sigmoid:[2,3,1,""],hard_swish:[2,3,1,""],swish:[2,3,1,""]},"torchlayers.activations.HardSigmoid":{forward:[2,2,1,""]},"torchlayers.activations.HardSwish":{forward:[2,2,1,""]},"torchlayers.activations.Swish":{forward:[2,2,1,""]},"torchlayers.convolution":{ChannelShuffle:[3,1,1,""],ChannelSplit:[3,1,1,""],Conv:[3,1,1,""],ConvTranspose:[3,1,1,""],Dense:[3,1,1,""],DepthwiseConv:[3,1,1,""],Fire:[3,1,1,""],InvertedResidualBottleneck:[3,1,1,""],MPoly:[3,1,1,""],Poly:[3,1,1,""],Residual:[3,1,1,""],SeparableConv:[3,1,1,""],SqueezeExcitation:[3,1,1,""],WayPoly:[3,1,1,""]},"torchlayers.convolution.ChannelShuffle":{forward:[3,2,1,""]},"torchlayers.convolution.ChannelSplit":{forward:[3,2,1,""]},"torchlayers.convolution.Dense":{forward:[3,2,1,""]},"torchlayers.convolution.Fire":{forward:[3,2,1,""]},"torchlayers.convolution.InvertedResidualBottleneck":{forward:[3,2,1,""]},"torchlayers.convolution.MPoly":{forward:[3,2,1,""]},"torchlayers.convolution.Poly":{extra_repr:[3,2,1,""],forward:[3,2,1,""]},"torchlayers.convolution.Residual":{forward:[3,2,1,""]},"torchlayers.convolution.SeparableConv":{forward:[3,2,1,""]},"torchlayers.convolution.SqueezeExcitation":{forward:[3,2,1,""]},"torchlayers.convolution.WayPoly":{forward:[3,2,1,""]},"torchlayers.normalization":{BatchNorm:[4,1,1,""],GroupNorm:[4,1,1,""],InstanceNorm:[4,1,1,""]},"torchlayers.pooling":{AvgPool:[5,1,1,""],GlobalAvgPool:[5,1,1,""],GlobalMaxPool:[5,1,1,""],MaxPool:[5,1,1,""]},"torchlayers.regularization":{Dropout:[6,1,1,""],StandardNormalNoise:[6,1,1,""],StochasticDepth:[6,1,1,""]},"torchlayers.regularization.StandardNormalNoise":{forward:[6,2,1,""]},"torchlayers.regularization.StochasticDepth":{forward:[6,2,1,""]},"torchlayers.upsample":{ConvPixelShuffle:[7,1,1,""]},"torchlayers.upsample.ConvPixelShuffle":{forward:[7,2,1,""],icnr_initialization:[7,2,1,""],post_build:[7,2,1,""]},torchlayers:{Concatenate:[1,1,1,""],Infer:[1,1,1,""],Lambda:[1,1,1,""],Reshape:[1,1,1,""],activations:[2,0,0,"-"],build:[1,3,1,""],convolution:[3,0,0,"-"],normalization:[4,0,0,"-"],pooling:[5,0,0,"-"],regularization:[6,0,0,"-"],upsample:[7,0,0,"-"]}},objnames:{"0":["py","module","Python module"],"1":["py","class","Python class"],"2":["py","method","Python method"],"3":["py","function","Python function"]},objtypes:{"0":"py:module","1":"py:class","2":"py:method","3":"py:function"},terms:{"1x1":3,"50x":3,"5mb":3,"class":[0,1,2,3,4,5,6,7],"default":[0,1,2,3,4,5,6,7],"final":3,"float":[2,3,4,6],"function":[0,1,2,3,6,7],"import":[0,3],"int":[1,3,4,5,7],"return":[1,2,3,5,7],"super":1,"true":[1,3,4,5,6,7],"while":[1,2,3,6,7],And:0,For:[0,3,6],One:3,The:[3,5],Use:1,Useful:[0,5],Uses:2,__init__:[1,7],about:8,abov:[0,2],accept:[3,7],accord:7,accordingli:3,accuraci:3,across:5,act:[3,7],action:1,activ:[0,1,3,8],actual:6,add:[3,6,7],addabl:3,added:[3,4,5,7],adding:3,addit:[0,3],advis:3,affin:4,after:[1,3,5,7],afterward:[1,2,3,6,7],agnost:7,alexnet:3,all:[0,1,2,3,6,7],allow:[1,7],along:[1,3,5],also:1,although:[1,2,3,6,7],alwai:[1,4],analys:8,andrew:2,ani:[1,3,6],anyth:1,appli:[2,3,4],architectur:[0,3],arg:[1,3],argument:[1,3,4],around:8,artifact:7,asymmetr:3,attent:0,author:8,automat:0,avail:0,averag:[3,4,5],avg:5,avgpool1d:5,avgpool2d:5,avgpool3d:5,avgpool:5,awar:3,base:[0,3,4,6,8],batch:[1,3,4,5,6,7],batchnorm:[3,4],befor:3,beforehand:1,being:1,below:[0,1,8],beta:2,between:[3,7],bia:[1,3,7],block:[0,3,7],bool:[1,2,3,4,5,6,7],both:[0,3,4,5,7],bottleneck:[3,6],build:[0,1],built:7,call:[0,1,2,3,6,7],callabl:[1,3,7],can:[0,1,3,4,7,8],capabl:0,care:[1,2,3,6,7],ceil:5,ceil_mod:5,channel:[3,4,5,7],channelshuffl:3,channelsplit:3,checkerboard:7,circular:[3,7],closer:6,cloud:0,com:7,compat:3,competit:0,compil:1,comput:[1,2,3,4,5,6,7],concaten:[1,3],connect:[3,7],consecut:3,consid:[3,8],control:5,conv1d:0,conv2d:3,conv3d:3,conv:[0,3],conveni:3,convolut:[0,1,5,7],convolv:[3,7],convpixelshuffl:7,convtranspos:3,copi:1,correct:3,correspond:0,could:[1,3],count_include_pad:5,counterpart:4,cpu:2,creat:[1,3,4,6],creation:8,cuda:[0,8],cudnn7:0,cumul:4,current:[0,3,7],custom:[0,1,3],customiz:3,dai:8,data:1,dead:6,decor:1,deep:[3,6],def:1,defin:[1,2,3,6,7],denomin:4,dens:3,depend:5,depth:6,depthwis:3,depthwiseconv:3,detail:2,devic:[3,8],devis:6,differ:3,differenti:6,dilat:[3,5,7],dim:[1,3],dimens:[1,3,4,5,6,7],dimension:[0,7],directli:8,distribut:6,divers:3,divis:3,document:[3,8],doe:[4,6],done:3,dropout:[0,6],due:[0,3,7],dure:[3,4,6],duti:8,each:[0,5],earli:6,easili:3,effect:2,effici:[0,2,3],either:[3,4,5,6],element:[2,3,5,6,7],emploi:6,empti:3,enabl:0,environ:8,eps:4,equal:[3,6],equat:3,error:3,especi:1,etc:[0,8],eval:4,even:3,everi:[1,2,3,6,7],exactli:[3,4],exampl:[0,1,7],except:[0,1,3],excit:[0,3],exclud:1,exp:2,expand:3,expect:4,explicitli:[1,3],express:3,extens:3,extra:3,extra_repr:3,extract:6,extrem:3,eye_:1,f_0:3,f_1:3,f_2:3,f_n:3,fact:6,factor:[3,7],fals:[2,4,5,6],featur:[0,1,5,6],fewer:[3,6],file:7,find:[0,8],fire:3,first:[1,3,4,5,6],floor:5,follow:[0,7],foo:1,form:2,former:[1,2,3,6,7],formula:2,forrest:3,forward:[1,2,3,4,6,7],free:7,from:[0,3,4,5,6,7],fulli:0,gao:[3,6],get:[0,5,7,8],github:[0,7,8],globalavgpool:5,globalmaxpool:5,goal:8,got:[3,7],greater:3,group:[3,4,7],groupnorm:4,half:3,hard_sigmoid:2,hard_swish:2,hardsigmoid:[2,3],hardswish:[2,3],hardtanh:2,has:[0,1,3,4,6,7],have:1,height:7,help:[6,8],here:1,hidden:3,hidden_channel:3,higher:6,hook:[1,2,3,6,7],host:8,howard:2,http:[7,8],huang:[3,6],iandola:3,icnr:7,icnr_initi:7,ident:3,ignor:[1,2,3,6,7],imag:[0,3,7],imagenet:0,implement:[3,7],implicit:5,improv:[3,8],in_channel:[3,7],in_featur:1,includ:[3,4,5,6],increas:7,index:1,indic:5,infer:[0,1,4,6],inferr:0,inform:[3,8],init:[1,7],initi:[1,3,4,7],inplac:[2,6],input:[0,1,3,4,5,6,7],instanc:[1,2,3,4,6,7],instancenorm:4,instanti:1,instead:[0,1,2,3,5,6,7],intern:2,intervent:0,invert:3,invertedresidualbottleneck:3,just:[0,3,4],kaiming_normal_:7,keep:3,kept:6,kera:[0,1],kernel:[3,7],kernel_s:[0,3,5,7],keyword:1,knowledg:3,kwarg:1,lack:0,lambda:1,last:5,later:[5,6],latest:0,latter:[1,2,3,6,7],layer:[0,1,3,4,5,6,7],learn:3,learnabl:[3,4,7],least:3,leav:6,length:[1,3],less:6,level:[3,6],librari:[0,8],lighter:0,like:[0,3,4,7],line:3,linear:[0,1,3,4],list:[0,1],low:6,main:0,mainli:1,mani:3,map:3,mark:3,max:[2,5],maximum:5,maxpool1d:5,maxpool2d:5,maxpool3d:5,maxpool:5,maxunpool:5,mean:[4,5],measur:8,method:[0,1,3],might:6,min:2,mingx:3,mini:4,mix:3,mnasnet:3,mobil:3,mobilenetv2:3,mobilenetv3:[2,3],mode:[3,4],model:[1,3,8],modul:[1,8],momentum:4,more:[2,3,6,8],most:0,mostli:0,move:4,mpoli:3,multi:3,multipl:[1,3],multipli:[2,3],must:1,mymodul:1,name:0,natur:3,nearest:7,necessari:1,need:[1,2,3,6,7],neighbour:7,net:[0,3,6],network:[3,6,8],neural:[3,8],neuron:[6,8],next:3,nightli:0,nightly_10:0,nightly_18:0,nightly_:0,nois:6,non:3,none:[3,4,5,7],nonetheless:6,normal:[0,1,3,6],note:7,num_channel:4,num_featur:4,num_group:4,number:[3,4,7],numer:4,nvidia:0,odd:3,offici:0,one:[0,1,2,3,6,7],ones:3,onli:[3,7],oper:[2,3,5,6,8],opriont:5,oprtion:5,option:[1,2,3,4,5,6,7],order:3,orient:8,origin:[2,3,6,7],other:[1,3,6,8],otherwis:[3,4],out:6,out_channel:[3,7],out_featur:1,output:[3,4,5,6,7],output_pad:3,over:[4,5,8],overhead:[0,1],overridden:[1,2,3,6,7],own:[1,3],packag:0,pad:[0,3,5,7],padding_mod:[3,7],paper:3,parallel:3,param:1,paramet:[1,2,3,4,5,6,7],part:3,partial:3,pass:[1,2,3,4,5,6,7],per:3,percentag:3,perform:[1,2,3,5,6,7,8],pixel:[5,7],pixelshuffl:7,place:[2,6],platform:3,plot:8,point:7,poli:3,poly_modul:3,polyincept:3,polynet:[0,3],pool:[0,1,3],possibl:1,post_build:[1,7],prajit:2,prefix:0,present:7,preserv:[1,3],previou:[1,3,4],primit:1,print:3,prior:7,probabl:6,problem:6,process:8,produc:[3,7],project:3,pronounc:6,propos:[2,3,6],provid:[0,1,3],proxi:1,pull:[0,7],pursuit:3,pytorch:[0,1,3,4,7,8],qualifi:0,ramachandran:2,randn:1,randomli:6,rang:2,rate:0,ratio:3,read:8,readm:0,realli:1,recip:[1,2,3,6,7],record:8,recurr:0,reduc:3,regist:[1,2,3,6,7],regular:[0,1],reimplement:3,rel:6,releas:0,relu6:3,relu:3,remaind:3,rememb:1,remov:1,represent:3,requir:[0,1],rescal:3,research:[3,6],reshap:[0,1,3],reshuffl:3,residu:3,resiz:7,resnet:3,resolut:7,respect:[3,4,5,6],respons:3,result:3,return_indic:5,run:[1,2,3,4,6,7],running_mean:4,running_var:4,runtim:0,same:[0,1,3,5,7,8],sandler:3,save:1,scale:7,search:[2,3],see:[0,2,3,7],seed:8,seen:0,self:1,separ:[3,4],separableconv:3,sequenti:1,set:[3,4],shape:[0,1,3,4,5,6],should:[1,2,3,6,7,8],shuffl:3,shufflenet:3,side:[3,5,7],sigmoid:[2,3],significantli:0,silent:[1,2,3,6,7],similar:[2,7],similarli:[0,1],simpl:[1,3,4],sinc:[1,2,3,6,7],singl:[0,3],size:[3,5,7,8],skip:[3,6],smaller:3,some:[1,6],sota:0,sourc:[1,2,3,4,5,6,7],space:[3,7],spatial:7,specif:8,specifi:[1,3,6],split:3,squash:3,squeez:[0,3],squeeze_excit:3,squeeze_excitation_activ:3,squeeze_excitation_hidden:3,squeeze_excitation_sigmoid:3,squeezeexcit:[0,3],squeezenet:3,stabil:4,standard:[1,3,4,6],standardnormalnois:[0,6],start:[7,8],statist:4,stochast:6,stochasticdepth:[0,6],str:[3,7],strangelinear:1,stride:[3,5,7],string:[3,7],structur:3,sub:7,subclass:[1,2,3,6,7],suggest:[3,6],sum:3,support:[0,1,7],surviv:6,swish:2,system:8,szymonmaszk:[0,8],tag:0,tailor:8,take:[1,2,3,5,6,7],taken:3,tan:3,target:8,task:8,techniqu:6,tensor:[1,2,3,5,6,7],than:[1,3],them:[0,1,2,3,6,7],thi:[0,1,2,3,4,6,7,8],those:[0,3],though:6,three:1,through:3,thrown:3,time:[1,3],torch:[0,1,2,3,5,6,7],torchscript:0,total:3,track:4,track_running_stat:4,train:[4,6],transfer:3,transform:[0,3],transpos:3,tune:6,tupl:[1,3,7],two:[3,7],type:[1,2,3,5,7],ubuntu18:0,ubuntu:0,union:[3,7],untouch:6,upsampl:[0,1],upscale_factor:7,usag:3,use:[0,1,3,5],used:[1,2,3,4,5,6],useful:[1,6],user:[0,1,3],uses:4,using:[0,1,3,7],usual:[1,3,5],valu:[1,3,4,5,7],variabl:[1,3],varianc:4,variou:0,veri:3,version:[0,6,7],via:3,view:1,visual:8,waypoli:3,weight:[1,7],well:[0,1,4,8],when:[1,3,4,5],where:3,whether:[2,3],which:[1,3],whose:6,width:7,window:5,wise:[2,3],wish:0,within:[1,2,3,6,7],without:[0,3],work:[0,1,3,4],would:3,xiangyu:3,xingcheng:3,you:[0,1,3,8],your:[1,3,8],zero:[0,3,5,6,7],zhang:3},titles:["torchlayers","torchlayers package","torchlayers.activations module","torchlayers.convolution module","torchlayers.normalization module","torchlayers.pooling module","torchlayers.regularization module","torchlayers.upsample module","Related projects"],titleterms:{activ:2,convolut:3,cpu:0,docker:0,gpu:0,instal:0,modul:[0,2,3,4,5,6,7],normal:4,packag:1,pip:0,pool:5,project:8,regular:6,relat:8,submodul:1,torchfunc:8,torchlay:[0,1,2,3,4,5,6,7],upsampl:7}}) \ No newline at end of file +Search.setIndex({docnames:["index","packages/torchlayers","packages/torchlayers.activations","packages/torchlayers.convolution","packages/torchlayers.normalization","packages/torchlayers.pooling","packages/torchlayers.regularization","packages/torchlayers.upsample","related"],envversion:{"sphinx.domains.c":1,"sphinx.domains.changeset":1,"sphinx.domains.citation":1,"sphinx.domains.cpp":1,"sphinx.domains.index":1,"sphinx.domains.javascript":1,"sphinx.domains.math":2,"sphinx.domains.python":1,"sphinx.domains.rst":1,"sphinx.domains.std":1,"sphinx.ext.intersphinx":1,"sphinx.ext.todo":2,"sphinx.ext.viewcode":1,sphinx:56},filenames:["index.rst","packages/torchlayers.rst","packages/torchlayers.activations.rst","packages/torchlayers.convolution.rst","packages/torchlayers.normalization.rst","packages/torchlayers.pooling.rst","packages/torchlayers.regularization.rst","packages/torchlayers.upsample.rst","related.rst"],objects:{"":{torchlayers:[1,0,0,"-"]},"torchlayers.Concatenate":{forward:[1,2,1,""]},"torchlayers.Lambda":{forward:[1,2,1,""]},"torchlayers.Reshape":{forward:[1,2,1,""]},"torchlayers.activations":{HardSigmoid:[2,1,1,""],HardSwish:[2,1,1,""],Swish:[2,1,1,""],hard_sigmoid:[2,3,1,""],hard_swish:[2,3,1,""],swish:[2,3,1,""]},"torchlayers.activations.HardSigmoid":{forward:[2,2,1,""]},"torchlayers.activations.HardSwish":{forward:[2,2,1,""]},"torchlayers.activations.Swish":{forward:[2,2,1,""]},"torchlayers.convolution":{ChannelShuffle:[3,1,1,""],ChannelSplit:[3,1,1,""],Conv:[3,1,1,""],ConvTranspose:[3,1,1,""],Dense:[3,1,1,""],DepthwiseConv:[3,1,1,""],Fire:[3,1,1,""],InvertedResidualBottleneck:[3,1,1,""],MPoly:[3,1,1,""],Poly:[3,1,1,""],Residual:[3,1,1,""],SeparableConv:[3,1,1,""],SqueezeExcitation:[3,1,1,""],WayPoly:[3,1,1,""]},"torchlayers.convolution.ChannelShuffle":{forward:[3,2,1,""]},"torchlayers.convolution.ChannelSplit":{forward:[3,2,1,""]},"torchlayers.convolution.Dense":{forward:[3,2,1,""]},"torchlayers.convolution.Fire":{forward:[3,2,1,""]},"torchlayers.convolution.InvertedResidualBottleneck":{forward:[3,2,1,""]},"torchlayers.convolution.MPoly":{forward:[3,2,1,""]},"torchlayers.convolution.Poly":{extra_repr:[3,2,1,""],forward:[3,2,1,""]},"torchlayers.convolution.Residual":{forward:[3,2,1,""]},"torchlayers.convolution.SeparableConv":{forward:[3,2,1,""]},"torchlayers.convolution.SqueezeExcitation":{forward:[3,2,1,""]},"torchlayers.convolution.WayPoly":{forward:[3,2,1,""]},"torchlayers.normalization":{BatchNorm:[4,1,1,""],GroupNorm:[4,1,1,""],InstanceNorm:[4,1,1,""]},"torchlayers.pooling":{AvgPool:[5,1,1,""],GlobalAvgPool:[5,1,1,""],GlobalMaxPool:[5,1,1,""],MaxPool:[5,1,1,""]},"torchlayers.regularization":{Dropout:[6,1,1,""],StandardNormalNoise:[6,1,1,""],StochasticDepth:[6,1,1,""]},"torchlayers.regularization.StandardNormalNoise":{forward:[6,2,1,""]},"torchlayers.regularization.StochasticDepth":{forward:[6,2,1,""]},"torchlayers.upsample":{ConvPixelShuffle:[7,1,1,""]},"torchlayers.upsample.ConvPixelShuffle":{forward:[7,2,1,""],icnr_initialization:[7,2,1,""],post_build:[7,2,1,""]},torchlayers:{Concatenate:[1,1,1,""],Lambda:[1,1,1,""],Reshape:[1,1,1,""],activations:[2,0,0,"-"],build:[1,3,1,""],convolution:[3,0,0,"-"],infer:[1,3,1,""],normalization:[4,0,0,"-"],pooling:[5,0,0,"-"],regularization:[6,0,0,"-"],upsample:[7,0,0,"-"]}},objnames:{"0":["py","module","Python module"],"1":["py","class","Python class"],"2":["py","method","Python method"],"3":["py","function","Python function"]},objtypes:{"0":"py:module","1":"py:class","2":"py:method","3":"py:function"},terms:{"1x1":3,"50x":3,"5mb":3,"class":[0,1,2,3,4,5,6,7],"default":[0,1,2,3,4,5,6,7],"final":3,"float":[2,3,4,6],"function":[0,1,2,3,6,7],"import":[0,3],"int":[1,3,4,5,7],"return":[1,2,3,5,7],"super":1,"true":[1,3,4,5,6,7],"while":[1,2,3,6,7],And:0,For:[0,3,6],One:3,The:[3,5],Use:1,Useful:[0,5],Uses:2,__init__:[1,7],_strangelinearimpl:1,about:8,abov:[0,2],accept:[3,7],accord:7,accordingli:3,accuraci:3,across:5,act:[3,7],action:1,activ:[0,1,3,8],actual:6,add:[3,6,7],addabl:3,added:[3,4,5,7],adding:3,addit:[0,3],advis:3,affin:4,after:[1,3,5,7],afterward:[1,2,3,6,7],agnost:7,alexnet:3,all:[0,1,2,3,6,7],allow:[1,7],along:[1,3,5],also:1,although:[1,2,3,6,7],alwai:[1,4],analys:8,andrew:2,ani:[1,3,6],anyth:1,appli:[2,3,4],architectur:[0,3],arg:[1,3],argument:[1,3,4],around:8,artifact:7,asymmetr:3,attent:0,author:8,automat:0,avail:0,averag:[3,4,5],avg:5,avgpool1d:5,avgpool2d:5,avgpool3d:5,avgpool:5,awar:3,base:[0,3,4,6,8],batch:[1,3,4,5,6,7],batchnorm:[3,4],befor:3,beforehand:1,being:1,below:[0,1,8],beta:2,between:[3,7],bia:[1,3,7],block:[0,3,7],bool:[1,2,3,4,5,6,7],both:[0,3,4,5,7],bottleneck:[3,6],build:[0,1],built:7,call:[0,1,2,3,6,7],callabl:[1,3,7],can:[0,1,3,4,7,8],capabl:[0,1],care:[1,2,3,6,7],ceil:5,ceil_mod:5,channel:[3,4,5,7],channelshuffl:3,channelsplit:3,checkerboard:7,circular:[3,7],closer:6,cloud:0,com:7,compat:3,competit:0,compil:1,comput:[1,2,3,4,5,6,7],concaten:[1,3],connect:[3,7],consecut:3,consid:[3,8],control:5,conv1d:0,conv2d:3,conv3d:3,conv:[0,3],conveni:3,convolut:[0,1,5,7],convolv:[3,7],convpixelshuffl:7,convtranspos:3,copi:1,correct:3,correspond:0,could:[1,3],count_include_pad:5,counterpart:4,cpu:2,creat:[1,3,4,6],creation:8,cuda:[0,8],cudnn7:0,cumul:4,current:[0,3,7],custom:[0,1,3],customiz:3,dai:8,data:1,dead:6,decor:1,deep:[3,6],def:1,defin:[1,2,3,6,7],denomin:4,dens:3,depend:5,depth:6,depthwis:3,depthwiseconv:3,detail:2,devic:[3,8],devis:6,differ:3,differenti:6,dilat:[3,5,7],dim:[1,3],dimens:[1,3,4,5,6,7],dimension:[0,7],directli:8,distribut:6,divers:3,divis:3,document:[3,8],doe:[4,6],done:3,dropout:[0,6],due:[0,3,7],dure:[3,4,6],duti:8,each:[0,5],earli:6,easili:3,effect:2,effici:[0,2,3],either:[3,4,5,6],element:[2,3,5,6,7],emploi:6,empti:3,enabl:0,environ:8,eps:4,equal:[3,6],equat:3,error:3,especi:1,etc:[0,8],eval:4,even:3,everi:[1,2,3,6,7],exactli:[3,4],exampl:[0,1,7],except:[0,1,3],excit:[0,3],exclud:1,exp:2,expand:3,expect:4,explicitli:[1,3],express:3,extens:3,extra:3,extra_repr:3,extract:6,extrem:3,eye_:1,f_0:3,f_1:3,f_2:3,f_n:3,fact:6,factor:[3,7],fals:[2,4,5,6],featur:[0,1,5,6],fewer:[3,6],file:7,find:[0,8],fire:3,first:[1,3,4,5,6],floor:5,follow:[0,7],foo:1,form:2,former:[1,2,3,6,7],formula:2,forrest:3,forward:[1,2,3,4,6,7],free:7,from:[0,3,4,5,6,7],fulli:0,gao:[3,6],get:[0,5,7,8],github:[0,7,8],globalavgpool:5,globalmaxpool:5,goal:8,got:[3,7],greater:3,group:[3,4,7],groupnorm:4,half:3,hard_sigmoid:2,hard_swish:2,hardsigmoid:[2,3],hardswish:[2,3],hardtanh:2,has:[0,3,4,6,7],have:1,height:7,help:[6,8],here:1,hidden:3,hidden_channel:3,higher:6,hook:[1,2,3,6,7],host:8,howard:2,http:[7,8],huang:[3,6],iandola:3,icnr:7,icnr_initi:7,ident:3,ignor:[1,2,3,6,7],imag:[0,3,7],imagenet:0,implement:[3,7],implicit:5,improv:[3,8],in_channel:[3,7],in_featur:1,includ:[3,4,5,6],increas:7,index:1,indic:5,infer:[0,1,4,6],inferr:0,inform:[3,8],init:[1,7],initi:[1,3,4,7],inplac:[2,6],input:[0,1,3,4,5,6,7],instanc:[1,2,3,4,6,7],instancenorm:4,instead:[0,1,2,3,5,6,7],intern:2,intervent:0,invert:3,invertedresidualbottleneck:3,just:[0,3,4],kaiming_normal_:7,keep:3,kept:6,kera:[0,1],kernel:[3,7],kernel_s:[0,3,5,7],keyword:1,knowledg:3,kwarg:1,lack:0,lambda:1,last:5,later:[5,6],latest:0,latter:[1,2,3,6,7],layer:[0,1,3,4,5,6,7],learn:3,learnabl:[3,4,7],least:3,leav:6,length:[1,3],less:6,level:[3,6],librari:[0,8],lighter:0,like:[0,3,4,7],line:3,linear:[0,1,3,4],list:[0,1],low:6,main:0,mainli:1,mani:3,map:3,mark:3,max:[2,5],maximum:5,maxpool1d:5,maxpool2d:5,maxpool3d:5,maxpool:5,maxunpool:5,mean:[4,5],measur:8,method:[0,1,3],might:6,min:2,mingx:3,mini:4,mix:3,mnasnet:3,mobil:3,mobilenetv2:3,mobilenetv3:[2,3],mode:[3,4],model:[1,3,8],modul:[1,8],module_class:1,momentum:4,more:[2,3,6,8],most:0,mostli:0,move:4,mpoli:3,multi:3,multipl:[1,3],multipli:[2,3],must:1,mymodul:1,name:0,natur:3,nearest:7,necessari:1,need:[1,2,3,6,7],neighbour:7,net:[0,3,6],network:[3,6,8],neural:[3,8],neuron:[6,8],next:3,nightli:0,nightly_10:0,nightly_18:0,nightly_:0,nois:6,non:3,none:[3,4,5,7],nonetheless:6,normal:[0,1,3,6],note:7,now:1,num_channel:4,num_featur:4,num_group:4,number:[3,4,7],numer:4,nvidia:0,odd:3,offici:0,one:[0,1,2,3,6,7],ones:3,onli:[3,7],oper:[2,3,5,6,8],opriont:5,oprtion:5,option:[1,2,3,4,5,6,7],order:3,orient:8,origin:[2,3,6,7],other:[1,3,6,8],otherwis:[3,4],out:6,out_channel:[3,7],out_featur:1,output:[3,4,5,6,7],output_pad:3,over:[4,5,8],overhead:[0,1],overridden:[1,2,3,6,7],own:[1,3],packag:0,pad:[0,3,5,7],padding_mod:[3,7],paper:3,parallel:3,param:1,paramet:[1,2,3,4,5,6,7],part:3,partial:3,pass:[1,2,3,4,5,6,7],per:3,percentag:3,perform:[1,2,3,5,6,7,8],pixel:[5,7],pixelshuffl:7,place:[2,6],platform:3,plot:8,point:7,poli:3,poly_modul:3,polyincept:3,polynet:[0,3],pool:[0,1,3],possibl:1,post_build:[1,7],prajit:2,prefix:0,present:7,preserv:[1,3],previou:[1,3,4],primit:1,print:3,prior:7,probabl:6,problem:6,process:8,produc:[3,7],project:3,pronounc:6,propos:[2,3,6],provid:[0,1,3],proxi:1,pull:[0,7],pursuit:3,pytorch:[0,1,3,4,7,8],qualifi:0,ramachandran:2,randn:1,randomli:6,rang:2,rate:0,ratio:3,read:8,readm:0,realli:1,recip:[1,2,3,6,7],record:8,recurr:0,reduc:3,regist:[1,2,3,6,7],regular:[0,1],reimplement:3,rel:6,releas:0,relu6:3,relu:3,remaind:3,remov:1,represent:3,requir:[0,1],rescal:3,research:[3,6],reshap:[0,1,3],reshuffl:3,residu:3,resiz:7,resnet:3,resolut:7,respect:[3,4,5,6],respons:3,result:3,return_indic:5,run:[1,2,3,4,6,7],running_mean:4,running_var:4,runtim:0,same:[0,1,3,5,7,8],sandler:3,save:1,scale:7,search:[2,3],see:[0,2,3,7],seed:8,seen:0,self:1,separ:[3,4],separableconv:3,sequenti:1,set:[3,4],shape:[0,1,3,4,5,6],should:[1,2,3,6,7,8],shuffl:3,shufflenet:3,side:[3,5,7],sigmoid:[2,3],significantli:0,silent:[1,2,3,6,7],similar:[2,7],similarli:[0,1],simpl:[1,3,4],sinc:[1,2,3,6,7],singl:[0,3],size:[3,5,7,8],skip:[3,6],smaller:3,some:[1,6],sota:0,sourc:[1,2,3,4,5,6,7],space:[3,7],spatial:7,specif:8,specifi:[1,3,6],split:3,squash:3,squeez:[0,3],squeeze_excit:3,squeeze_excitation_activ:3,squeeze_excitation_hidden:3,squeeze_excitation_sigmoid:3,squeezeexcit:[0,3],squeezenet:3,stabil:4,standard:[1,3,4,6],standardnormalnois:[0,6],start:[7,8],statist:4,stochast:6,stochasticdepth:[0,6],str:[1,3,7],strangelinear:1,stride:[3,5,7],string:[3,7],structur:3,sub:7,subclass:[1,2,3,6,7],suggest:[3,6],sum:3,support:[0,1,7],surviv:6,swish:2,system:8,szymonmaszk:[0,8],tag:0,tailor:8,take:[1,2,3,5,6,7],taken:3,tan:3,target:8,task:8,techniqu:6,tensor:[1,2,3,5,6,7],than:[1,3],them:[0,1,2,3,6,7],thi:[0,1,2,3,4,6,7,8],those:[0,3],though:6,three:1,through:3,thrown:3,time:[1,3],torch:[0,1,2,3,5,6,7],torchscript:0,total:3,track:4,track_running_stat:4,train:[4,6],transfer:3,transform:[0,3],transpos:3,tune:6,tupl:[1,3,7],two:[3,7],type:[1,2,3,5,7],ubuntu18:0,ubuntu:0,union:[3,7],untouch:6,updat:1,upsampl:[0,1],upscale_factor:7,usag:3,use:[0,1,3,5],used:[1,2,3,4,5,6],useful:[1,6],user:[0,1,3],uses:4,using:[0,1,3,7],usual:[1,3,5],valu:[1,3,4,5,7],variabl:[1,3],varianc:4,variou:0,veri:3,version:[0,6,7],via:3,view:1,visual:8,waypoli:3,weight:[1,7],well:[0,1,4,8],when:[1,3,4,5],where:3,whether:[2,3],which:[1,3],whose:6,width:7,window:5,wise:[2,3],wish:0,within:[1,2,3,6,7],without:[0,3],work:[0,1,3,4],would:3,xiangyu:3,xingcheng:3,you:[0,1,3,8],your:[1,3,8],zero:[0,3,5,6,7],zhang:3},titles:["torchlayers","torchlayers package","torchlayers.activations module","torchlayers.convolution module","torchlayers.normalization module","torchlayers.pooling module","torchlayers.regularization module","torchlayers.upsample module","Related projects"],titleterms:{activ:2,convolut:3,cpu:0,docker:0,gpu:0,instal:0,modul:[0,2,3,4,5,6,7],normal:4,packag:1,pip:0,pool:5,project:8,regular:6,relat:8,submodul:1,torchfunc:8,torchlay:[0,1,2,3,4,5,6,7],upsampl:7}}) \ No newline at end of file diff --git a/tests/general_test.py b/tests/general_test.py index fc21c6d..4297035 100644 --- a/tests/general_test.py +++ b/tests/general_test.py @@ -11,13 +11,15 @@ def forward(self, tensor): return tensor, tensor, tensor -@torchlayers.Infer() -class CustomLinear(torch.nn.Linear): +class _CustomLinearImpl(torch.nn.Linear): def __init__(self, in_features, out_features, bias: bool = True): super().__init__(in_features, out_features, bias) self.some_params = torch.nn.Parameter(torch.randn(2, out_features)) +CustomLinear = torchlayers.infer(_CustomLinearImpl) + + @pytest.fixture def model(): return torchlayers.Sequential( @@ -57,3 +59,9 @@ def test_custom_inferable_parameters(): layer = CustomLinear(32) layer(torch.rand(16, 64)) assert layer.some_params.shape == (2, 32) + + +def test_custom_inferable_build(): + layer = CustomLinear(32) + layer = torchlayers.build(layer, torch.rand(16, 64)) + assert layer.some_params.shape == (2, 32) diff --git a/torchlayers/__init__.py b/torchlayers/__init__.py index 6026d6e..0794ca9 100644 --- a/torchlayers/__init__.py +++ b/torchlayers/__init__.py @@ -9,7 +9,7 @@ pooling, regularization, upsample) from ._version import __version__ -__all__ = ["build", "Infer", "Lambda", "Reshape", "Concatenate"] +__all__ = ["build", "infer", "Lambda", "Reshape", "Concatenate"] def build(module, *args, **kwargs): @@ -71,16 +71,14 @@ def run_post(module): return module -class Infer: +def infer(module_class, index: str = 1): """Allows custom user modules to infer input shape. Input shape should be the first argument after `self`. Usually used as class decorator, e.g.:: - # Remember it's a class, it has to be instantiated - @torchlayers.Infer() - class StrangeLinear(torch.nn.Linear): + class _StrangeLinearImpl(torch.nn.Linear): def __init__(self, in_features, out_features, bias: bool = True): super().__init__(in_features, out_features, bias) self.params = torch.nn.Parameter(torch.randn(out_features)) @@ -88,77 +86,75 @@ def __init__(self, in_features, out_features, bias: bool = True): def forward(self, inputs): super().forward(inputs) + self.params + # Now you can use shape inference of in_features + StrangeLinear = torchlayers.infer(_StrangeLinearImpl) + # in_features can be inferred layer = StrangeLinear(out_features=64) Parameters ---------- + module_class: torch.nn.Module + Class of module to be updated with shape inference capabilities. + index: int, optional Index into `tensor.shape` input which should be inferred, e.g. tensor.shape[1]. Default: `1` (`0` being batch dimension) """ - def __init__(self, index: int = 1): - self.index: int = index - - def __call__(self, module_class): - init_arguments = [ - str(argument) - for argument in inspect.signature(module_class.__init__).parameters.values() - ] - - # Other argument than self - if len(init_arguments) > 1: - name = module_class.__name__ - infered_module = type( - name, (torch.nn.Module,), {_dev_utils.infer.MODULE_CLASS: module_class}, - ) - parsed_arguments, uninferable_arguments = _dev_utils.infer.parse_arguments( - init_arguments, infered_module - ) - - setattr( - infered_module, - "__init__", - _dev_utils.infer.create_init(parsed_arguments), - ) - - setattr( - infered_module, - "forward", - _dev_utils.infer.create_forward( - _dev_utils.infer.MODULE, - _dev_utils.infer.MODULE_CLASS, - parsed_arguments, - self.index, - ), - ) - setattr( - infered_module, - "__repr__", - _dev_utils.infer.create_repr( - _dev_utils.infer.MODULE, **uninferable_arguments - ), - ) - setattr( - infered_module, - "__getattr__", - _dev_utils.infer.create_getattr(_dev_utils.infer.MODULE), - ) - - setattr( - infered_module, - "__reduce__", - _dev_utils.infer.create_reduce( - _dev_utils.infer.MODULE, parsed_arguments - ), - ) - - return infered_module - - return module_class + init_arguments = [ + str(argument) + for argument in inspect.signature(module_class.__init__).parameters.values() + ] + + # Other argument than self + if len(init_arguments) > 1: + name = module_class.__name__ + infered_module = type( + name, (torch.nn.Module,), {_dev_utils.infer.MODULE_CLASS: module_class}, + ) + parsed_arguments, uninferable_arguments = _dev_utils.infer.parse_arguments( + init_arguments, infered_module + ) + + setattr( + infered_module, "__init__", _dev_utils.infer.create_init(parsed_arguments), + ) + + setattr( + infered_module, + "forward", + _dev_utils.infer.create_forward( + _dev_utils.infer.MODULE, + _dev_utils.infer.MODULE_CLASS, + parsed_arguments, + index, + ), + ) + setattr( + infered_module, + "__repr__", + _dev_utils.infer.create_repr( + _dev_utils.infer.MODULE, **uninferable_arguments + ), + ) + setattr( + infered_module, + "__getattr__", + _dev_utils.infer.create_getattr(_dev_utils.infer.MODULE), + ) + + setattr( + infered_module, + "__reduce__", + _dev_utils.infer.create_reduce(_dev_utils.infer.MODULE, parsed_arguments), + ) + + return infered_module + + return module_class class Lambda(torch.nn.Module): @@ -300,7 +296,7 @@ def _getattr(name): module_class = _getattr(name) if name in _inferable.torch.all() + _inferable.custom.all(): - return Infer(_dev_utils.helpers.get_per_module_index(module_class))( - module_class + return infer( + module_class, _dev_utils.helpers.get_per_module_index(module_class) ) return module_class