Skip to content

Commit

Permalink
_models/resnet.py args bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ChanLumerico committed Aug 11, 2024
1 parent 5d3dbb7 commit b5692bf
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions luma/neural/_models/resnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, Self, override, List, Optional
from dataclasses import asdict
from dataclasses import asdict, dataclass

from luma.core.super import Estimator, Evaluator, Optimizer, Supervised
from luma.interface.typing import Matrix, Tensor, TensorLike, Vector
Expand Down Expand Up @@ -30,7 +30,7 @@ def _make_layer(
n_blocks: int,
layer_num: int,
conv_base_args: dict,
res_base_args: dict,
res_base_args: dataclass,
stride: int = 1,
) -> tuple[Sequential, int]:
downsampling: Optional[Sequential] = None
Expand All @@ -51,15 +51,15 @@ def _make_layer(
out_channels,
stride,
downsampling,
**res_base_args,
**asdict(res_base_args),
)
layers: list = [(f"ResNetConv{layer_num}_1", first_block)]

in_channels = out_channels * block.expansion
for i in range(1, n_blocks):
new_block = (
f"ResNetConv{layer_num}_{i + 1}",
block(in_channels, out_channels, **res_base_args),
block(in_channels, out_channels, **asdict(res_base_args)),
)
layers.append(new_block)

Expand Down Expand Up @@ -151,7 +151,7 @@ def build_model(self) -> None:
Pooling2D(3, 2, "max", "same"),
)
self.layer_2, in_channels = _make_layer(
64, 64, BasicBlock, 2, 2, base_args, asdict(res_args)
64, 64, BasicBlock, 2, 2, base_args, res_args
)
self.layer_3, in_channels = _make_layer(
in_channels, 128, BasicBlock, 2, 3, base_args, res_args, stride=2
Expand Down

0 comments on commit b5692bf

Please sign in to comment.