diff --git a/neuralnetlib/models.py b/neuralnetlib/models.py index 748bbf1..27aff75 100644 --- a/neuralnetlib/models.py +++ b/neuralnetlib/models.py @@ -57,7 +57,7 @@ def train_on_batch(self, x_batch: np.ndarray, y_batch: np.ndarray) -> float: pass @abstractmethod - def compile(self, loss_function, optimizer, verbose: bool = False): + def compile(self, loss_function, optimizer, verbose: bool = False, metrics: list[Metric] = None): pass @abstractmethod @@ -139,11 +139,12 @@ def add(self, layer: Layer): f"Invalid activation function: {activation_attr}") self.layers.append(activation) - def compile(self, loss_function: LossFunction | str, optimizer: Optimizer | str, verbose: bool = False): + def compile(self, loss_function: LossFunction | str, optimizer: Optimizer | str, verbose: bool = False, metrics: list[Metric] = None): self.loss_function = loss_function if isinstance(loss_function, LossFunction) else LossFunction.from_name( loss_function) self.optimizer = optimizer if isinstance( optimizer, Optimizer) else Optimizer.from_name(optimizer) + self.metrics = metrics if verbose: print(str(self)) @@ -912,7 +913,8 @@ def compile(self, decoder_loss: LossFunction | str = None, encoder_optimizer: Optimizer | str = None, decoder_optimizer: Optimizer | str = None, - verbose: bool = False): + verbose: bool = False, + metrics: list[Metric] = None): if encoder_loss is None: encoder_loss = decoder_loss @@ -936,6 +938,8 @@ def compile(self, self.decoder_optimizer = decoder_optimizer if isinstance( decoder_optimizer, Optimizer) else Optimizer.from_name(decoder_optimizer) + self.metrics = metrics + if verbose: print(str(self)) @@ -1746,12 +1750,16 @@ def prepare_data(self, x_train: np.ndarray, y_train: np.ndarray) -> tuple: def compile(self, loss_function: LossFunction | str, optimizer: Optimizer | str, - verbose: bool = False) -> None: + verbose: bool = False, + metrics: list | None = None) -> None: + self.loss_function = loss_function if isinstance( loss_function, LossFunction) else LossFunction.from_name(loss_function) self.optimizer = optimizer if isinstance( optimizer, Optimizer) else Optimizer.from_name(optimizer) + self.metrics = metrics + if verbose: print(str(self)) @@ -2483,7 +2491,8 @@ def compile( generator_optimizer: Optimizer | str, discriminator_optimizer: Optimizer | str, loss_function: LossFunction | str = 'bce', - verbose: bool = False + verbose: bool = False, + metrics: list | None = None ): self.generator = generator self.discriminator = discriminator @@ -2534,6 +2543,8 @@ def compile( self.discriminator.loss_function = self.discriminator_loss self.discriminator.optimizer = self.discriminator_optimizer + self.metrics = metrics + if verbose: print(str(self)) @@ -2689,7 +2700,8 @@ def fit( plot_interval: int = 1, fixed_noise: np.ndarray | None = None, fixed_labels: np.ndarray | None = None, - n_gen_samples: int | None = None + n_gen_samples: int | None = None, + visualization_grid: tuple[int, int] = (8, 8) ) -> dict: history = History({ @@ -2718,9 +2730,23 @@ def fit( history[f'val_discriminator_{metric.name}'] = [] history[f'val_generator_{metric.name}'] = [] - if plot_generated and fixed_noise is None: - rng = np.random.default_rng(self.random_state) - fixed_noise = rng.normal(0, 1, (80, self.latent_dim)) + + if plot_generated: + if fixed_noise is None: + n_samples = visualization_grid[0] * visualization_grid[1] + rng = np.random.default_rng(self.random_state) + fixed_noise = rng.normal(0, 1, (n_samples, self.latent_dim)) + if self.n_classes is not None and fixed_labels is None: + n_samples = fixed_noise.shape[0] + samples_per_class = n_samples // self.n_classes + + samples_distribution = np.full(self.n_classes, samples_per_class) + + remainder = n_samples % self.n_classes + if remainder > 0: + samples_distribution[:remainder] += 1 + + fixed_labels = np.repeat(np.arange(self.n_classes), samples_distribution) callbacks = callbacks if callbacks is not None else [] @@ -2881,7 +2907,7 @@ def fit( f'- val_g_loss: {format_number(val_g_loss)}', end='') if plot_generated and (epoch + 1) % plot_interval == 0: - self._plot_samples(fixed_noise, epoch + 1) + self._plot_samples(fixed_noise, epoch + 1, fixed_labels, visualization_grid) stop_training = False for callback in callbacks: @@ -2908,10 +2934,21 @@ def fit( return history - def _plot_samples(self, noise: np.ndarray, epoch: int): + def _plot_samples( + self, + noise: np.ndarray, + epoch: int, + labels: np.ndarray | None = None, + grid_size: tuple[int, int] = (8, 8) + ): + n_rows, n_cols = grid_size + if noise.shape[0] != n_rows * n_cols: + raise ValueError("The number of samples must match the grid size.") + if self.n_classes is not None: - samples_per_class = noise.shape[0] // self.n_classes - labels = np.repeat(np.arange(self.n_classes), samples_per_class) + if labels is None: + samples_per_class = noise.shape[0] // self.n_classes + labels = np.repeat(np.arange(self.n_classes), samples_per_class) one_hot_labels = np.zeros((len(labels), self.n_classes)) one_hot_labels[np.arange(len(labels)), labels] = 1 latent_points = np.concatenate([noise, one_hot_labels], axis=1) @@ -2921,8 +2958,6 @@ def _plot_samples(self, noise: np.ndarray, epoch: int): generated = self.generator.forward_pass(latent_points, training=False) height, width = self.image_dimensions - n_rows = samples_per_class if self.n_classes else 8 - n_cols = self.n_classes if self.n_classes else 8 figure = np.zeros((height * n_rows, width * n_cols)) for i in range(n_rows):