Skip to content

Commit

Permalink
feat+fix(gan+compiles): metrics in compile + fixed labels in GAN
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Dec 10, 2024
1 parent 2357dfd commit 6bcd35d
Showing 1 changed file with 50 additions and 15 deletions.
65 changes: 50 additions & 15 deletions neuralnetlib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def train_on_batch(self, x_batch: np.ndarray, y_batch: np.ndarray) -> float:
pass

@abstractmethod
def compile(self, loss_function, optimizer, verbose: bool = False):
def compile(self, loss_function, optimizer, verbose: bool = False, metrics: list[Metric] = None):
pass

@abstractmethod
Expand Down Expand Up @@ -139,11 +139,12 @@ def add(self, layer: Layer):
f"Invalid activation function: {activation_attr}")
self.layers.append(activation)

def compile(self, loss_function: LossFunction | str, optimizer: Optimizer | str, verbose: bool = False):
def compile(self, loss_function: LossFunction | str, optimizer: Optimizer | str, verbose: bool = False, metrics: list[Metric] = None):
self.loss_function = loss_function if isinstance(loss_function, LossFunction) else LossFunction.from_name(
loss_function)
self.optimizer = optimizer if isinstance(
optimizer, Optimizer) else Optimizer.from_name(optimizer)
self.metrics = metrics
if verbose:
print(str(self))

Expand Down Expand Up @@ -912,7 +913,8 @@ def compile(self,
decoder_loss: LossFunction | str = None,
encoder_optimizer: Optimizer | str = None,
decoder_optimizer: Optimizer | str = None,
verbose: bool = False):
verbose: bool = False,
metrics: list[Metric] = None):

if encoder_loss is None:
encoder_loss = decoder_loss
Expand All @@ -936,6 +938,8 @@ def compile(self,
self.decoder_optimizer = decoder_optimizer if isinstance(
decoder_optimizer, Optimizer) else Optimizer.from_name(decoder_optimizer)

self.metrics = metrics

if verbose:
print(str(self))

Expand Down Expand Up @@ -1746,12 +1750,16 @@ def prepare_data(self, x_train: np.ndarray, y_train: np.ndarray) -> tuple:
def compile(self,
loss_function: LossFunction | str,
optimizer: Optimizer | str,
verbose: bool = False) -> None:
verbose: bool = False,
metrics: list | None = None) -> None:

self.loss_function = loss_function if isinstance(
loss_function, LossFunction) else LossFunction.from_name(loss_function)
self.optimizer = optimizer if isinstance(
optimizer, Optimizer) else Optimizer.from_name(optimizer)

self.metrics = metrics

if verbose:
print(str(self))

Expand Down Expand Up @@ -2483,7 +2491,8 @@ def compile(
generator_optimizer: Optimizer | str,
discriminator_optimizer: Optimizer | str,
loss_function: LossFunction | str = 'bce',
verbose: bool = False
verbose: bool = False,
metrics: list | None = None
):
self.generator = generator
self.discriminator = discriminator
Expand Down Expand Up @@ -2534,6 +2543,8 @@ def compile(
self.discriminator.loss_function = self.discriminator_loss
self.discriminator.optimizer = self.discriminator_optimizer

self.metrics = metrics

if verbose:
print(str(self))

Expand Down Expand Up @@ -2689,7 +2700,8 @@ def fit(
plot_interval: int = 1,
fixed_noise: np.ndarray | None = None,
fixed_labels: np.ndarray | None = None,
n_gen_samples: int | None = None
n_gen_samples: int | None = None,
visualization_grid: tuple[int, int] = (8, 8)
) -> dict:

history = History({
Expand Down Expand Up @@ -2718,9 +2730,23 @@ def fit(
history[f'val_discriminator_{metric.name}'] = []
history[f'val_generator_{metric.name}'] = []

if plot_generated and fixed_noise is None:
rng = np.random.default_rng(self.random_state)
fixed_noise = rng.normal(0, 1, (80, self.latent_dim))

if plot_generated:
if fixed_noise is None:
n_samples = visualization_grid[0] * visualization_grid[1]
rng = np.random.default_rng(self.random_state)
fixed_noise = rng.normal(0, 1, (n_samples, self.latent_dim))
if self.n_classes is not None and fixed_labels is None:
n_samples = fixed_noise.shape[0]
samples_per_class = n_samples // self.n_classes

samples_distribution = np.full(self.n_classes, samples_per_class)

remainder = n_samples % self.n_classes
if remainder > 0:
samples_distribution[:remainder] += 1

fixed_labels = np.repeat(np.arange(self.n_classes), samples_distribution)

callbacks = callbacks if callbacks is not None else []

Expand Down Expand Up @@ -2881,7 +2907,7 @@ def fit(
f'- val_g_loss: {format_number(val_g_loss)}', end='')

if plot_generated and (epoch + 1) % plot_interval == 0:
self._plot_samples(fixed_noise, epoch + 1)
self._plot_samples(fixed_noise, epoch + 1, fixed_labels, visualization_grid)

stop_training = False
for callback in callbacks:
Expand All @@ -2908,10 +2934,21 @@ def fit(

return history

def _plot_samples(self, noise: np.ndarray, epoch: int):
def _plot_samples(
self,
noise: np.ndarray,
epoch: int,
labels: np.ndarray | None = None,
grid_size: tuple[int, int] = (8, 8)
):
n_rows, n_cols = grid_size
if noise.shape[0] != n_rows * n_cols:
raise ValueError("The number of samples must match the grid size.")

if self.n_classes is not None:
samples_per_class = noise.shape[0] // self.n_classes
labels = np.repeat(np.arange(self.n_classes), samples_per_class)
if labels is None:
samples_per_class = noise.shape[0] // self.n_classes
labels = np.repeat(np.arange(self.n_classes), samples_per_class)
one_hot_labels = np.zeros((len(labels), self.n_classes))
one_hot_labels[np.arange(len(labels)), labels] = 1
latent_points = np.concatenate([noise, one_hot_labels], axis=1)
Expand All @@ -2921,8 +2958,6 @@ def _plot_samples(self, noise: np.ndarray, epoch: int):
generated = self.generator.forward_pass(latent_points, training=False)

height, width = self.image_dimensions
n_rows = samples_per_class if self.n_classes else 8
n_cols = self.n_classes if self.n_classes else 8
figure = np.zeros((height * n_rows, width * n_cols))

for i in range(n_rows):
Expand Down

0 comments on commit 6bcd35d

Please sign in to comment.