Skip to content

Commit

Permalink
_Mobile_V3_Large
Browse files Browse the repository at this point in the history
  • Loading branch information
ChanLumerico committed Sep 7, 2024
1 parent 1822b60 commit 32254d0
Showing 1 changed file with 58 additions and 1 deletion.
59 changes: 58 additions & 1 deletion luma/neural/model/mobile.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def __init__(
)
self.check_param_ranges()
self.build_model()

def build_model(self) -> None:
inverted_res_config: List[list] = [
[3, 16, 16, False, "RE", 1],
Expand All @@ -512,4 +512,61 @@ def build_model(self) -> None:
[5, 960, 160, True, "HS", 1],
[5, 960, 160, True, "HS", 1],
]
base_args = {
"initializer": self.initializer,
"lambda_": self.lambda_,
"random_state": self.random_state,
}

self.model.extend(
Conv2D(3, 16, 3, 2, **base_args),
BatchNorm2D(16, self.momentum),
Activation.HardSwish(),
)
in_ = 16
for i, (f, exp, out, b, a, s) in enumerate(inverted_res_config):
block = InvertedRes_SE if b else InvertedRes
act = Activation.HardSwish if a == "HS" else Activation.ReLU
self.model += (
f"InvRes_{i + 1}",
block(in_, out, f, s, exp, activation=act, **base_args),
)
in_ = out

self.model.extend(
Conv2D(160, 960, 1, 1, **base_args),
BatchNorm2D(960, self.momentum),
Activation.HardSwish(),
)
self.model.extend(
GlobalAvgPool2D(),
Conv2D(960, 1280, 1, 1, **base_args),
Activation.HardSwish(),
)
self.model.extend(
Flatten(),
Dropout(self.dropout_rate),
Dense(1280, self.out_features, **base_args),
)

input_shape: ClassVar[int] = (-1, 3, 224, 224)

@Tensor.force_shape(input_shape)
def fit(self, X: Tensor, y: Matrix) -> Self:
return super(_Mobile_V3_Large, self).fit_nn(X, y)

@override
@Tensor.force_shape(input_shape)
def predict(self, X: Tensor, argmax: bool = True) -> Matrix | Vector:
return super(_Mobile_V3_Large, self).predict_nn(X, argmax)

@override
@Tensor.force_shape(input_shape)
def score(
self,
X: Tensor,
y: Matrix,
metric: Evaluator = Accuracy,
argmax: bool = True,
) -> float:
return super(_Mobile_V3_Large, self).score_nn(X, y, metric, argmax)

0 comments on commit 32254d0

Please sign in to comment.