Skip to content

Commit

Permalink
refactor(gan): some improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Dec 10, 2024
1 parent 5dc15e8 commit 261d107
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions neuralnetlib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2514,6 +2514,15 @@ def compile(
print(
f"Inferred image dimensions: {self._image_height}x{self._image_width}")

if self.n_classes is not None:
input_shape = discriminator.layers[0].input_shape[1]
expected_shape = self._image_height * self._image_width + self.n_classes
if input_shape != expected_shape:
raise ValueError(
f"Discriminator input shape ({input_shape}) does not match "
f"expected shape for conditional GAN ({expected_shape})"
)

last_dense = None
for layer in generator.layers:
if isinstance(layer, Dense):
Expand Down Expand Up @@ -2965,8 +2974,14 @@ def _plot_samples(
if labels is None:
samples_per_class = noise.shape[0] // self.n_classes
labels = np.repeat(np.arange(self.n_classes), samples_per_class)

one_hot_labels = np.zeros((len(labels), self.n_classes))
one_hot_labels[np.arange(len(labels)), labels] = 1

sorted_indices = np.argsort(labels)
noise = noise[sorted_indices]
one_hot_labels = one_hot_labels[sorted_indices]

latent_points = np.concatenate([noise, one_hot_labels], axis=1)
else:
latent_points = noise
Expand All @@ -2985,6 +3000,15 @@ def _plot_samples(
plt.figure(figsize=(10, 8))
plt.imshow(figure, cmap='gray_r', interpolation='nearest')
plt.axis('off')

if self.n_classes is not None:
samples_per_class = n_cols
for i in range(n_rows):
plt.text(-width/2, i * height + height/2,
f'Class {i}',
horizontalalignment='right',
verticalalignment='center')

plt.tight_layout(pad=0)
plt.savefig(f'video{str(epoch).zfill(2)}.png', bbox_inches='tight', pad_inches=0)
plt.close()
Expand Down

0 comments on commit 261d107

Please sign in to comment.