Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 590154234
  • Loading branch information
froystig authored and JAXopt authors committed Dec 12, 2023
1 parent 445b241 commit 58bac0a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/perturbations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def test_rank_finite_diff(self):
delta_num = (sq_loss_plus_h - sq_loss_minus_h) / (2 * eps)
delta_lin = jnp.sum(gradient_square_rank * h)

self.assertArraysAllClose(delta_num, delta_lin, atol=5e-2)
self.assertArraysAllClose(delta_num, delta_lin, atol=1e-1, rtol=.5)


class PerturbationsMaxTest(test_util.JaxoptTestCase):
Expand Down Expand Up @@ -571,7 +571,7 @@ def test_noise_iid(self, control_variate):
rngs_batch)
self.assertArraysAllClose(pert_scalar_repeat[0],
pert_scalar_repeat[1],
atol=2e-2)
atol=1e-1, rtol=1e-2)
delta_noise = pert_scalar_repeat[0] - pert_scalar_repeat[1]
self.assertNotAlmostEqual(jnp.linalg.norm(delta_noise), 0)

Expand Down Expand Up @@ -752,7 +752,7 @@ def test_noise_iid(self):
pert_repeat = jax.vmap(pert_fun)(theta_batch_repeat,
rngs_batch)
self.assertArraysAllClose(pert_repeat[0], pert_repeat[1],
atol=2e-2)
atol=5e-2, rtol=5e-2)
delta_noise = pert_repeat[0] - pert_repeat[1]
self.assertNotAlmostEqual(jnp.linalg.norm(delta_noise), 0)

Expand Down

0 comments on commit 58bac0a

Please sign in to comment.