From 2f2d25e82bbfa48edf43c8efe7d1f444867e6bbf Mon Sep 17 00:00:00 2001 From: Marcel Arpogaus <38564291+MArpogaus@users.noreply.github.com> Date: Fri, 18 Oct 2024 16:26:40 +0200 Subject: [PATCH] fix: calculate correct number of singleton dimensions for bounds --- src/bernstein_flow/math/bernstein.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/bernstein_flow/math/bernstein.py b/src/bernstein_flow/math/bernstein.py index cff6d19..f47e541 100644 --- a/src/bernstein_flow/math/bernstein.py +++ b/src/bernstein_flow/math/bernstein.py @@ -4,7 +4,7 @@ # author : Marcel Arpogaus # # created : 2024-07-10 10:10:18 (Marcel Arpogaus) -# changed : 2024-07-18 11:57:17 (Marcel Arpogaus) +# changed : 2024-10-18 16:22:18 (Marcel Arpogaus) # %% License ################################################################### # Copyright 2024 Marcel Arpogaus @@ -182,10 +182,9 @@ def get_bounds(thetas: tf.Tensor) -> tf.Tensor: x = tf.cast([eps, 1 - eps], dtype=thetas.dtype) # adds singleton dimensions for batch shape - batch_shape = prefer_static.shape(thetas)[:-1] - batch_rank = prefer_static.rank(batch_shape) + batch_dims = prefer_static.rank(thetas) - shape = [...] + [tf.newaxis for _ in range(batch_rank + 1)] + shape = [...] + [tf.newaxis for _ in range(batch_dims - 1)] x = x[shape] return x