-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add digitize op for Discretization layer #641
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR! 👍
keras_core/ops/numpy.py
Outdated
return backend.numpy.digitize(x, bins) | ||
|
||
def compute_output_spec(self, x, bins): | ||
return KerasTensor(x.shape, dtype=x.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Surely the dtype should be int? What is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x.dtype
can be float or int, so I must return KerasTensor(x.shape, dtype="int32")
?
actually I wanted to make the dtype int, but I saw that there is other ops where the input can be float and the dtype is set to the input dtype. so I thought the dtype of the shape must be the same as the input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not an open question where you can make a choice. The dtype you pass here should match the dtype that gets actually returned when you run the op. What is that dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I get it now. it's int64
for numpy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It must be the same dtype in all backends (if it isn't, that's a bug and we need to cast)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have used standardize_dtype
to test the return dtypes but it changes the dtype of numpy from int64
to int32
because of this line
if dtype == "int": |
where:
np.dtype("int64") == "int" ##returns True
so it executes what in the condition and return int32
instead of int64
which causes the fail of the test. I can open a pull request to fix this, if you like. I will just move the if statement below the mentioned line to the top of it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also jax doesn't enable int64
until jax_enable_x64
is set to True
using jax.config.update("jax_enable_x64", True)
. I think If we want to unify the dtypes we should enable x64, when the backend is set to jax.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. In this case just pass dtype="int"
in compute_output_spec
, and in the unit tests check x.dtype
to match backend.standardize_dtype("int")
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did that. but I had to cast pytorch output to int32
for the test to work.
I think fixing standrize_dtype()
and enabling int64
by default for jax will be a better solution. can you give that a look?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
JAX doesn't make that possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution!
This op will help to implement
Discretization
layer.The output of the op is unified to be like the tensorflow output. and there is somethings to consider:
1- jax and numpy can digitize if the bins are monotonically decreasing but torch and tensorflow can't. torch will return undefined output. And tensorflow will return an error.
2- jax, numpy, and torch has
right
arg but tensorflow hasn't. so I didn't add it