diff --git a/jax_scaled_arithmetics/lax/scaled_ops_common.py b/jax_scaled_arithmetics/lax/scaled_ops_common.py index 83362d5..8a3a4f4 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops_common.py +++ b/jax_scaled_arithmetics/lax/scaled_ops_common.py @@ -257,7 +257,7 @@ def scaled_sin(val: ScaledArray) -> ScaledArray: def scaled_minmax(prim, lhs, rhs): - print("LHS/RHS:", lhs.scale, rhs.scale) + print("LHS/RHS scale:", lhs.scale, rhs.scale) # print(np.all(is_static_zero(lhs)), np.all(is_static_zero(rhs))) # return scaled_op_default_translation(prim, [lhs, rhs]) @@ -268,6 +268,7 @@ def scaled_minmax(prim, lhs, rhs): raise NotImplementedError("") + @core.register_scaled_lax_op def scaled_min(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray: return scaled_minmax(lax.min_p, lhs, rhs)