From 6fbf88400f7d46aa72b2db171488231481bb738b Mon Sep 17 00:00:00 2001 From: Marc Pinet <52708150+marcpinet@users.noreply.github.com> Date: Tue, 10 Dec 2024 09:46:52 +0100 Subject: [PATCH] fix(gan): metrics handling --- neuralnetlib/models.py | 48 ++++++++++++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/neuralnetlib/models.py b/neuralnetlib/models.py index ef3f2c1..4c5ba59 100644 --- a/neuralnetlib/models.py +++ b/neuralnetlib/models.py @@ -270,6 +270,10 @@ def fit(self, x_train: np.ndarray, y_train: np.ndarray, Returns: Dictionary containing the training history of metrics (loss and any other metrics) """ + + if hasattr(self, 'metrics') and self.metrics is not None: + metrics = self.metrics + history = History({ 'loss': [], 'val_loss': [] @@ -1279,6 +1283,10 @@ def fit(self, x_train: np.ndarray, Returns: Dictionary containing the training history """ + + if hasattr(self, 'metrics') and self.metrics is not None: + metrics = self.metrics + history = History({ 'loss': [], 'val_loss': [] @@ -1982,6 +1990,9 @@ def fit(self, x_train: np.ndarray | list, y_train: np.ndarray | list, validation_split: float | None = None, callbacks: list = []) -> dict: + if hasattr(self, 'metrics') and self.metrics is not None: + metrics = self.metrics + history = History({ 'loss': [], 'val_loss': [] @@ -2704,6 +2715,9 @@ def fit( visualization_grid: tuple[int, int] = (8, 8) ) -> dict: + if hasattr(self, 'metrics') and self.metrics is not None: + metrics = self.metrics + history = History({ 'discriminator_loss': [], 'generator_loss': [], @@ -2730,23 +2744,25 @@ def fit( history[f'val_discriminator_{metric.name}'] = [] history[f'val_generator_{metric.name}'] = [] - if plot_generated: if fixed_noise is None: - n_samples = visualization_grid[0] * visualization_grid[1] + if self.n_classes is not None: + base_samples = visualization_grid[0] * visualization_grid[1] + n_samples = ((base_samples + self.n_classes - 1) // self.n_classes) * self.n_classes + n_rows = int(np.sqrt(n_samples)) + while n_samples % n_rows != 0: + n_rows += 1 + n_cols = n_samples // n_rows + visualization_grid = (n_rows, n_cols) + else: + n_samples = visualization_grid[0] * visualization_grid[1] + rng = np.random.default_rng(self.random_state) fixed_noise = rng.normal(0, 1, (n_samples, self.latent_dim)) - if self.n_classes is not None and fixed_labels is None: - n_samples = fixed_noise.shape[0] - samples_per_class = n_samples // self.n_classes - samples_distribution = np.full(self.n_classes, samples_per_class) - - remainder = n_samples % self.n_classes - if remainder > 0: - samples_distribution[:remainder] += 1 - - fixed_labels = np.repeat(np.arange(self.n_classes), samples_distribution) + if self.n_classes is not None and fixed_labels is None: + samples_per_class = fixed_noise.shape[0] // self.n_classes + fixed_labels = np.repeat(np.arange(self.n_classes), samples_per_class) callbacks = callbacks if callbacks is not None else [] @@ -2818,8 +2834,8 @@ def fit( batch_metrics = {} if metrics is not None: - noise = self._generate_latent_points(len(x_batch), y_batch) - generated_samples = self.forward_pass(noise, training=False) + latent_points, _ = self._generate_latent_points(len(x_batch), y_batch) + generated_samples = self.forward_pass(latent_points, training=False) for metric in metrics: metric_value = metric(generated_samples, x_batch) metric_values[f'generator_{metric.name}'] += metric_value @@ -2861,9 +2877,9 @@ def fit( d_error, g_error = self.train_on_batch(x_train, y_train, len(x_train), n_critic) if metrics is not None: - noise = self._generate_latent_points( + latent_points, _ = self._generate_latent_points( len(x_train) if n_gen_samples is None else n_gen_samples, y_train) - generated_samples = self.forward_pass(noise, training=False) + generated_samples = self.forward_pass(latent_points, training=False) for metric in metrics: metric_value = metric(generated_samples, x_train)