-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
71f4b25
commit 8617bca
Showing
4 changed files
with
155 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from typing import Tuple, override | ||
|
||
from luma.core.super import Optimizer | ||
from luma.interface.typing import Tensor, TensorLike | ||
from luma.interface.util import InitUtil | ||
|
||
from luma.neural.layer import * | ||
from luma.neural.autoprop import LayerNode, LayerGraph | ||
|
||
|
||
class _InvertedRes(LayerGraph): | ||
def __init__( | ||
self, | ||
in_channels: int, | ||
out_channels: int, | ||
stride: int = 1, | ||
expand: int = 1, | ||
activation: Activation.FuncType = Activation.ReLU6, | ||
optimizer: Optimizer | None = None, | ||
initializer: InitUtil.InitStr = None, | ||
lambda_: float = 0.0, | ||
do_batch_norm: bool = True, | ||
momentum: float = 0.9, | ||
random_state: int | None = None, | ||
) -> None: | ||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.stride = stride | ||
self.expand = expand | ||
self.activation = activation | ||
self.optimizer = optimizer | ||
self.initializer = initializer | ||
self.lambda_ = lambda_ | ||
self.do_batch_norm = do_batch_norm | ||
self.momentum = momentum | ||
|
||
self.basic_args = { | ||
"initializer": initializer, | ||
"lambda_": lambda_, | ||
"random_state": random_state, | ||
} | ||
|
||
assert self.stride in [1, 2] | ||
self.do_shortcut = stride == 1 and in_channels == out_channels | ||
self.hid_channels = int(round(in_channels * expand)) | ||
|
||
self.init_nodes() | ||
super(_InvertedRes, self).__init__( | ||
graph={ | ||
self.rt_: [self.dw_pw_lin], | ||
self.dw_pw_lin: [self.tm_], | ||
}, | ||
root=self.rt_, | ||
term=self.tm_, | ||
) | ||
|
||
if self.expand != 1: | ||
self[self.rt_].clear() | ||
self[self.rt_].append(self.pw_) | ||
self.graph[self.pw] = [self.dw_pw_lin] | ||
|
||
if self.do_shortcut: | ||
self[self.rt_].append(self.tm_) | ||
|
||
self.build() | ||
if optimizer is not None: | ||
self.set_optimizer(optimizer) | ||
|
||
def init_nodes(self) -> None: | ||
self.rt_ = LayerNode(Identity(), name="rt_") | ||
self.pw_ = LayerNode( | ||
Sequential( | ||
Conv2D( | ||
self.in_channels, | ||
self.hid_channels, | ||
1, | ||
padding="valid", | ||
**self.basic_args, | ||
), | ||
( | ||
BatchNorm2D(self.hid_channels, self.momentum) | ||
if self.do_batch_norm | ||
else None | ||
), | ||
self.activation(), | ||
), | ||
name="pw_", | ||
) | ||
self.dw_pw_lin = LayerNode( | ||
Sequential( | ||
DepthConv2D( | ||
self.hid_channels, | ||
3, | ||
self.stride, | ||
padding="valid" if self.stride == 2 else "same", | ||
**self.basic_args, | ||
), | ||
( | ||
BatchNorm2D(self.hid_channels, self.momentum) | ||
if self.do_batch_norm | ||
else None | ||
), | ||
self.activation(), | ||
Conv2D( | ||
self.hid_channels, | ||
self.out_channels, | ||
1, | ||
padding="valid", | ||
**self.basic_args, | ||
), | ||
), | ||
name="dw_pw_lin", | ||
) | ||
self.tm_ = LayerNode(Identity(), merge_mode="sum", name="tm_") | ||
|
||
@Tensor.force_dim(4) | ||
def forward(self, X: TensorLike, is_train: bool = False) -> TensorLike: | ||
return super().forward(X, is_train) | ||
|
||
@Tensor.force_dim(4) | ||
def backward(self, d_out: TensorLike) -> TensorLike: | ||
return super().backward(d_out) | ||
|
||
@override | ||
def out_shape(self, in_shape: Tuple[int]) -> Tuple[int]: | ||
batch_size, _, height, width = in_shape | ||
return ( | ||
batch_size, | ||
self.out_channels, | ||
height // self.stride, | ||
width // self.stride, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters