Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert neural machine translation with transformer to keras-core #623

Conversation

freedomtan
Copy link
Contributor

  1. remove import tensorflow as tf. Use tf_strings and tf_data
  2. import keras_core as keras and replace tf ops withkeras_core.ops when possible
  3. it actually works with KERAS_BACKEND=torch not for jax though

diff: 25b60eb
colab notebook

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! 👍

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update. Is this example backend-agnostic (minus the use of tf.data)?

@freedomtan freedomtan closed this Jul 27, 2023
@freedomtan
Copy link
Contributor Author

freedomtan commented Jul 27, 2023

Thanks for the update. Is this example backend-agnostic (minus the use of tf.data)?

yes and no. it works with KERAS_BACKEND=torch not for jax (because of tile op I think. I don't how to fix it)

@freedomtan freedomtan reopened this Jul 27, 2023
@fchollet
Copy link
Member

yes and no. it works with KERAS_BACKEND=torch not for jax (because of tile op I think. I don't how to fix it)

What's the failure?

@freedomtan
Copy link
Contributor Author

yes and no. it works with KERAS_BACKEND=torch not for jax (because of tile op I think. I don't how to fix it)

What's the failure?

in the colab notebook

running on my macbook, there are more error messages

.....
Exception ignored in: <function AtomicFunction.__del__ at 0x179629ee0>
Traceback (most recent call last):
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/atomic_function.py", line 218, in __del__
TypeError: 'NoneType' object is not subscriptable
Exception ignored in: <function AtomicFunction.__del__ at 0x179629ee0>
Traceback (most recent call last):
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/atomic_function.py", line 218, in __del__
TypeError: 'NoneType' object is not subscriptable

@fchollet
Copy link
Member

Are you able to isolate the failure with the tile op? E.g. calling the tile op in JAX with arrays with the same shape as here? Are the inputs shapes (of the op) the same in the other backends?

@freedomtan
Copy link
Contributor Author

Are you able to isolate the failure with the tile op? E.g. calling the tile op in JAX with arrays with the same shape as here? Are the inputs shapes (of the op) the same in the other backends?

I don't know how to do it in JAX. I didn't really use JAX :-)

For input shapes:

  • Tensorflow:
    • mask: [1, None, None]
    • mult: [3]
  • JAX:
    • mask: (1, 83, 83)
    • mult(3, )
  • torch:
    • mask: [1, 83, 83]
    • mult: [3]

@freedomtan freedomtan force-pushed the convert_neural_machine_translation_with_transformer branch from 11e75fa to c7ea016 Compare August 1, 2023 04:30
@freedomtan
Copy link
Contributor Author

freedomtan commented Aug 1, 2023

here is a small example derived from the TransformerDecoder.get_causal_attention_mask().
It worked with `KERAS_BACKEND = {tensorflow, torch, numpy}. Failed with 'jax'.

import keras_core as keras
from keras_core import ops

foo = keras.Input((None, 256))

class MyLayer(keras.layers.Layer):
    def call(self, inputs):
        input_shape = ops.shape(inputs)
        batch_size, sequence_length = input_shape[0], input_shape[1]
        i = ops.arange(sequence_length)[:, None]
        j = ops.arange(sequence_length)
        mask = ops.cast(i >= j, dtype="int32")
        mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
        mult = ops.concatenate(
            [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],
            axis=0,
        )
        return ops.tile(mask, mult)

f = MyLayer()(foo)
print(f"succeed: f = {f}")

Error message

Using JAX backend.
Traceback (most recent call last):
  File "/private/tmp/foobar.py", line 20, in <module>
    f = MyLayer()(foo)
  File "/Users/freedom/work/keras-core/keras_core/utils/traceback_utils.py", line 123, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/freedom/work/keras-core/keras_core/backend/jax/core.py", line 208, in compute_output_spec
    _, jax_out = jax.make_jaxpr(wrapped_fn, return_shape=True)(
RuntimeError: Exception encountered when calling MyLayer.call().

Could not automatically infer the output shape / dtype of 'my_layer' (of type MyLayer). Either the `MyLayer.call()` method is incorrect, or you need to implement the `MyLayer.compute_output_spec() / compute_output_shape()` method. Error encountered:

Shapes must be 1D sequences of concrete values of integer type, got (None, None, 256).

Arguments received by MyLayer.call():
  • args=('<KerasTensor shape=(None, None, 256), dtype=float32, name=keras_tensor>',)
  • kwargs=<class 'inspect._empty'>

@freedomtan
Copy link
Contributor Author

if we turn the example to have fixed input shapes,

import keras_core as keras
from keras_core import ops

foo = keras.Input((20, 256), batch_size=64)

class MyLayer(keras.layers.Layer):
    def call(self, inputs):
        input_shape = ops.shape(inputs)
        batch_size, sequence_length = input_shape[0], input_shape[1]
        i = ops.arange(sequence_length)[:, None]
        j = ops.arange(sequence_length)
        mask = ops.cast(i >= j, dtype="int32")
        mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
        mult = ops.concatenate(
            [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],
            axis=0,
        )
        return ops.tile(mask, mult)

f = MyLayer()(foo)
print(f"succeed: f = {f}")

It works for KERAS_BACKEND={'tensorflow', 'torch', 'numpy'}, for 'jax', the error messages:

Using JAX backend.
Traceback (most recent call last):
  File "/tmp/foo.py", line 20, in <module>
    f = MyLayer()(foo)
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/keras_core/src/utils/traceback_utils.py", line 123, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/tmp/foo.py", line 18, in call
    return ops.tile(mask, mult)
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 1796, in tile
    result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]),
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 1173, in broadcast_to
    return util._broadcast_to(array, shape)
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/jax/_src/numpy/util.py", line 388, in _broadcast_to
    if not isinstance(shape, tuple) and np.ndim(shape) == 0:
  File "<__array_function__ internals>", line 180, in ndim
  File "/Users/freedom/tf-master/lib/python3.9/site-packages/numpy/core/fromnumeric.py", line 3156, in ndim
    return asarray(a).ndim
RuntimeError: Exception encountered when calling MyLayer.call().

Could not automatically infer the output shape / dtype of 'my_layer' (of type MyLayer). Either the `MyLayer.call()` method is incorrect, or you need to implement the `MyLayer.compute_output_spec() / compute_output_shape()` method. Error encountered:

The numpy.ndarray conversion method __array__() was called on traced array with shape int32[].
The error occurred while tracing the function wrapped_fn at /Users/freedom/tf-master/lib/python3.9/site-packages/keras_core/src/backend/jax/core.py:138 for make_jaxpr. This value became a tracer due to JAX operations on these lines:

  operation a:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] b
    from line /Users/freedom/tf-master/lib/python3.9/site-packages/keras_core/src/backend/jax/numpy.py:227 (expand_dims)

  operation a:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /Users/freedom/tf-master/lib/python3.9/site-packages/keras_core/src/backend/jax/core.py:40 (convert_to_tensor)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Arguments received by MyLayer.call():
  • args=('<KerasTensor shape=(64, 20, 256), dtype=float32, name=keras_tensor>',)
  • kwargs=<class 'inspect._empty'>

@fchollet fchollet added the bug label Aug 5, 2023
@fchollet
Copy link
Member

fchollet commented Aug 5, 2023

I looked into it and this is quite blatantly a JAX bug.

    return backend.numpy.tile(x, repeats)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/keras_core/src/backend/jax/numpy.py", line 523, in tile
    return jnp.tile(x, repeats)
  File "/Users/fchollet/Library/Python/3.10/lib/python/site-packages/jax/_src/numpy/lax_numpy.py", line 1796, in tile
    result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]),
  File "/Users/fchollet/Library/Python/3.10/lib/python/site-packages/jax/_src/numpy/lax_numpy.py", line 1173, in broadcast_to
    return util._broadcast_to(array, shape)
  File "/Users/fchollet/Library/Python/3.10/lib/python/site-packages/jax/_src/numpy/util.py", line 388, in _broadcast_to
    if not isinstance(shape, tuple) and np.ndim(shape) == 0:
  File "<__array_function__ internals>", line 200, in ndim
  File "/Users/fchollet/Library/Python/3.10/lib/python/site-packages/numpy/core/fromnumeric.py", line 3187, in ndim
    return asarray(a).ndim

As far as I can tell what's going on is that jpn.tile() does not expect to be called with a reps argument that is not static.

The layer works eagerly but would not work in any symbolic JAX transformation. I guess you could file a JAX bug report...

@freedomtan
Copy link
Contributor Author

Do we have an option to disable jit for a specific code block in Keras Core? If I disable jit in get_causal_attention_mask() with with jax.ensure_compile_time_eval(), JAX backend also works.

def get_causal_attention_mask(self, inputs):
        with jax.ensure_compile_time_eval():
            input_shape = ops.shape(inputs)
            batch_size, sequence_length = input_shape[0], input_shape[1]
            i = ops.arange(sequence_length)[:, None]
            j = ops.arange(sequence_length)
            mask = ops.cast(i >= j, dtype="int32")
            mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
            mult = ops.concatenate(
                [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],
            )
            return ops.tile(mask, mult)

I updated the jax notebook.

@fchollet
Copy link
Member

fchollet commented Aug 7, 2023

Do we have an option to disable jit for a specific code block in Keras Core?

This is not doable in a cross-backend way, unfortunately. Also, it will lead to a considerable drop in performance even if the block is small.

I think our best choice here is to leave the example as-is, but add a comment paragraph explaining that JAX is not supported due to the behavior of ops.tile in JAX when tracing, and suggesting your workaround (ensure_compile_time_eval).

return self.layernorm_3(out_2 + proj_output)

def get_causal_attention_mask(self, inputs):
# due to a bug in jax.tile(), abstract/jit symbols don't work
Copy link
Member

@fchollet fchollet Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do the following at the start of the example:

# We set the backend to TensorFlow. The code works with
# both `tensorflow` and `torch`. It does not work with JAX
# due to the behavior of `jax.numpy.tile` in a jit scope
# (used in `TransformerDecoder.get_causal_attention_mask()`:
#`tile` in JAX does not support a dynamic `reps` argument.
# You can make the code work in JAX by wrapping the
# inside of the `get_causal_attention_mask` method in
# a decorator to prevent jit compilation:
# `with jax.ensure_compile_time_eval():`.
import os
os["KERAS_BACKEND"] = "tensorflow"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you

@fchollet fchollet merged commit 65f58cc into keras-team:main Aug 8, 2023
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants