diff --git a/luma/__import__.py b/luma/__import__.py index 5870c82..edb45b7 100644 --- a/luma/__import__.py +++ b/luma/__import__.py @@ -83,6 +83,9 @@ GlobalAvgPooling1D, GlobalAvgPooling2D, GlobalAvgPooling3D, + AdaptiveAvgPooling1D, + AdaptiveAvgPooling2D, + AdaptiveAvgPooling3D, LpPooling1D, LpPooling2D, LpPooling3D, @@ -304,6 +307,7 @@ Convolution1D, Convolution2D, Convolution3D, Pooling1D, Pooling2D, Pooling3D, GlobalAvgPooling1D, GlobalAvgPooling2D, GlobalAvgPooling3D, + AdaptiveAvgPooling1D, AdaptiveAvgPooling2D, AdaptiveAvgPooling3D, LpPooling1D, LpPooling2D, LpPooling3D Dropout, Dropout1D, Dropout2D, Dropout3D, BatchNorm1D, BatchNorm2D, BatchNorm3D, diff --git a/luma/neural/_layers/pool.py b/luma/neural/_layers/pool.py index 06e0d5f..6980d7a 100644 --- a/luma/neural/_layers/pool.py +++ b/luma/neural/_layers/pool.py @@ -13,6 +13,9 @@ "_GlobalAvgPool1D", "_GlobalAvgPool2D", "_GlobalAvgPool3D", + "_AdaptiveAvgPool1D", + "_AdaptiveAvgPool2D", + "_AdaptiveAvgPool3D", "_LpPool1D", "_LpPool2D", "_LpPool3D", @@ -502,6 +505,176 @@ def out_shape(self, in_shape: Tuple[int]) -> Tuple[int]: return (batch_size, channels, 1, 1, 1) +class _AdaptiveAvgPool1D(Layer): + def __init__(self, out_size: int | Tuple[int]) -> None: + super().__init__() + self.out_size = out_size + + @Tensor.force_dim(3) + def forward(self, X: Tensor, is_train: bool = False) -> Tensor: + _ = is_train + self.input_ = X + batch_size, channels, width = X.shape + target_width = self.out_size + + out = np.zeros((batch_size, channels, target_width)) + + for i in range(target_width): + start = int(np.floor(i * width / target_width)) + end = int(np.ceil((i + 1) * width / target_width)) + + out[:, :, i] = np.mean(X[:, :, start:end], axis=2) + + return out + + @Tensor.force_dim(3) + def backward(self, d_out: Tensor) -> Tensor: + X = self.input_ + _, _, width = X.shape + target_width = self.out_size + + dX = np.zeros_like(X) + for i in range(target_width): + start = int(np.floor(i * width / target_width)) + end = int(np.ceil((i + 1) * width / target_width)) + + dX[:, :, start:end] += d_out[:, :, i][:, :, None] / (end - start) + + self.dX = dX + return self.dX + + def out_shape(self, in_shape: Tuple[int]) -> Tuple[int]: + batch_size, channels, _ = in_shape + return (batch_size, channels, self.out_size) + + +class _AdaptiveAvgPool2D(Layer): + def __init__(self, out_size: Tuple[int, int]) -> None: + super().__init__() + self.out_size = out_size + + @Tensor.force_dim(4) + def forward(self, X: Tensor, is_train: bool = False) -> Tensor: + _ = is_train + self.input_ = X + batch_size, channels, height, width = X.shape + target_height, target_width = self.out_size + + out = np.zeros((batch_size, channels, target_height, target_width)) + + for i in range(target_height): + for j in range(target_width): + h_start = int(np.floor(i * height / target_height)) + h_end = int(np.ceil((i + 1) * height / target_height)) + w_start = int(np.floor(j * width / target_width)) + w_end = int(np.ceil((j + 1) * width / target_width)) + + out[:, :, i, j] = np.mean( + X[:, :, h_start:h_end, w_start:w_end], axis=(2, 3) + ) + + return out + + @Tensor.force_dim(4) + def backward(self, d_out: Tensor) -> Tensor: + X = self.input_ + _, _, height, width = X.shape + target_height, target_width = self.out_size + + dX = np.zeros_like(X) + for i in range(target_height): + for j in range(target_width): + h_start = int(np.floor(i * height / target_height)) + h_end = int(np.ceil((i + 1) * height / target_height)) + w_start = int(np.floor(j * width / target_width)) + w_end = int(np.ceil((j + 1) * width / target_width)) + + dX[:, :, h_start:h_end, w_start:w_end] += d_out[:, :, i, j][ + :, :, None, None + ] / ((h_end - h_start) * (w_end - w_start)) + + self.dX = dX + return self.dX + + def out_shape(self, in_shape: Tuple[int]) -> Tuple[int]: + batch_size, channels, _, _ = in_shape + return (batch_size, channels, *self.out_size) + + +class _AdaptiveAvgPool3D(Layer): + def __init__(self, out_size: Tuple[int, int, int]) -> None: + super().__init__() + self.out_size = out_size + + @Tensor.force_dim(5) + def forward(self, X: Tensor, is_train: bool = False) -> Tensor: + _ = is_train + self.input_ = X + batch_size, channels, depth, height, width = X.shape + target_depth, target_height, target_width = self.out_size + + out = np.zeros( + ( + batch_size, + channels, + target_depth, + target_height, + target_width, + ) + ) + for d in range(target_depth): + d_start = int(np.floor(d * depth / target_depth)) + d_end = int(np.ceil((d + 1) * depth / target_depth)) + + for i in range(target_height): + h_start = int(np.floor(i * height / target_height)) + h_end = int(np.ceil((i + 1) * height / target_height)) + + for j in range(target_width): + w_start = int(np.floor(j * width / target_width)) + w_end = int(np.ceil((j + 1) * width / target_width)) + + out[:, :, d, i, j] = np.mean( + X[:, :, d_start:d_end, h_start:h_end, w_start:w_end], + axis=(2, 3, 4), + ) + + return out + + @Tensor.force_dim(5) + def backward(self, d_out: Tensor) -> Tensor: + X = self.input_ + _, _, depth, height, width = X.shape + target_depth, target_height, target_width = self.out_size + + dX = np.zeros_like(X) + + for d in range(target_depth): + d_start = int(np.floor(d * depth / target_depth)) + d_end = int(np.ceil((d + 1) * depth / target_depth)) + + for i in range(target_height): + h_start = int(np.floor(i * height / target_height)) + h_end = int(np.ceil((i + 1) * height / target_height)) + + for j in range(target_width): + w_start = int(np.floor(j * width / target_width)) + w_end = int(np.ceil((j + 1) * width / target_width)) + + dX[:, :, d_start:d_end, h_start:h_end, w_start:w_end] += d_out[ + :, :, d, i, j + ][:, :, None, None, None] / ( + (d_end - d_start) * (h_end - h_start) * (w_end - w_start) + ) + + self.dX = dX + return self.dX + + def out_shape(self, in_shape: Tuple[int]) -> Tuple[int]: + batch_size, channels, _, _, _ = in_shape + return (batch_size, channels, *self.out_size) + + class _LpPool1D(Layer): def __init__( self, diff --git a/luma/neural/_models/__init__.py b/luma/neural/_models/__init__.py index 40063fa..2dff88f 100644 --- a/luma/neural/_models/__init__.py +++ b/luma/neural/_models/__init__.py @@ -29,3 +29,4 @@ import luma.neural._models.alex as alex import luma.neural._models.vgg as vgg import luma.neural._models.incep as incep +import luma.neural._models.resnet as resnet diff --git a/luma/neural/_models/alex.py b/luma/neural/_models/alex.py index bcddc43..55c62e8 100644 --- a/luma/neural/_models/alex.py +++ b/luma/neural/_models/alex.py @@ -204,7 +204,7 @@ def build_model(self) -> None: lambda_=self.lambda_, random_state=self.random_state, ) - + input_shape: tuple = (-1, 3, 227, 227) @Tensor.force_shape(input_shape) @@ -412,7 +412,7 @@ def build_model(self) -> None: lambda_=self.lambda_, random_state=self.random_state, ) - + input_shape: tuple = (-1, 3, 227, 227) @Tensor.force_shape(input_shape) diff --git a/luma/neural/_models/incep.py b/luma/neural/_models/incep.py index b6807f6..034f711 100644 --- a/luma/neural/_models/incep.py +++ b/luma/neural/_models/incep.py @@ -9,7 +9,7 @@ from luma.neural.base import NeuralModel from luma.neural.block import ( - IncepBlockArgs, + BaseBlockArgs, IncepBlock, IncepResBlock, ) @@ -105,7 +105,7 @@ def build_model(self) -> None: "lambda_": self.lambda_, "random_state": self.random_state, } - incep_args = IncepBlockArgs( + incep_args = BaseBlockArgs( activation=self.activation, do_batch_norm=False, **base_args, @@ -179,7 +179,7 @@ def build_model(self) -> None: self.model += Flatten() self.model += Dense(1024, self.out_features, **base_args) - + input_shape: tuple = (-1, 3, 224, 224) @Tensor.force_shape(input_shape) @@ -273,7 +273,7 @@ def build_model(self) -> None: "lambda_": self.lambda_, "random_state": self.random_state, } - incep_args = IncepBlockArgs( + incep_args = BaseBlockArgs( activation=self.activation, do_batch_norm=False, **base_args, @@ -385,7 +385,7 @@ def build_model(self) -> None: Dropout(self.dropout_rate, self.random_state), Dense(2048, self.out_features, **base_args), ) - + input_shape: tuple = (-1, 3, 299, 299) @Tensor.force_shape(input_shape) @@ -482,7 +482,7 @@ def build_model(self) -> None: "lambda_": self.lambda_, "random_state": self.random_state, } - incep_args = IncepBlockArgs( + incep_args = BaseBlockArgs( activation=self.activation, do_batch_norm=True, **base_args, @@ -600,7 +600,7 @@ def build_model(self) -> None: Dropout(self.dropout_rate, self.random_state), Dense(2048, self.out_features, **base_args), ) - + input_shape: tuple = (-1, 3, 299, 299) @Tensor.force_shape(input_shape) @@ -688,7 +688,7 @@ def __init__( self.build_model() def build_model(self) -> None: - incep_args = IncepBlockArgs( + incep_args = BaseBlockArgs( activation=self.activation, initializer=self.initializer, lambda_=self.lambda_, @@ -726,7 +726,7 @@ def build_model(self) -> None: Dropout(self.dropout_rate, self.random_state), Dense(1536, self.out_features), ) - + input_shape: tuple = (-1, 3, 299, 299) @Tensor.force_shape(input_shape) @@ -814,7 +814,7 @@ def __init__( self.build_model() def build_model(self) -> None: - incep_args = IncepBlockArgs( + incep_args = BaseBlockArgs( activation=self.activation, initializer=self.initializer, lambda_=self.lambda_, @@ -851,7 +851,7 @@ def build_model(self) -> None: Dropout(self.dropout_rate, self.random_state), Dense(1792, self.out_features), ) - + input_shape: tuple = (-1, 3, 299, 299) @Tensor.force_shape(input_shape) @@ -939,7 +939,7 @@ def __init__( self.build_model() def build_model(self) -> None: - incep_args = IncepBlockArgs( + incep_args = BaseBlockArgs( activation=self.activation, initializer=self.initializer, lambda_=self.lambda_, @@ -979,7 +979,7 @@ def build_model(self) -> None: Dropout(self.dropout_rate, self.random_state), Dense(2272, self.out_features), ) - + input_shape: tuple = (-1, 3, 299, 299) @Tensor.force_shape(input_shape) diff --git a/luma/neural/_models/lenet.py b/luma/neural/_models/lenet.py index ba37f07..e5cbf74 100644 --- a/luma/neural/_models/lenet.py +++ b/luma/neural/_models/lenet.py @@ -112,7 +112,7 @@ def build_model(self) -> None: lambda_=self.lambda_, random_state=self.random_state, ) - + input_shape: tuple = (-1, 1, 28, 28) @Tensor.force_shape(input_shape) @@ -246,7 +246,7 @@ def build_model(self) -> None: lambda_=self.lambda_, random_state=self.random_state, ) - + input_shape: tuple = (-1, 1, 32, 32) @Tensor.force_shape(input_shape) diff --git a/luma/neural/_models/resnet.py b/luma/neural/_models/resnet.py new file mode 100644 index 0000000..fcf32a8 --- /dev/null +++ b/luma/neural/_models/resnet.py @@ -0,0 +1,199 @@ +from typing import Any, Self, override, List, Optional +from dataclasses import asdict, dataclass + +from luma.core.super import Estimator, Evaluator, Optimizer, Supervised +from luma.interface.typing import Matrix, Tensor, TensorLike, Vector +from luma.interface.util import InitUtil +from luma.metric.classification import Accuracy + +from luma.neural.base import NeuralModel +from luma.neural.block import ResNetBlock, BaseBlockArgs +from luma.neural.layer import ( + Convolution2D, + Pooling2D, + AdaptiveAvgPooling2D, + BatchNorm2D, + Activation, + Dense, + Flatten, + Sequential, +) + +BasicBlock = ResNetBlock.Basic +Bottleneck = ResNetBlock.Bottleneck + + +def _make_layer( + in_channels: int, + out_channels: int, + block: ResNetBlock, + n_blocks: int, + layer_num: int, + conv_base_args: dict, + res_base_args: dataclass, + stride: int = 1, +) -> tuple[Sequential, int]: + downsampling: Optional[Sequential] = None + if stride != 1 or in_channels != out_channels * block.expansion: + downsampling = Sequential( + Convolution2D( + in_channels, + out_channels * block.expansion, + 1, + stride, + **conv_base_args, + ), + BatchNorm2D(out_channels * block.expansion), + ) + + first_block = block( + in_channels, + out_channels, + stride, + downsampling, + **asdict(res_base_args), + ) + layers: list = [(f"ResNetConv{layer_num}_1", first_block)] + + in_channels = out_channels * block.expansion + for i in range(1, n_blocks): + new_block = ( + f"ResNetConv{layer_num}_{i + 1}", + block(in_channels, out_channels, **asdict(res_base_args)), + ) + layers.append(new_block) + + return Sequential(*layers), in_channels + + +class _ResNet_18(Estimator, Supervised, NeuralModel): + def __init__( + self, + activation: Activation.FuncType = Activation.ReLU, + initializer: InitUtil.InitStr = None, + out_features: int = 1000, + batch_size: int = 128, + n_epochs: int = 100, + valid_size: float = 0.1, + lambda_: float = 0.0, + momentum: float = 0.9, + early_stopping: bool = False, + patience: int = 10, + shuffle: bool = True, + random_state: int | None = None, + deep_verbose: bool = False, + ) -> None: + self.activation = activation + self.initializer = initializer + self.out_features = out_features + self.lambda_ = lambda_ + self.momentum = momentum + self.shuffle = shuffle + self.random_state = random_state + self._fitted = False + + super().__init__( + batch_size, + n_epochs, + valid_size, + early_stopping, + patience, + shuffle, + random_state, + deep_verbose, + ) + super().init_model() + self.model = Sequential() + + self.feature_sizes_ = [ + [3, 64], + [64, 64, 64, 64], + [128, 128, 128, 128], + [256, 256, 256, 256], + [512, 512, 512, 512], + ] + self.feature_shapes_ = [ + self._get_feature_shapes(sizes) for sizes in self.feature_sizes_ + ] + + self.set_param_ranges( + { + "out_features": ("0<,+inf", int), + "batch_size": ("0<,+inf", int), + "n_epochs": ("0<,+inf", int), + "valid_size": ("0<,<1", None), + "momentum": ("0,1", None), + "dropout_rate": ("0,1", None), + "lambda_": ("0,+inf", None), + "patience": ("0<,+inf", int), + } + ) + self.check_param_ranges() + self.build_model() + + def build_model(self) -> None: + base_args = { + "initializer": self.initializer, + "lambda_": self.lambda_, + "random_state": self.random_state, + } + res_args = BaseBlockArgs( + activation=self.activation, + do_batch_norm=True, + momentum=self.momentum, + **base_args, + ) + + self.model.extend( + Convolution2D(3, 64, 7, 2, 3, **base_args), + BatchNorm2D(64, self.momentum), + self.activation(), + Pooling2D(3, 2, "max", "same"), + ) + self.layer_2, in_channels = _make_layer( + 64, 64, BasicBlock, 2, 2, base_args, res_args + ) + self.layer_3, in_channels = _make_layer( + in_channels, 128, BasicBlock, 2, 3, base_args, res_args, stride=2 + ) + self.layer_4, in_channels = _make_layer( + in_channels, 256, BasicBlock, 2, 4, base_args, res_args, stride=2 + ) + self.layer_5, in_channels = _make_layer( + in_channels, 512, BasicBlock, 2, 5, base_args, res_args, stride=2 + ) + + self.model.extend( + self.layer_2, + self.layer_3, + self.layer_4, + self.layer_5, + deep_add=True, + ) + self.model.extend( + AdaptiveAvgPooling2D((1, 1)), + Flatten(), + Dense(512 * BasicBlock.expansion, self.out_features, **base_args), + ) + + input_shape: tuple = (-1, 3, 224, 224) + + @Tensor.force_shape(input_shape) + def fit(self, X: Tensor, y: Matrix) -> Self: + return super(_ResNet_18, self).fit_nn(X, y) + + @override + @Tensor.force_shape(input_shape) + def predict(self, X: Tensor, argmax: bool = True) -> Matrix | Vector: + return super(_ResNet_18, self).predict_nn(X, argmax) + + @override + @Tensor.force_shape(input_shape) + def score( + self, + X: Tensor, + y: Matrix, + metric: Evaluator = Accuracy, + argmax: bool = True, + ) -> float: + return super(_ResNet_18, self).score_nn(X, y, metric, argmax) diff --git a/luma/neural/_models/vgg.py b/luma/neural/_models/vgg.py index bf60d2d..9642e1b 100644 --- a/luma/neural/_models/vgg.py +++ b/luma/neural/_models/vgg.py @@ -167,7 +167,7 @@ def build_model(self) -> None: lambda_=self.lambda_, random_state=self.random_state, ) - + input_shape: tuple = (-1, 3, 224, 224) @Tensor.force_shape(input_shape) @@ -346,7 +346,7 @@ def build_model(self) -> None: lambda_=self.lambda_, random_state=self.random_state, ) - + input_shape: tuple = (-1, 3, 224, 224) @Tensor.force_shape(input_shape) @@ -740,7 +740,7 @@ def build_model(self) -> None: lambda_=self.lambda_, random_state=self.random_state, ) - + input_shape: tuple = (-1, 3, 224, 224) @Tensor.force_shape(input_shape) diff --git a/luma/neural/_specials/resnet.py b/luma/neural/_specials/resnet.py index 67b4a52..273fd31 100644 --- a/luma/neural/_specials/resnet.py +++ b/luma/neural/_specials/resnet.py @@ -1,4 +1,4 @@ -from typing import Tuple, override +from typing import Tuple, override, ClassVar from luma.core.super import Optimizer from luma.interface.typing import Tensor, TensorLike @@ -56,6 +56,8 @@ def __init__( if optimizer is not None: self.set_optimizer(optimizer) + expansion: ClassVar[int] = 1 + def init_nodes(self) -> None: self.rt_ = LayerNode(Identity(), name="rt_") self.conv_ = LayerNode( @@ -71,12 +73,14 @@ def init_nodes(self) -> None: self.activation(), Convolution2D( self.out_channels, - self.out_channels, + self.out_channels * _Basic.expansion, 3, - self.stride, **self.basic_args ), - BatchNorm2D(self.out_channels, self.momentum), + BatchNorm2D( + self.out_channels * _Basic.expansion, + self.momentum, + ), ), name="conv_", ) @@ -100,11 +104,7 @@ def backward(self, d_out: TensorLike) -> TensorLike: @override def out_shape(self, in_shape: Tuple[int]) -> Tuple[int]: - batch_size, _, height, width = in_shape - if self.downsampling: - _, _, height, width = self.downsampling.out_shape(in_shape) - - return batch_size, self.out_channels, height, width + return self.conv_.out_shape(in_shape) class _Bottleneck(LayerGraph): @@ -113,7 +113,6 @@ def __init__( in_channels: int, out_channels: int, stride: int = 1, - expansion: int = 4, downsampling: LayerLike | None = None, activation: Activation.FuncType = Activation.ReLU, optimizer: Optimizer | None = None, @@ -126,7 +125,6 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels self.stride = stride - self.expansion = expansion self.downsampling = downsampling self.activation = activation self.optimizer = optimizer @@ -152,16 +150,18 @@ def __init__( term=self.sum_, ) + self.build() + if optimizer is not None: + self.set_optimizer(optimizer) + + expansion: ClassVar[int] = 4 + def init_nodes(self) -> None: self.rt_ = LayerNode(Identity(), name="rt_") self.conv_ = LayerNode( Sequential( Convolution2D( - self.in_channels, - self.out_channels, - 1, - self.stride, - **self.basic_args + self.in_channels, self.out_channels, 1, **self.basic_args ), BatchNorm2D(self.out_channels, self.momentum), self.activation(), @@ -176,13 +176,12 @@ def init_nodes(self) -> None: self.activation(), Convolution2D( self.out_channels, - self.out_channels * self.expansion, + self.out_channels * _Bottleneck.expansion, 1, - self.stride, **self.basic_args ), BatchNorm2D( - self.out_channels * self.expansion, + self.out_channels * _Bottleneck.expansion, self.momentum, ), ), @@ -208,13 +207,4 @@ def backward(self, d_out: TensorLike) -> TensorLike: @override def out_shape(self, in_shape: Tuple[int]) -> Tuple[int]: - batch_size, _, height, width = in_shape - if self.downsampling: - _, _, height, width = self.downsampling.out_shape(in_shape) - - return ( - batch_size, - self.out_channels * self.expansion, - height, - width, - ) + return self.conv_.out_shape(in_shape) diff --git a/luma/neural/block.py b/luma/neural/block.py index 608d2ab..032ad16 100644 --- a/luma/neural/block.py +++ b/luma/neural/block.py @@ -549,7 +549,7 @@ def backward(self, d_out: TensorLike) -> TensorLike: @dataclass -class IncepBlockArgs: +class BaseBlockArgs: activation: Activation.FuncType optimizer: Optimizer | None = None initializer: InitUtil.InitStr = None @@ -895,29 +895,29 @@ class ResNetBlock: References ---------- `ResNet-(18, 34, 50, 101, 152)` : - [1] He, Kaiming, et al. “Deep Residual Learning for Image - Recognition.” Proceedings of the IEEE Conference on Computer + [1] He, Kaiming, et al. “Deep Residual Learning for Image + Recognition.” Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016, pp. 770-778. """ - class Basic(_specials.resnet._Basic): + class Basic(_specials.resnet._Basic): """ Basic convolution block used in `ResNet-18` and `ResNet-34`. - + Parameters ---------- `downsampling` : LayerLike, optional An additional layer to the input signal which reduces its grid size to perform a downsampling - + See [1] also for additional information. """ - class Bottleneck(_specials.resnet._Bottleneck): + class Bottleneck(_specials.resnet._Bottleneck): """ Bottleneck block used in `ResNet-(50, 101, 152)`. - + Parameters ---------- `downsampling` : LayerLike, optional @@ -925,6 +925,6 @@ class Bottleneck(_specials.resnet._Bottleneck): its grid size to perform a downsampling `expansion` : int, default=4 Expanding factor for the number of output channels. - + See [1] also for additional information. """ diff --git a/luma/neural/layer.py b/luma/neural/layer.py index e084f66..402b02c 100644 --- a/luma/neural/layer.py +++ b/luma/neural/layer.py @@ -18,6 +18,9 @@ "GlobalAvgPooling1D", "GlobalAvgPooling2D", "GlobalAvgPooling3D", + "AdaptiveAvgPooling1D", + "AdaptiveAvgPooling2D", + "AdaptiveAvgPooling3D", "LpPooling1D", "LpPooling2D", "LpPooling3D", @@ -434,6 +437,63 @@ def __init__(self) -> None: super().__init__() +class AdaptiveAvgPooling1D(_layers.pool._AdaptiveAvgPool1D): + """ + Adaptive average pooling layer for 1-dimensional data. + + Adaptive Average Pooling adjusts input dimensions to produce a fixed-size + output by averaging over dynamically sized regions. It's useful for consistent + output sizes in neural networks, regardless of input shape. + + Parameters + ---------- + `out_size` : int or tuple of int + An output shape to be fixed + + """ + + def __init__(self, out_size: int | Tuple[int]) -> None: + super().__init__(out_size) + + +class AdaptiveAvgPooling2D(_layers.pool._AdaptiveAvgPool2D): + """ + Adaptive average pooling layer for 2-dimensional data. + + Adaptive Average Pooling adjusts input dimensions to produce a fixed-size + output by averaging over dynamically sized regions. It's useful for consistent + output sizes in neural networks, regardless of input shape. + + Parameters + ---------- + `out_size` : int or tuple of int + An output shape to be fixed + + """ + + def __init__(self, out_size: Tuple[int]) -> None: + super().__init__(out_size) + + +class AdaptiveAvgPooling3D(_layers.pool._AdaptiveAvgPool3D): + """ + Adaptive average pooling layer for 3-dimensional data. + + Adaptive Average Pooling adjusts input dimensions to produce a fixed-size + output by averaging over dynamically sized regions. It's useful for consistent + output sizes in neural networks, regardless of input shape. + + Parameters + ---------- + `out_size` : int or tuple of int + An output shape to be fixed + + """ + + def __init__(self, out_size: Tuple[int]) -> None: + super().__init__(out_size) + + class LpPooling1D(_layers.pool._LpPool1D): """ Lp pooling layer for 1-dimensional data. diff --git a/luma/neural/model.py b/luma/neural/model.py index 47f9e03..f9b3bb0 100644 --- a/luma/neural/model.py +++ b/luma/neural/model.py @@ -24,6 +24,7 @@ "Inception_V4", "InceptionResNet_V1", "InceptionResNet_V2", + "ResNet_18", ) @@ -1856,3 +1857,102 @@ def __init__( random_state, deep_verbose, ) + + +class ResNet_18(_models.resnet._ResNet_18): + """ + ResNet18 is a 18-layer deep neural network that uses residual blocks + to improve training by learning residuals, helping prevent vanishing + gradients and enabling better performance in image recognition tasks. + + Structure + --------- + Input: + ```py + Tensor[..., 3, 224, 224] + ``` + Residual Blocks: + ```py + Convolution2D(3, 64, filter_size=7, stride=2) -> # conv1 + + 2x ResNetBlock.Basic(64, 64) -> # conv2 + 2x ResNetBlock.Basic(128, 128, stride=2) -> # conv3 + 2x ResNetBlock.Basic(256, 256, stride=2) -> # conv4 + 2x ResNetBlock.Basic(512, 512, stride=2) -> # conv5 + + AdaptiveAvgPooling2D((1, 1)) -> # avg pool + ``` + Fully Connected Layers: + ```py + Flatten -> Dense(512, 1000) + ``` + Output: + ```py + Matrix[..., 1000] + ``` + Parameter Size: + ```txt + 11,688,512 weights, 5,800 biases -> 11,694,312 params + ``` + Parameters + ---------- + `activation` : FuncType, default=Activation.ReLU + Type of activation function + `initializer` : InitStr, default=None + Type of weight initializer + `out_features` : int, default=1000 + Number of output features + `batch_size` : int, default=100 + Size of a single mini-batch + `n_epochs` : int, default=100 + Number of epochs for training + `valid_size` : float, default=0.1 + Fractional size of validation set + `lambda_` : float, default=0.0 + L2 regularization strength + `early_stopping` : bool, default=False + Whether to early-stop the training when the valid score stagnates + `patience` : int, default=10 + Number of epochs to wait until early-stopping + `shuffle` : bool, default=True + Whethter to shuffle the data at the beginning of every epoch + + References + ---------- + 1. He, Kaiming, et al. “Deep Residual Learning for Image Recognition.” + Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition + (CVPR), 2016, pp. 770-778. + + """ + + def __init__( + self, + activation: Activation.FuncType = Activation.ReLU, + initializer: InitUtil.InitStr = None, + out_features: int = 1000, + batch_size: int = 128, + n_epochs: int = 100, + valid_size: float = 0.1, + lambda_: float = 0, + momentum: float = 0.9, + early_stopping: bool = False, + patience: int = 10, + shuffle: bool = True, + random_state: int | None = None, + deep_verbose: bool = False, + ) -> None: + super().__init__( + activation, + initializer, + out_features, + batch_size, + n_epochs, + valid_size, + lambda_, + momentum, + early_stopping, + patience, + shuffle, + random_state, + deep_verbose, + )