Skip to content

Commit

Permalink
Inception v1 wip
Browse files Browse the repository at this point in the history
  • Loading branch information
ChanLumerico committed May 24, 2024
1 parent a62e4cf commit 966247a
Show file tree
Hide file tree
Showing 6 changed files with 301 additions and 9 deletions.
3 changes: 3 additions & 0 deletions luma/neural/_layers/_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,6 @@ def out_shape(self, in_shape: Tuple[int]) -> Tuple[int]:
out_width = 1 + (padded_w - self.filter_size) // self.stride

return (batch_size, channels, out_depth, out_height, out_width)


# TODO: Implement _GlobalAvgPool
205 changes: 205 additions & 0 deletions luma/neural/_models/_inception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from typing import Self, override
from dataclasses import asdict

from luma.core.super import Estimator, Evaluator, Optimizer, Supervised
from luma.interface.typing import Matrix, Tensor, Vector
from luma.interface.util import InitUtil
from metric.classification import Accuracy

from luma.neural.base import Loss, NeuralModel
from luma.neural.loss import CrossEntropy
from luma.neural.block import (
InceptionBlock,
InceptionBlockArgs,
)
from luma.neural.layer import (
Convolution2D,
Pooling2D,
Activation,
Dropout,
Dense,
Flatten,
Sequential,
)


__all__ = "_Inception_V1"


class _Inception_V1(Estimator, Supervised, NeuralModel):
def __init__(
self,
optimizer: Optimizer,
activation: Activation.FuncType = Activation.ReLU,
loss: Loss = CrossEntropy(),
initializer: InitUtil.InitStr = None,
out_features: int = 1000,
batch_size: int = 128,
n_epochs: int = 100,
learning_rate: float = 0.01,
valid_size: float = 0.1,
lambda_: float = 0.0,
dropout_rate: float = 0.4,
early_stopping: bool = False,
patience: int = 10,
shuffle: bool = True,
random_state: int | None = None,
deep_verbose: bool = False,
) -> None:
self.activation = activation
self.optimizer = optimizer
self.loss = loss
self.initializer = initializer
self.out_features = out_features
self.lambda_ = lambda_
self.dropout_rate = dropout_rate
self.shuffle = shuffle
self.random_state = random_state
self._fitted = False

super().__init__(
batch_size,
n_epochs,
learning_rate,
valid_size,
early_stopping,
patience,
deep_verbose,
)
super().__init_model__()
self.model = Sequential()
self.optimizer.set_params(learning_rate=self.learning_rate)
self.model.set_optimizer(optimizer=self.optimizer)

self.feature_sizes_ = [
[3, 64, 64, 192],
[192, 256, 480, 512, 512, 512, 528, 832, 832],
# [1024, self.out_features],
]
self.feature_shapes_ = [
self._get_feature_shapes(self.feature_sizes_[0]),
self._get_feature_shapes(self.feature_sizes_[1]),
]

self.set_param_ranges(
{
"out_features": ("0<,+inf", int),
"batch_size": ("0<,+inf", int),
"n_epochs": ("0<,+inf", int),
"learning_rate": ("0<,+inf", None),
"valid_size": ("0<,<1", None),
"dropout_rate": ("0,1", None),
"lambda_": ("0,+inf", None),
"patience": (f"0<,+inf", int),
}
)
self.check_param_ranges()
self._build_model()

def _build_model(self) -> None:
base_args = {
"initializer": self.initializer,
"optimizer": self.optimizer,
"lambda_": self.lambda_,
"random_state": self.random_state,
}
incep_args = InceptionBlockArgs(
activation=self.activation,
do_batch_norm=False,
**base_args,
)

self.model.extend(
Convolution2D(3, 64, filter_size=7, stride=2, padding=3, **base_args),
self.activation(),
Pooling2D(3, stride=2, mode="max", padding="same"),
)

self.model.extend(
Convolution2D(64, 64, filter_size=1, padding="valid", **base_args),
self.activation(),
Convolution2D(64, 192, filter_size=3, padding="valid", **base_args),
self.activation(),
Pooling2D(3, stride=2, mode="max", padding="same"),
)

self.model.extend(
(
"Inception_3a",
InceptionBlock(192, 64, 96, 128, 16, 32, 32, **asdict(incep_args)),
),
(
"Inception_3b",
InceptionBlock(256, 128, 128, 192, 32, 96, 64, **asdict(incep_args)),
),
Pooling2D(3, stride=2, mode="max", padding="same"),
deep_add=False,
)

self.model.extend(
(
"Inception_4a",
InceptionBlock(480, 192, 96, 208, 16, 48, 64, **asdict(incep_args)),
),
(
"Inception_4b",
InceptionBlock(512, 160, 112, 224, 24, 64, 64, **asdict(incep_args)),
),
(
"Inception_4c",
InceptionBlock(512, 128, 128, 256, 24, 64, 64, **asdict(incep_args)),
),
(
"Inception_4d",
InceptionBlock(512, 112, 144, 288, 32, 64, 64, **asdict(incep_args)),
),
(
"Inception_4e",
InceptionBlock(528, 256, 160, 320, 32, 128, 128, **asdict(incep_args)),
),
Pooling2D(3, stride=2, mode="max", padding="same"),
deep_add=False,
)

self.model.extend(
(
"Inception_5a",
InceptionBlock(832, 256, 160, 320, 32, 128, 128, **asdict(incep_args)),
),
(
"Inception_5b",
InceptionBlock(832, 384, 192, 384, 48, 128, 128, **asdict(incep_args)),
),
# Pooling2D(7, stride=1, mode="avg"), TODO: Implement Global Avg. Pooling
Dropout(self.dropout_rate, random_state=self.random_state),
deep_add=False,
)

