Skip to content

Commit

Permalink
[JAX FE] Document Support of JAX/Flax models by OpenVINO
Browse files Browse the repository at this point in the history
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
  • Loading branch information
rkazants committed Sep 10, 2024
1 parent 94a9675 commit fe89fa5
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 7 deletions.
5 changes: 3 additions & 2 deletions docs/articles_en/openvino-workflow/model-preparation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ OpenVINO supports the following model formats:
* TensorFlow Lite,
* ONNX,
* PaddlePaddle,
* JAX/Flax (experimental feature)
* OpenVINO IR.

The easiest way to obtain a model is to download it from an online database, such as
Expand Down Expand Up @@ -61,7 +62,7 @@ The easiest way to obtain a model is to download it from an online database, suc
CLI tool. For more details, see the
:doc:`Model Conversion API Transition Guide <../documentation/legacy-features/transition-legacy-conversion-api>`.

For PyTorch models, `Python API <#convert-a-model-with-python-convert-model>`__ is the only
For PyTorch and JAX/Flax models, `Python API <#convert-a-model-with-python-convert-model>`__ is the only
conversion option.

Different model representations
Expand Down Expand Up @@ -280,7 +281,7 @@ formats to
which can then be read, compiled, and run by the final inference application.

.. note::
PyTorch models cannot be converted with ``ovc``, use ``openvino.convert_model`` instead.
PyTorch and JAX/Flax models cannot be converted with ``ovc``, use ``openvino.convert_model`` instead.

Additional Resources
####################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ which are described below.

- ``example_input`` parameter available in Python ``openvino.convert_model`` only is
intended to trace the model to obtain its graph representation. This parameter is crucial
for converting PyTorch models and may sometimes be required for TensorFlow models.
For more details, refer to the :doc:`PyTorch Model Conversion <convert-model-pytorch>`
or :doc:`TensorFlow Model Conversion <convert-model-tensorflow>`.
for converting PyTorch and Flax models and may sometimes be required for TensorFlow models.
For more details, refer to the :doc:`PyTorch Model Conversion <convert-model-pytorch>`,
:doc:`TensorFlow Model Conversion <convert-model-tensorflow>` or :doc:`JAX/Flax Model Conversion <convert-model-jax>`.

- ``input`` parameter to set or override shapes for model inputs. It configures dynamic
and static dimensions in model inputs depending on your inference requirements. For more
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
(Experimental) Converting a JAX/Flax Model
==========================================


.. meta::
:description: Learn how to convert a model from the
JAX/Flax format to the OpenVINO Model.


``openvino.convert_model`` function supports the following JAX/Flax model object types:

* ``jax._src.core.ClosedJaxpr``
* ``flax.linen.Module``

``jax._src.core.ClosedJaxpr`` object is created by tracing Python function using ``jax.make_jaxpr`` function.
Here is an example of ``jax._src.core.ClosedJaxpr`` object creation and conversion to OpenVINO model:

.. code-block:: py
:force:
import jax
import jax.numpy as jnp
import openvino as ov
# let us have some JAX function
def jax_func(x, y):
return jax.lax.tanh(jax.lax.max(x, y))
# 1. Create ClosedJaxpr object
x = jnp.array([1.0, 2.0])
y = jnp.array([-1.0, 10.0])
jaxpr = jax.make_jaxpr(jax_func)(x, y)
# 2. Convert to OpenVINO
ov_model = ov.convert_model(jaxpr)
Here is an example of the simplest ``flax.linen.Module`` model conversion:

.. code-block:: py
:force:
import flax.linen as nn
import jax
import jax.numpy as jnp
import openvino as ov
# let user have some Flax module
class SimpleModule(nn.Module):
features: int
@nn.compact
def __call__(self, x):
return nn.Dense(features=self.features)(x)
module = SimpleModule(features=4)
# create example_input used for training
example_input = jnp.ones((2, 3))
# prepare parameters to initialize the module
# they can be also loaded using pickle, flax.serialization
key = jax.random.PRNGKey(0)
params = module.init(key, example_input)
module = module.bind(params)
ov_model = ov.convert_model(module, example_input=example_input)
When using ``flax.linen.Module`` as an input model, ``openvino.convert_model`` requires the
``example_input`` parameter to be specified. Internally, it triggers the model tracing during
the model conversion process, using the capabilities of the ``jax.make_jaxpr`` function.

The ``__call__`` method of ``flax.linen.Module`` object can also have extra custom flags
(like ``training``) in the input signature. In this case, it is required to create a helper function
which has an input signature without any extra custom flags and parameters not related to input data.
Here is an example of handling such case below:

.. code-block:: py
:force:
import jax
import jax.numpy as jnp
import openvino as ov
from flax import linen as nn
from flax.core import freeze, unfreeze
class SimpleModuleWithExtraFlag(nn.Module):
features: int
@nn.compact
def __call__(self, x, training):
x = nn.Dense(self.features)(x)
x = nn.BatchNorm(use_running_average=not training)(x)
return x
# 1. Initialize the model
module = SimpleModuleWithExtraFlag(features=10)
key = jax.random.PRNGKey(0)
input_data = jnp.ones((4, 5)) # Batch of 4 samples, each with 5 features
params = module.init(key, input_data, training=False)
# 2. Create helper function with only input data parameter
def helper_function(x):
return module.apply(params, x, training=False)
# 3. Trace the helper function
jaxpr = jax.make_jaxpr(helper_function)(input_data)
# 4. Convert to OpenVINO
ov_model = ov.convert_model(jaxpr)
.. note::

In the examples above the ``openvino.save_model`` function is not used because there are no
JAX-specific details regarding the usage of this function. In all examples, the converted
OpenVINO model can be saved to IR by calling ``ov.save_model(ov_model, 'model.xml')`` as usual.

