Skip to content

Commit

Permalink
MobileNetBlock eip
Browse files Browse the repository at this point in the history
  • Loading branch information
ChanLumerico committed Aug 24, 2024
1 parent 8617bca commit a2c4aa9
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
3 changes: 2 additions & 1 deletion luma/__import__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
IncepResBlock,
ResNetBlock,
XceptionBlock,
MobileNetBlock,
)
from luma.neural.model import (
SimpleMLP,
Expand Down Expand Up @@ -349,7 +350,7 @@
ConvBlock1D, ConvBlock2D, ConvBlock3D,
SeparableConv1D, SeparableConv2D, SeparableConv3D,
DenseBlock, IncepBlock, IncepResBlock, ResNetBlock,
XceptionBlock
XceptionBlock, MobileNetBlock

LayerNode, LayerGraph

Expand Down
13 changes: 8 additions & 5 deletions luma/neural/block/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from luma.interface.typing import ClassType
from luma.interface.util import InitUtil

from luma.neural.layer import *
from luma.neural.block import (
incep_v1,
incep_v2,
Expand All @@ -42,13 +41,14 @@
"IncepResBlock",
"ResNetBlock",
"XceptionBlock",
"MobileNetBlock",
)


@dataclass
class ConvBlockArgs:
filter_size: Tuple[int, ...] | int
activation: Activation.FuncType
activation: callable
optimizer: Optimizer | None = None
initializer: InitUtil.InitStr = None
padding: Tuple[int, ...] | int | Literal["same", "valid"] = "same"
Expand Down Expand Up @@ -359,7 +359,7 @@ class SeparableConv3D(standard._SeparableConv3D):

@dataclass
class DenseBlockArgs:
activation: Activation.FuncType
activation: callable
optimizer: Optimizer | None = None
initializer: InitUtil.InitStr = None
lambda_: float = 0.0
Expand Down Expand Up @@ -410,7 +410,7 @@ class DenseBlock(standard._DenseBlock):

@dataclass
class BaseBlockArgs:
activation: Activation.FuncType
activation: callable
optimizer: Optimizer | None = None
initializer: InitUtil.InitStr = None
lambda_: float = 0.0
Expand Down Expand Up @@ -868,4 +868,7 @@ class Exit(xception._Exit):
"""


class InvertedResBlock(mobile._InvertedRes): ...
@ClassType.non_instantiable()
class MobileNetBlock:

class InvertedRes(mobile._InvertedRes): ...
2 changes: 1 addition & 1 deletion luma/neural/block/mobile.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
if self.expand != 1:
self[self.rt_].clear()
self[self.rt_].append(self.pw_)
self.graph[self.pw] = [self.dw_pw_lin]
self.graph[self.pw_] = [self.dw_pw_lin]

if self.do_shortcut:
self[self.rt_].append(self.tm_)
Expand Down

0 comments on commit a2c4aa9

Please sign in to comment.