Skip to content

Commit

Permalink
fix(gan): metrics handling
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Dec 10, 2024
1 parent c178566 commit 6fbf884
Showing 1 changed file with 32 additions and 16 deletions.
48 changes: 32 additions & 16 deletions neuralnetlib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,10 @@ def fit(self, x_train: np.ndarray, y_train: np.ndarray,
Returns:
Dictionary containing the training history of metrics (loss and any other metrics)
"""

if hasattr(self, 'metrics') and self.metrics is not None:
metrics = self.metrics

history = History({
'loss': [],
'val_loss': []
Expand Down Expand Up @@ -1279,6 +1283,10 @@ def fit(self, x_train: np.ndarray,
Returns:
Dictionary containing the training history
"""

if hasattr(self, 'metrics') and self.metrics is not None:
metrics = self.metrics

history = History({
'loss': [],
'val_loss': []
Expand Down Expand Up @@ -1982,6 +1990,9 @@ def fit(self, x_train: np.ndarray | list, y_train: np.ndarray | list,
validation_split: float | None = None,
callbacks: list = []) -> dict:

if hasattr(self, 'metrics') and self.metrics is not None:
metrics = self.metrics

history = History({
'loss': [],
'val_loss': []
Expand Down Expand Up @@ -2704,6 +2715,9 @@ def fit(
visualization_grid: tuple[int, int] = (8, 8)
) -> dict:

if hasattr(self, 'metrics') and self.metrics is not None:
metrics = self.metrics

history = History({
'discriminator_loss': [],
'generator_loss': [],
Expand All @@ -2730,23 +2744,25 @@ def fit(
history[f'val_discriminator_{metric.name}'] = []
history[f'val_generator_{metric.name}'] = []


if plot_generated:
if fixed_noise is None:
n_samples = visualization_grid[0] * visualization_grid[1]
if self.n_classes is not None:
base_samples = visualization_grid[0] * visualization_grid[1]
n_samples = ((base_samples + self.n_classes - 1) // self.n_classes) * self.n_classes
n_rows = int(np.sqrt(n_samples))
while n_samples % n_rows != 0:
n_rows += 1
n_cols = n_samples // n_rows
visualization_grid = (n_rows, n_cols)
else:
n_samples = visualization_grid[0] * visualization_grid[1]

rng = np.random.default_rng(self.random_state)
fixed_noise = rng.normal(0, 1, (n_samples, self.latent_dim))
if self.n_classes is not None and fixed_labels is None:
n_samples = fixed_noise.shape[0]
samples_per_class = n_samples // self.n_classes

samples_distribution = np.full(self.n_classes, samples_per_class)

remainder = n_samples % self.n_classes
if remainder > 0:
samples_distribution[:remainder] += 1

fixed_labels = np.repeat(np.arange(self.n_classes), samples_distribution)
if self.n_classes is not None and fixed_labels is None:
samples_per_class = fixed_noise.shape[0] // self.n_classes
fixed_labels = np.repeat(np.arange(self.n_classes), samples_per_class)

callbacks = callbacks if callbacks is not None else []

Expand Down Expand Up @@ -2818,8 +2834,8 @@ def fit(

batch_metrics = {}
if metrics is not None:
noise = self._generate_latent_points(len(x_batch), y_batch)
generated_samples = self.forward_pass(noise, training=False)
latent_points, _ = self._generate_latent_points(len(x_batch), y_batch)
generated_samples = self.forward_pass(latent_points, training=False)
for metric in metrics:
metric_value = metric(generated_samples, x_batch)
metric_values[f'generator_{metric.name}'] += metric_value
Expand Down Expand Up @@ -2861,9 +2877,9 @@ def fit(
d_error, g_error = self.train_on_batch(x_train, y_train, len(x_train), n_critic)

if metrics is not None:
noise = self._generate_latent_points(
latent_points, _ = self._generate_latent_points(
len(x_train) if n_gen_samples is None else n_gen_samples, y_train)
generated_samples = self.forward_pass(noise, training=False)
generated_samples = self.forward_pass(latent_points, training=False)

for metric in metrics:
metric_value = metric(generated_samples, x_train)
Expand Down

0 comments on commit 6fbf884

Please sign in to comment.