diff --git a/test/distributions/test_bernstein_flow.py b/test/distributions/test_bernstein_flow.py index 8db69ef..11a50df 100644 --- a/test/distributions/test_bernstein_flow.py +++ b/test/distributions/test_bernstein_flow.py @@ -228,9 +228,11 @@ def test_student_t(self): dtype=dtype, base_distribution=student_t, thetas_constrain_fn=get_thetas_constrain_fn( - low=-25, high=25, allow_flexible_bounds=True + low=-35, high=35, allow_flexible_bounds=True ), scale_base_distribution=False, + clip_to_bernstein_domain=False, + bb_class=BernsteinBijectorLinearExtrapolate, ) self.f(normal_dist, trans_dist) @@ -243,10 +245,14 @@ def test_weibull(self): order=10, dtype=dtype, base_distribution=weibull, - thetas_constrain_fn=get_thetas_constrain_fn(low=1e-10, high=50), + thetas_constrain_fn=get_thetas_constrain_fn(low=1e-12, high=100), + bb_class=BernsteinBijectorLinearExtrapolate, + clip_to_bernstein_domain=False, scale_base_distribution=False, + shift_data=False, + scale_data=False, ) - self.f(normal_dist, trans_dist) + self.f(normal_dist, trans_dist, stay_in_domain=True) @pytest.mark.skip def test_small_numbers(self):