From 261d1078ad5ecc54339b080654aaf19ed481371c Mon Sep 17 00:00:00 2001 From: Marc Pinet <52708150+marcpinet@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:15:48 +0100 Subject: [PATCH] refactor(gan): some improvements --- neuralnetlib/models.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/neuralnetlib/models.py b/neuralnetlib/models.py index 4c5ba59..b501579 100644 --- a/neuralnetlib/models.py +++ b/neuralnetlib/models.py @@ -2514,6 +2514,15 @@ def compile( print( f"Inferred image dimensions: {self._image_height}x{self._image_width}") + if self.n_classes is not None: + input_shape = discriminator.layers[0].input_shape[1] + expected_shape = self._image_height * self._image_width + self.n_classes + if input_shape != expected_shape: + raise ValueError( + f"Discriminator input shape ({input_shape}) does not match " + f"expected shape for conditional GAN ({expected_shape})" + ) + last_dense = None for layer in generator.layers: if isinstance(layer, Dense): @@ -2965,8 +2974,14 @@ def _plot_samples( if labels is None: samples_per_class = noise.shape[0] // self.n_classes labels = np.repeat(np.arange(self.n_classes), samples_per_class) + one_hot_labels = np.zeros((len(labels), self.n_classes)) one_hot_labels[np.arange(len(labels)), labels] = 1 + + sorted_indices = np.argsort(labels) + noise = noise[sorted_indices] + one_hot_labels = one_hot_labels[sorted_indices] + latent_points = np.concatenate([noise, one_hot_labels], axis=1) else: latent_points = noise @@ -2985,6 +3000,15 @@ def _plot_samples( plt.figure(figsize=(10, 8)) plt.imshow(figure, cmap='gray_r', interpolation='nearest') plt.axis('off') + + if self.n_classes is not None: + samples_per_class = n_cols + for i in range(n_rows): + plt.text(-width/2, i * height + height/2, + f'Class {i}', + horizontalalignment='right', + verticalalignment='center') + plt.tight_layout(pad=0) plt.savefig(f'video{str(epoch).zfill(2)}.png', bbox_inches='tight', pad_inches=0) plt.close()