Skip to content

Commit

Permalink
fix: calculate correct number of singleton dimensions for bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
MArpogaus committed Oct 18, 2024
1 parent 00c83f8 commit 2f2d25e
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions src/bernstein_flow/math/bernstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# author : Marcel Arpogaus <znepry.necbtnhf@tznvy.pbz>
#
# created : 2024-07-10 10:10:18 (Marcel Arpogaus)
# changed : 2024-07-18 11:57:17 (Marcel Arpogaus)
# changed : 2024-10-18 16:22:18 (Marcel Arpogaus)

# %% License ###################################################################
# Copyright 2024 Marcel Arpogaus
Expand Down Expand Up @@ -182,10 +182,9 @@ def get_bounds(thetas: tf.Tensor) -> tf.Tensor:
x = tf.cast([eps, 1 - eps], dtype=thetas.dtype)

# adds singleton dimensions for batch shape
batch_shape = prefer_static.shape(thetas)[:-1]
batch_rank = prefer_static.rank(batch_shape)
batch_dims = prefer_static.rank(thetas)

shape = [...] + [tf.newaxis for _ in range(batch_rank + 1)]
shape = [...] + [tf.newaxis for _ in range(batch_dims - 1)]
x = x[shape]

return x
Expand Down

0 comments on commit 2f2d25e

Please sign in to comment.