Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Dec 4, 2023
1 parent 85b5cc4 commit b6b1936
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def promote_scalar_to_scaled_array(val: Any) -> ScaledArray:
Note: needs to work with any input type, including JAX tracer ones.
"""
print("PROMOTE:", val, type(val))
# int / float special cases
if isinstance(val, float):
return ScaledArray(data=np.array(1, dtype=np.float32), scale=np.float32(val))
Expand Down
4 changes: 2 additions & 2 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

def check_scalar_scales(*args: ScaledArray):
"""Check all ScaledArrays have scalar scaling."""
print(args)
# print(args)
for val in args:
assert np.ndim(val.scale) == 0

Expand Down Expand Up @@ -109,7 +109,6 @@ def scaled_rev(val: ScaledArray, dimensions: Sequence[int]) -> ScaledArray:
@core.register_scaled_lax_op
def scaled_pad(val: ScaledArray, padding_value: Any, padding_config: Any) -> ScaledArray:
# Only supporting constant zero padding for now.
print(padding_value)
assert np.all(is_static_zero(padding_value))
# assert float(padding_value) == 0.0
return ScaledArray(lax.pad(val.data, np.array(0, val.dtype), padding_config), val.scale)
Expand Down Expand Up @@ -157,6 +156,7 @@ def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray:
A, B = as_scaled_array((A, B)) # type:ignore
check_scalar_scales(A, B)
A, B = promote_scale_types(A, B)
# print(A, B)
assert np.issubdtype(A.scale.dtype, np.floating)
# TODO: what happens to `sqrt` for non-floating scale?
output_scale = lax.sqrt(A.scale**2 + B.scale**2)
Expand Down

0 comments on commit b6b1936

Please sign in to comment.