From 70a7fa3f641f11303cf000ed4b24813c21a610a6 Mon Sep 17 00:00:00 2001 From: Marcel Arpogaus <38564291+MArpogaus@users.noreply.github.com> Date: Tue, 3 Oct 2023 15:22:31 +0200 Subject: [PATCH] fixes tests for weibul and student-t base distribustions --- test/distributions/test_bernstein_flow.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/test/distributions/test_bernstein_flow.py b/test/distributions/test_bernstein_flow.py index 8db69ef..11a50df 100644 --- a/test/distributions/test_bernstein_flow.py +++ b/test/distributions/test_bernstein_flow.py @@ -228,9 +228,11 @@ def test_student_t(self): dtype=dtype, base_distribution=student_t, thetas_constrain_fn=get_thetas_constrain_fn( - low=-25, high=25, allow_flexible_bounds=True + low=-35, high=35, allow_flexible_bounds=True ), scale_base_distribution=False, + clip_to_bernstein_domain=False, + bb_class=BernsteinBijectorLinearExtrapolate, ) self.f(normal_dist, trans_dist) @@ -243,10 +245,14 @@ def test_weibull(self): order=10, dtype=dtype, base_distribution=weibull, - thetas_constrain_fn=get_thetas_constrain_fn(low=1e-10, high=50), + thetas_constrain_fn=get_thetas_constrain_fn(low=1e-12, high=100), + bb_class=BernsteinBijectorLinearExtrapolate, + clip_to_bernstein_domain=False, scale_base_distribution=False, + shift_data=False, + scale_data=False, ) - self.f(normal_dist, trans_dist) + self.f(normal_dist, trans_dist, stay_in_domain=True) @pytest.mark.skip def test_small_numbers(self):