Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add digitize op for Discretization layer #641

Merged
merged 4 commits into from
Aug 1, 2023
Merged

Add digitize op for Discretization layer #641

merged 4 commits into from
Aug 1, 2023

Conversation

abuelnasr0
Copy link
Contributor

@abuelnasr0 abuelnasr0 commented Jul 30, 2023

This op will help to implement Discretization layer.
The output of the op is unified to be like the tensorflow output. and there is somethings to consider:
1- jax and numpy can digitize if the bins are monotonically decreasing but torch and tensorflow can't. torch will return undefined output. And tensorflow will return an error.
2- jax, numpy, and torch has right arg but tensorflow hasn't. so I didn't add it

@abuelnasr0 abuelnasr0 changed the title Add digitize op for digitization layer Add digitize op for Discretization layer Jul 31, 2023
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! 👍

keras_core/ops/numpy.py Outdated Show resolved Hide resolved
return backend.numpy.digitize(x, bins)

def compute_output_spec(self, x, bins):
return KerasTensor(x.shape, dtype=x.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surely the dtype should be int? What is it?

Copy link
Contributor Author

@abuelnasr0 abuelnasr0 Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x.dtype can be float or int, so I must return KerasTensor(x.shape, dtype="int32")?
actually I wanted to make the dtype int, but I saw that there is other ops where the input can be float and the dtype is set to the input dtype. so I thought the dtype of the shape must be the same as the input.

Copy link
Member

@fchollet fchollet Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not an open question where you can make a choice. The dtype you pass here should match the dtype that gets actually returned when you run the op. What is that dtype?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I get it now. it's int64 for numpy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It must be the same dtype in all backends (if it isn't, that's a bug and we need to cast)

Copy link
Contributor Author

@abuelnasr0 abuelnasr0 Aug 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have used standardize_dtype to test the return dtypes but it changes the dtype of numpy from int64 to int32 because of this line

where:

np.dtype("int64") == "int" ##returns True

so it executes what in the condition and return int32 instead of int64 which causes the fail of the test. I can open a pull request to fix this, if you like. I will just move the if statement below the mentioned line to the top of it.

Copy link
Contributor Author

@abuelnasr0 abuelnasr0 Aug 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also jax doesn't enable int64 until jax_enable_x64 is set to True using jax.config.update("jax_enable_x64", True). I think If we want to unify the dtypes we should enable x64, when the backend is set to jax.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. In this case just pass dtype="int" in compute_output_spec, and in the unit tests check x.dtype to match backend.standardize_dtype("int").

Copy link
Contributor Author

@abuelnasr0 abuelnasr0 Aug 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did that. but I had to cast pytorch output to int32 for the test to work.
I think fixing standrize_dtype() and enabling int64 by default for jax will be a better solution. can you give that a look?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JAX doesn't make that possible.

keras_core/ops/numpy_test.py Show resolved Hide resolved
keras_core/ops/numpy_test.py Show resolved Hide resolved
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution!

@fchollet fchollet merged commit 710cfdb into keras-team:main Aug 1, 2023
6 checks passed
@abuelnasr0 abuelnasr0 deleted the Digitize-op branch August 3, 2023 14:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants