-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert neural machine translation with transformer to keras-core #623
Convert neural machine translation with transformer to keras-core #623
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! 👍
examples/keras_io/tensorflow/nlp/neural_machine_translation_with_transformer.py
Outdated
Show resolved
Hide resolved
examples/keras_io/tensorflow/nlp/neural_machine_translation_with_transformer.py
Outdated
Show resolved
Hide resolved
examples/keras_io/tensorflow/nlp/neural_machine_translation_with_transformer.py
Outdated
Show resolved
Hide resolved
examples/keras_io/tensorflow/nlp/neural_machine_translation_with_transformer.py
Outdated
Show resolved
Hide resolved
examples/keras_io/tensorflow/nlp/neural_machine_translation_with_transformer.py
Outdated
Show resolved
Hide resolved
examples/keras_io/tensorflow/nlp/neural_machine_translation_with_transformer.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update. Is this example backend-agnostic (minus the use of tf.data)?
yes and no. it works with KERAS_BACKEND=torch not for jax (because of tile op I think. I don't how to fix it) |
What's the failure? |
in the colab notebook running on my macbook, there are more error messages
|
Are you able to isolate the failure with the |
I don't know how to do it in JAX. I didn't really use JAX :-) For input shapes:
|
11e75fa
to
c7ea016
Compare
here is a small example derived from the import keras_core as keras
from keras_core import ops
foo = keras.Input((None, 256))
class MyLayer(keras.layers.Layer):
def call(self, inputs):
input_shape = ops.shape(inputs)
batch_size, sequence_length = input_shape[0], input_shape[1]
i = ops.arange(sequence_length)[:, None]
j = ops.arange(sequence_length)
mask = ops.cast(i >= j, dtype="int32")
mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
mult = ops.concatenate(
[ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],
axis=0,
)
return ops.tile(mask, mult)
f = MyLayer()(foo)
print(f"succeed: f = {f}") Error message
|
if we turn the example to have fixed input shapes, import keras_core as keras
from keras_core import ops
foo = keras.Input((20, 256), batch_size=64)
class MyLayer(keras.layers.Layer):
def call(self, inputs):
input_shape = ops.shape(inputs)
batch_size, sequence_length = input_shape[0], input_shape[1]
i = ops.arange(sequence_length)[:, None]
j = ops.arange(sequence_length)
mask = ops.cast(i >= j, dtype="int32")
mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
mult = ops.concatenate(
[ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],
axis=0,
)
return ops.tile(mask, mult)
f = MyLayer()(foo)
print(f"succeed: f = {f}") It works for
|
I looked into it and this is quite blatantly a JAX bug.
As far as I can tell what's going on is that The layer works eagerly but would not work in any symbolic JAX transformation. I guess you could file a JAX bug report... |
Do we have an option to disable jit for a specific code block in Keras Core? If I disable jit in def get_causal_attention_mask(self, inputs):
with jax.ensure_compile_time_eval():
input_shape = ops.shape(inputs)
batch_size, sequence_length = input_shape[0], input_shape[1]
i = ops.arange(sequence_length)[:, None]
j = ops.arange(sequence_length)
mask = ops.cast(i >= j, dtype="int32")
mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
mult = ops.concatenate(
[ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],
)
return ops.tile(mask, mult) I updated the jax notebook. |
This is not doable in a cross-backend way, unfortunately. Also, it will lead to a considerable drop in performance even if the block is small. I think our best choice here is to leave the example as-is, but add a comment paragraph explaining that JAX is not supported due to the behavior of |
return self.layernorm_3(out_2 + proj_output) | ||
|
||
def get_causal_attention_mask(self, inputs): | ||
# due to a bug in jax.tile(), abstract/jit symbols don't work |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do the following at the start of the example:
# We set the backend to TensorFlow. The code works with
# both `tensorflow` and `torch`. It does not work with JAX
# due to the behavior of `jax.numpy.tile` in a jit scope
# (used in `TransformerDecoder.get_causal_attention_mask()`:
#`tile` in JAX does not support a dynamic `reps` argument.
# You can make the code work in JAX by wrapping the
# inside of the `get_causal_attention_mask` method in
# a decorator to prevent jit compilation:
# `with jax.ensure_compile_time_eval():`.
import os
os["KERAS_BACKEND"] = "tensorflow"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you
import tensorflow as tf
. Usetf_strings
andtf_data
import keras_core as keras
and replace tf ops withkeras_core.ops
when possibleKERAS_BACKEND=torch
not for jax thoughdiff: 25b60eb
colab notebook