self.model += Flatten()
self.model += Dense(1024, self.out_features, **base_args)

@Tensor.force_dim(4)
def fit(self, X: Tensor, y: Matrix) -> Self:
return super(_Inception_V1, self).fit_nn(X, y)

@override
@Tensor.force_dim(4)
def predict(self, X: Tensor, argmax: bool = True) -> Matrix | Vector:
return super(_Inception_V1, self).predict_nn(X, argmax)

@override
@Tensor.force_dim(4)
def score(
self,
X: Tensor,
y: Matrix,
metric: Evaluator = Accuracy,
argmax: bool = True,
) -> float:
return super(_Inception_V1, self).score_nn(X, y, metric, argmax)


from luma.neural.optimizer import AdamOptimizer

model = _Inception_V1(optimizer=AdamOptimizer())
model.summarize(in_shape=(-1, 3, 224, 224))
8 changes: 4 additions & 4 deletions luma/neural/_models/_vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
early_stopping: bool = False,
patience: int = 10,
shuffle: bool = True,
random_state: int = None,
random_state: int | None = None,
deep_verbose: bool = False,
) -> None:
self.activation = activation
Expand Down Expand Up @@ -214,7 +214,7 @@ def __init__(
early_stopping: bool = False,
patience: int = 10,
shuffle: bool = True,
random_state: int = None,
random_state: int | None = None,
deep_verbose: bool = False,
) -> None:
self.activation = activation
Expand Down Expand Up @@ -398,7 +398,7 @@ def __init__(
early_stopping: bool = False,
patience: int = 10,
shuffle: bool = True,
random_state: int = None,
random_state: int | None = None,
deep_verbose: bool = False,
) -> None:
self.activation = activation
Expand Down Expand Up @@ -594,7 +594,7 @@ def __init__(
early_stopping: bool = False,
patience: int = 10,
shuffle: bool = True,
random_state: int = None,
random_state: int | None = None,
deep_verbose: bool = False,
) -> None:
self.activation = activation
Expand Down
78 changes: 77 additions & 1 deletion luma/neural/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,85 @@ def backward(self, d_out: TensorLike) -> TensorLike:
return d_out


@dataclass
class InceptionBlockArgs:
activation: Activation.FuncType
optimizer: Optimizer | None = None
initializer: InitUtil.InitStr = None
lambda_: float = 0.0
do_batch_norm: float = True
momentum: float = 0.9
random_state: int | None = None


class InceptionBlock(Sequential):
"""
TODO: Finish docstring
Inception block for neural networks.
An inception block allows for multiple convolutional operations to be
performed in parallel. This structure is inspired by the Inception modules
of Google's Inception network, and it concatenates the outputs of different
convolutions to capture rich and varied features from input data.
Structure
---------
1x1 Branch:
```py
Convolution2D(filter_size=1) -> Optional[BatchNorm2D] -> Activation
```
3x3 Branch:
```py
Convolution2D(filter_size=1) -> Optional[BatchNorm2D] -> Activation ->
Convolution2D(filter_size=3) -> Optional[BatchNorm2D] -> Activation
```
5x5 Branch:
```py
Convolution2D(filter_size=1) -> Optional[BatchNorm2D] -> Activation ->
Convolution2D(filter_size=5) -> Optional[BatchNorm2D] -> Activation
```
Pooling Branch:
```py
Pooling2D(3, 1, mode="max", padding="same") ->
Convolution2D(filter_size=1) -> Optional[BatchNorm2D] -> Activation
```
Parameters
----------
`in_channels` : int
Number of input channels.
`out_1x1` : int
Number of output channels for the 1x1 convolution.
`red_3x3` : int
Number of output channels for the dimension reduction before
the 3x3 convolution.
`out_3x3` : int
Number of output channels for the 3x3 convolution.
`red_5x5` : int
Number of output channels for the dimension reduction before
the 5x5 convolution.
`out_5x5` : int
Number of output channels for the 5x5 convolution.
`out_pool` : int
Number of output channels for the 1x1 convolution after max pooling.
`activation` : FuncType
Type of activation function
`optimizer` : Optimizer, optional, default=None
Type of optimizer for weight update
`initializer` : InitStr, default=None
Type of weight initializer
`lambda_` : float, default=0.0
L2 regularization strength
`do_batch_norm` : bool, default=True
Whether to perform batch normalization
`momentum` : float, default=0.9
Momentum for batch normalization
Notes
-----
- The input `X` must have the form of a 4D-array (`Tensor`).
```py
X.shape = (batch_size, height, width, channels)
```
"""

def __init__(
Expand Down
14 changes: 11 additions & 3 deletions luma/neural/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,13 +845,21 @@ def add(self, layer: Layer | tuple[str, Layer] | None) -> None:
if self.optimizer is not None:
self.set_optimizer(self.optimizer)

def extend(self, *layers: Self | Layer | tuple[str, Layer] | None) -> None:
def extend(
self,
*layers: Self | Layer | tuple[str, Layer] | None,
deep_add: bool = True,
) -> None:
for layer in layers:
if hasattr(layer, "layers"):
new_layer = layer
if isinstance(layer, tuple):
name, layer = layer
new_layer = (name, layer)
if hasattr(layer, "layers") and deep_add:
for sub_layer in layer.layers:
self.add(sub_layer)
continue
self.add(layer)
self.add(new_layer)

@override
@property
Expand Down
2 changes: 1 addition & 1 deletion luma/neural/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from luma.neural.layer import Activation
from luma.neural.loss import CrossEntropy

from ._models import _simple, _lenet, _alex, _vgg
from ._models import _simple, _lenet, _alex, _vgg, _inception


__all__ = (
Expand Down

0 comments on commit 966247a

Please sign in to comment.