Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Jan 4, 2024
1 parent 6015f49 commit d96840c
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax_scaled_arithmetics/lax/scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def scaled_sin(val: ScaledArray) -> ScaledArray:


def scaled_minmax(prim, lhs, rhs):
print("LHS/RHS:", lhs.scale, rhs.scale)
print("LHS/RHS scale:", lhs.scale, rhs.scale)
# print(np.all(is_static_zero(lhs)), np.all(is_static_zero(rhs)))
# return scaled_op_default_translation(prim, [lhs, rhs])

Expand All @@ -268,6 +268,7 @@ def scaled_minmax(prim, lhs, rhs):

raise NotImplementedError("")


@core.register_scaled_lax_op
def scaled_min(lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:
return scaled_minmax(lax.min_p, lhs, rhs)
Expand Down

0 comments on commit d96840c

Please sign in to comment.