Skip to content

Commit

Permalink
added scaled_pow
Browse files Browse the repository at this point in the history
  • Loading branch information
samho committed Nov 9, 2023
1 parent 212c5d1 commit 7b7f227
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def scaled_cbrt(A: ScaledArray) -> ScaledArray:
pass

def scaled_div(A: ScaledArray, B: ScaledArray) -> ScaledArray:
return ScaledArray(A.data * B.data, A.scale * B.scale)
return ScaledArray(A.data / B.data, A.scale / B.scale)

def scaled_dot(A: ScaledArray, B: ScaledArray) -> ScaledArray:
pass
Expand All @@ -48,7 +48,9 @@ def scaled_min(A: ScaledArray, B: ScaledArray) -> ScaledArray:
pass

def scaled_pow(A: ScaledArray, B: ScaledArray) -> ScaledArray:
pass
output_data = lax.pow(lax.pow(A.data, B.data), B.scale) #I think doing in this order means we won't overflow
output_scale = lax.pow(lax.pow(A.scale, B.data), B.scale)
return ScaledArray(output_data, output_scale)

def scaled_sub(A: ScaledArray, B: ScaledArray) -> ScaledArray:
output_scale = lax.sqrt(A.scale ** 2 + B.scale ** 2)
Expand Down

0 comments on commit 7b7f227

Please sign in to comment.