Skip to content

Commit

Permalink
feat(GAN): add Conditional GAN (CGAN)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Dec 9, 2024
1 parent 8cc464c commit 5460a72
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 88 deletions.
Binary file added cgan_samples_epoch_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
140 changes: 94 additions & 46 deletions neuralnetlib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2314,6 +2314,7 @@ class GAN(BaseModel):
def __init__(
self,
latent_dim: int = 100,
n_classes: int | None = None,
gradient_clip_threshold: float = 0.1,
enable_padding: bool = False,
padding_size: int = 32,
Expand All @@ -2327,6 +2328,7 @@ def __init__(
super().__init__(gradient_clip_threshold, enable_padding, padding_size, random_state)

self.latent_dim = latent_dim
self.n_classes = n_classes
self.generator = None
self.discriminator = None
self.generator_optimizer = None
Expand Down Expand Up @@ -2445,10 +2447,26 @@ def backward_pass(self, error: np.ndarray):

self.generator.backward_pass(error)

def _generate_latent_points(self, n_samples: int) -> np.ndarray:
def _generate_latent_points(self, n_samples: int, labels: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]:
rng = np.random.default_rng(self.random_state)
latent_points = rng.normal(0, 1, (n_samples, self.latent_dim))
return latent_points

if self.n_classes is not None:
if labels is None:
labels = rng.integers(0, self.n_classes, n_samples)
elif labels.ndim == 2 and labels.shape[1] == self.n_classes:
return np.concatenate([latent_points, labels], axis=1), labels

one_hot_labels = np.zeros((n_samples, self.n_classes))
if labels.ndim == 1:
one_hot_labels[np.arange(n_samples), labels] = 1
else:
one_hot_labels = labels

latent_points = np.concatenate([latent_points, one_hot_labels], axis=1)
return latent_points, one_hot_labels

return latent_points, None

def _apply_spectral_norm(self, model: 'Sequential'):
if not self.use_spectral_norm:
Expand Down Expand Up @@ -2495,7 +2513,8 @@ def _ensure_initialized(self, input_data: np.ndarray):
def train_on_batch(
self,
real_samples: np.ndarray,
batch_size: int,
labels: np.ndarray | None = None,
batch_size: int = 32,
n_critic: int = 1
) -> tuple[float, float]:
rng = np.random.default_rng(self.random_state)
Expand All @@ -2504,60 +2523,71 @@ def train_on_batch(
for _ in range(n_critic):
idx = rng.choice(len(real_samples), batch_size, replace=False)
real_batch = real_samples[idx]
batch_labels = labels[idx] if labels is not None else None

noise = rng.standard_normal(size=(batch_size, self.latent_dim))
noise, gen_labels = self._generate_latent_points(batch_size, batch_labels)
fake_batch = self.generator.forward_pass(noise, training=False)

combined_batch = np.concatenate([real_batch, fake_batch])
combined_labels = np.zeros((2 * batch_size, 1))
combined_labels[:batch_size] = 1.0

if self.n_classes is not None:
if batch_labels is not None:
combined_cond = np.concatenate([batch_labels, gen_labels])
combined_batch = np.concatenate([combined_batch, combined_cond], axis=1)

self.discriminator.y_true = combined_labels
predictions = self.discriminator.forward_pass(
combined_batch, training=True)
predictions = self.discriminator.forward_pass(combined_batch, training=True)
d_loss = self.discriminator_loss(combined_labels, predictions)
d_grad = self.discriminator_loss.derivative(
combined_labels, predictions)
d_grad = self.discriminator_loss.derivative(combined_labels, predictions)
self.discriminator.backward_pass(d_grad)

d_loss_total += d_loss

d_loss_avg = d_loss_total / n_critic

noise = rng.standard_normal(size=(batch_size, self.latent_dim))
noise, gen_labels = self._generate_latent_points(batch_size)
fake_samples = self.generator.forward_pass(noise, training=True)

if self.n_classes is not None:
fake_samples_with_cond = np.concatenate([fake_samples, gen_labels], axis=1)
else:
fake_samples_with_cond = fake_samples

target_labels = np.ones((batch_size, 1))

disc_predictions = self.discriminator.forward_pass(
fake_samples, training=False)
disc_predictions = self.discriminator.forward_pass(fake_samples_with_cond, training=False)
self.discriminator.y_true = target_labels
g_loss = self.generator_loss(target_labels, disc_predictions)
g_grad = self.generator_loss.derivative(
target_labels, disc_predictions)
g_grad = self.generator_loss.derivative(target_labels, disc_predictions)

d_grad = self.discriminator.backward_pass(g_grad, compute_only=True)
if self.n_classes is not None:
d_grad = d_grad[:, :-self.n_classes]
self.generator.backward_pass(d_grad, gan=True)

return d_loss_avg, g_loss

def fit(
self,
x_train: np.ndarray,
epochs: int = 100,
batch_size: int | None = None,
n_critic: int = 5,
verbose: bool = True,
metrics: list | None = None,
random_state: int | None = None,
validation_data: tuple | None = None,
validation_split: float | None = None,
callbacks: list = [],
plot_generated: bool = False,
plot_interval: int = 1,
fixed_noise: np.ndarray | None = None,
n_gen_samples: int | None = None
) -> dict:
self,
x_train: np.ndarray,
y_train: np.ndarray | None = None,
epochs: int = 100,
batch_size: int | None = None,
n_critic: int = 5,
verbose: bool = True,
metrics: list | None = None,
random_state: int | None = None,
validation_data: tuple | None = None,
validation_split: float | None = None,
callbacks: list = [],
plot_generated: bool = False,
plot_interval: int = 1,
fixed_noise: np.ndarray | None = None,
fixed_labels: np.ndarray | None = None,
n_gen_samples: int | None = None
) -> dict:

history = History({
'discriminator_loss': [],
Expand All @@ -2574,6 +2604,7 @@ def fit(
validation_data = (x_val, None)

x_train = np.array(x_train) if not isinstance(x_train, np.ndarray) else x_train
y_train = np.array(y_train) if not isinstance(y_train, np.ndarray) else y_train

if metrics is not None:
metrics = [Metric(m) for m in metrics]
Expand All @@ -2586,7 +2617,7 @@ def fit(

if plot_generated and fixed_noise is None:
rng = np.random.default_rng(self.random_state)
fixed_noise = rng.standard_normal(size=(64, self.latent_dim))
fixed_noise = rng.normal(0, 1, (80, self.latent_dim))

callbacks = callbacks if callbacks is not None else []

Expand All @@ -2613,8 +2644,8 @@ def fit(
callback.on_epoch_begin(epoch, epoch_logs)

start_time = time.time()
x_train_shuffled = shuffle(
x_train,
x_train_shuffled, y_train_shuffled = shuffle(
x_train, y_train,
random_state=random_state if random_state is not None else self.random_state
)
d_error = 0
Expand All @@ -2631,6 +2662,7 @@ def fit(
for j in range(0, x_train.shape[0], batch_size):
batch_index = j // batch_size
x_batch = x_train_shuffled[j:j + batch_size]
y_batch = y_train_shuffled[j:j + batch_size] if y_train is not None else None

batch_logs = {
'batch': batch_index,
Expand All @@ -2642,13 +2674,13 @@ def fit(
callback.on_batch_begin(batch_index, batch_logs)

d_loss, g_loss = self.train_on_batch(
x_batch, min(batch_size, len(x_batch)), n_critic)
x_batch, y_batch, min(batch_size, len(x_batch)), n_critic)
d_error += d_loss
g_error += g_loss

batch_metrics = {}
if metrics is not None:
noise = self._generate_latent_points(len(x_batch))
noise = self._generate_latent_points(len(x_batch), y_batch)
generated_samples = self.forward_pass(noise, training=False)
for metric in metrics:
metric_value = metric(generated_samples, x_batch)
Expand Down Expand Up @@ -2688,11 +2720,11 @@ def fit(
metric_values[k] /= num_batches

else:
d_error, g_error = self.train_on_batch(x_train, len(x_train), n_critic)
d_error, g_error = self.train_on_batch(x_train, y_train, len(x_train), n_critic)

if metrics is not None:
noise = self._generate_latent_points(
len(x_train) if n_gen_samples is None else n_gen_samples)
len(x_train) if n_gen_samples is None else n_gen_samples, y_train)
generated_samples = self.forward_pass(noise, training=False)

for metric in metrics:
Expand Down Expand Up @@ -2765,21 +2797,37 @@ def fit(
return history

def _plot_samples(self, noise: np.ndarray, epoch: int):
generated = self.forward_pass(noise, training=False)
if self.n_classes is not None:
samples_per_class = noise.shape[0] // self.n_classes
labels = np.repeat(np.arange(self.n_classes), samples_per_class)
one_hot_labels = np.zeros((len(labels), self.n_classes))
one_hot_labels[np.arange(len(labels)), labels] = 1
latent_points = np.concatenate([noise, one_hot_labels], axis=1)
else:
latent_points = noise

generated = self.generator.forward_pass(latent_points, training=False)

height, width = self.image_dimensions
sample = generated[0].reshape(height, width)

plt.figure(figsize=(4, 4))
plt.imshow(sample, cmap='gray_r', interpolation='nearest')
n_rows = samples_per_class if self.n_classes else 8
n_cols = self.n_classes if self.n_classes else 8
figure = np.zeros((height * n_rows, width * n_cols))

for i in range(n_rows):
for j in range(n_cols):
sample_idx = i * n_cols + j
sample = generated[sample_idx].reshape(height, width)
figure[i * height:(i + 1) * height, j * width:(j + 1) * width] = sample

plt.figure(figsize=(10, 8))
plt.imshow(figure, cmap='gray_r', interpolation='nearest')
plt.axis('off')
plt.tight_layout()

plt.savefig(f'video{str(epoch).zfill(2)}.png')
plt.tight_layout(pad=0)
plt.savefig(f'video{str(epoch).zfill(2)}.png', bbox_inches='tight', pad_inches=0)
plt.close()

def predict(self, n_samples: int, temperature: float = 1.0) -> np.ndarray:
latent_points = self._generate_latent_points(n_samples)
def predict(self, n_samples: int, labels: np.ndarray | None = None, temperature: float = 1.0) -> np.ndarray:
latent_points, _ = self._generate_latent_points(n_samples, labels)
return self.generator.predict(latent_points, temperature)

def evaluate(
Expand Down

0 comments on commit 5460a72

Please sign in to comment.