Exporting a JAX/Flax Model to TensorFlow SavedModel Format
##########################################################

An alternative method of converting JAX/Flax models is exporting a JAX/Flax model to TensorFlow SavedModel format
with ``jax.experimental.jax2tf.convert`` first and then converting the resulting SavedModel directory to OpenVINO Model
with ``openvino.convert_model``. It can be considered as a backup solution if a model cannot be
converted directly from JAX/Flax to OpenVINO as described in the above chapters.

1. Refer to the `JAX and TensorFlow interoperation https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md>`__
guide to learn how to export models from JAX to SavedModel format.
2. Follow :doc:`Convert a TensorFlow model <convert-model-tensorflow>` chapter to produce OpenVINO model.

Here is an illustration of using these two steps together:

.. code-block:: py
:force:
import flax.linen as nn
import jax
import jax.experimental.jax2tf as jax2tf
import jax.numpy as jnp
import openvino as ov
import openvino as ov
import tensorflow as tf
# let user have some Flax module
class SimpleModule(nn.Module):
features: int
@nn.compact
def __call__(self, x):
return nn.Dense(features=self.features)(x)
flax_module = SimpleModule(features=4)
# prepare parameters to initialize the module
# they can be also loaded using pickle, flax.serialization
example_input = jnp.ones((2, 3))
key = jax.random.PRNGKey(0)
params = flax_module.init(key, example_input)
module = flax_module.bind(params)
# 1. Export to SavedModel
# create TF function and wrap it into TF Module
tf_function = tf.function(jax2tf.convert(flax_module, native_serialization=False), autograph=False,
input_signature=[tf.TensorSpec(shape=[2, 3], dtype=tf.float32)])
tf_module = tf.Module()
tf_module.f = tf_function
tf.saved_model.save(tf_module, 'saved_model')
# 2. Convert to OpenVINO
ov_model = ov.convert_model('saved_model')
.. note::

As of version 0.4.15, it is requred to pass ``native_serialization=False`` parameter
into ``jax2tf.convert`` for graph serialization mode. Without this option, the created TensorFlow
function will contain the embedded StableHLO modules that are not handled by OpenVINO TensorFlow Frontend.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Convert to OpenVINO IR
Convert from ONNX <convert-model-onnx>
Convert from TensorFlow Lite <convert-model-tensorflow-lite>
Convert from PaddlePaddle <convert-model-paddle>
Convert from JAX/Flax <convert-model-jax>



Expand Down Expand Up @@ -572,12 +573,85 @@ used by OpenVINO, typically obtained by converting models of supported framework
For details on the conversion, refer to the
:doc:`article <convert-model-paddle>`.

.. tab-item:: JAX/Flax
:sync: torch

.. tab-set::

.. tab-item:: Python
:sync: py

The ``convert_model()`` method is the only method applicable to JAX/Flax models.

.. dropdown:: List of supported formats:

* **Python objects**:

* ``jax._src.core.ClosedJaxpr``
* ``flax.linen.Module``

* Converion of ``jax._src.core.ClosedJaxpr`` object

.. code-block:: py
:force:
import jax
import jax.numpy as jnp
import openvino as ov
# let user have some JAX function
def jax_func(x, y):
return jax.lax.tanh(jax.lax.max(x, y))
# use example inputs for creation of ClosedJaxpr object
x = jnp.array([1.0, 2.0])
y = jnp.array([-1.0, 10.0])
jaxpr = jax.make_jaxpr(jax_func)(x, y)
ov_model = ov.convert_model(jaxpr)
compiled_model = ov.compile_model(ov_model, "AUTO")
* Converion of ``flax.linen.Module`` object

.. code-block:: py
:force:
import flax.linen as nn
import jax
import jax.numpy as jnp
import openvino as ov
# let user have some Flax module
class SimpleDenseModule(nn.Module):
features: int
@nn.compact
def __call__(self, x):
return nn.Dense(features=self.features)(x)
module = SimpleDenseModule(features=4)
# create example_input used in training
example_input = jnp.ones((2, 3))
# prepare parameters to initialize the module
# they can be also loaded from a disk
# using pickle, flax.serialization for deserialization
key = jax.random.PRNGKey(0)
params = module.init(key, example_input)
module = module.bind(params)
ov_model = ov.convert_model(module, example_input=example_input)
compiled_model = ov.compile_model(ov_model, "AUTO")
For more details on conversion, refer to the :doc:`guide <convert-model-jax>`.



These are basic examples, for detailed conversion instructions, see the individual guides on
:doc:`PyTorch <convert-model-pytorch>`, :doc:`ONNX <convert-model-onnx>`,
:doc:`TensorFlow <convert-model-tensorflow>`, :doc:`TensorFlow Lite <convert-model-tensorflow-lite>`,
and :doc:`PaddlePaddle <convert-model-paddle>`.
:doc:`PaddlePaddle <convert-model-paddle>` and :doc:`JAX/Flax <convert-model-jax>`.

Refer to the list of all supported conversion options in :doc:`Conversion Parameters <conversion-parameters>`.

Expand All @@ -596,7 +670,7 @@ IR Conversion Benefits
especially useful for large models, like Llama2-7B.
| **Saving to IR to avoid large dependencies in inference code**
| Frameworks such as TensorFlow and PyTorch tend to be large dependencies for applications
| Frameworks such as TensorFlow, PyTorch and JAX/Flax tend to be large dependencies for applications
running inference (multiple gigabytes). Converting models to OpenVINO IR removes this
dependency, as OpenVINO can run its inference with no additional components.
This way, much less disk space is needed, while loading and compiling usually takes less
Expand Down

0 comments on commit fe89fa5

Please sign in to comment.