Skip to content

Commit

Permalink
Add more tests for the isolated potentials
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed Apr 11, 2024
1 parent e323b96 commit f7d13a7
Showing 1 changed file with 40 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,56 @@
import jax.numpy as np
import numpy as onp

from znnl.loss_functions import ContrastiveIsolatedPotentialLoss
from znnl.distance_metrics import LPNorm
from znnl.loss_functions import (
ContrastiveIsolatedPotentialLoss,
ExponentialRepulsionLoss,
ExternalPotential,
)


class TestContrastiveIsolatedPotentialLoss:
"""
Class for the testing of the contrastive isolated potential loss.
"""

@classmethod
def setup_class(cls):
"""
Prepare the test class
"""
cls.predictions = np.array([[1, 1, 2], [1, 1, 1], [0, 0, 0], [2, 1, 1]])
cls.targets = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]]) # one-hot

def test_contrastive_loss(self):
"""
Test the contrastive loss call method
"""
# General case
predictions = np.array([[1, 1, 2], [1, 1, 1], [0, 0, 0], [2, 1, 1]])
targets = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]]) # one-hot
contrastive_loss = ContrastiveIsolatedPotentialLoss()
loss = contrastive_loss(self.predictions, self.targets)
loss = contrastive_loss(predictions, targets)
onp.testing.assert_almost_equal(loss, 3.36333, decimal=4)

# Only attractive potential
predictions = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]])
targets = np.array([[1, 0, 0], [1, 0, 0], [1, 0, 0], [1, 0, 0]]) # one-hot
contrastive_loss = ContrastiveIsolatedPotentialLoss(
turn_off_external_potential=True
)
loss = contrastive_loss(predictions, targets)
onp.testing.assert_almost_equal(loss, 0.0, decimal=4)

# Only repulsive potential
predictions = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
targets = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) # one-hot
contrastive_loss = ContrastiveIsolatedPotentialLoss(
turn_off_external_potential=True,
repulsive_pot_fn=ExponentialRepulsionLoss(alpha=1, temp=1),
)
loss = contrastive_loss(predictions, targets)
onp.testing.assert_almost_equal(loss, 1.0, decimal=4)

# Only external potential
predictions = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
targets = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
contrastive_loss = ContrastiveIsolatedPotentialLoss(
turn_off_attractive_potential=True,
turn_off_repulsive_potential=True,
external_pot_fn=ExternalPotential(distance_metric=LPNorm(order=2)),
)
loss = contrastive_loss(predictions, targets)
onp.testing.assert_almost_equal(loss, np.sqrt(3), decimal=4)

0 comments on commit f7d13a7

Please sign in to comment.