From fe89fa54e3ce99b64439b34d58923e2647879104 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Tue, 10 Sep 2024 11:51:26 +0400 Subject: [PATCH] [JAX FE] Document Support of JAX/Flax models by OpenVINO Signed-off-by: Kazantsev, Roman --- .../openvino-workflow/model-preparation.rst | 5 +- .../conversion-parameters.rst | 6 +- .../model-preparation/convert-model-jax.rst | 177 ++++++++++++++++++ .../model-preparation/convert-model-to-ir.rst | 78 +++++++- 4 files changed, 259 insertions(+), 7 deletions(-) create mode 100644 docs/articles_en/openvino-workflow/model-preparation/convert-model-jax.rst diff --git a/docs/articles_en/openvino-workflow/model-preparation.rst b/docs/articles_en/openvino-workflow/model-preparation.rst index bea0fcdba5311b..c23540874e9b7a 100644 --- a/docs/articles_en/openvino-workflow/model-preparation.rst +++ b/docs/articles_en/openvino-workflow/model-preparation.rst @@ -22,6 +22,7 @@ OpenVINO supports the following model formats: * TensorFlow Lite, * ONNX, * PaddlePaddle, +* JAX/Flax (experimental feature) * OpenVINO IR. The easiest way to obtain a model is to download it from an online database, such as @@ -61,7 +62,7 @@ The easiest way to obtain a model is to download it from an online database, suc CLI tool. For more details, see the :doc:`Model Conversion API Transition Guide <../documentation/legacy-features/transition-legacy-conversion-api>`. - For PyTorch models, `Python API <#convert-a-model-with-python-convert-model>`__ is the only + For PyTorch and JAX/Flax models, `Python API <#convert-a-model-with-python-convert-model>`__ is the only conversion option. Different model representations @@ -280,7 +281,7 @@ formats to which can then be read, compiled, and run by the final inference application. .. note:: - PyTorch models cannot be converted with ``ovc``, use ``openvino.convert_model`` instead. + PyTorch and JAX/Flax models cannot be converted with ``ovc``, use ``openvino.convert_model`` instead. Additional Resources #################### diff --git a/docs/articles_en/openvino-workflow/model-preparation/conversion-parameters.rst b/docs/articles_en/openvino-workflow/model-preparation/conversion-parameters.rst index 6a7023fc16afc3..74847bf9f1f884 100644 --- a/docs/articles_en/openvino-workflow/model-preparation/conversion-parameters.rst +++ b/docs/articles_en/openvino-workflow/model-preparation/conversion-parameters.rst @@ -49,9 +49,9 @@ which are described below. - ``example_input`` parameter available in Python ``openvino.convert_model`` only is intended to trace the model to obtain its graph representation. This parameter is crucial - for converting PyTorch models and may sometimes be required for TensorFlow models. - For more details, refer to the :doc:`PyTorch Model Conversion ` - or :doc:`TensorFlow Model Conversion `. + for converting PyTorch and Flax models and may sometimes be required for TensorFlow models. + For more details, refer to the :doc:`PyTorch Model Conversion `, + :doc:`TensorFlow Model Conversion ` or :doc:`JAX/Flax Model Conversion `. - ``input`` parameter to set or override shapes for model inputs. It configures dynamic and static dimensions in model inputs depending on your inference requirements. For more diff --git a/docs/articles_en/openvino-workflow/model-preparation/convert-model-jax.rst b/docs/articles_en/openvino-workflow/model-preparation/convert-model-jax.rst new file mode 100644 index 00000000000000..f18f5147fc52a7 --- /dev/null +++ b/docs/articles_en/openvino-workflow/model-preparation/convert-model-jax.rst @@ -0,0 +1,177 @@ +(Experimental) Converting a JAX/Flax Model +========================================== + + +.. meta:: + :description: Learn how to convert a model from the + JAX/Flax format to the OpenVINO Model. + + +``openvino.convert_model`` function supports the following JAX/Flax model object types: + +* ``jax._src.core.ClosedJaxpr`` +* ``flax.linen.Module`` + +``jax._src.core.ClosedJaxpr`` object is created by tracing Python function using ``jax.make_jaxpr`` function. +Here is an example of ``jax._src.core.ClosedJaxpr`` object creation and conversion to OpenVINO model: + +.. code-block:: py + :force: + + import jax + import jax.numpy as jnp + import openvino as ov + + # let us have some JAX function + def jax_func(x, y): + return jax.lax.tanh(jax.lax.max(x, y)) + + # 1. Create ClosedJaxpr object + x = jnp.array([1.0, 2.0]) + y = jnp.array([-1.0, 10.0]) + jaxpr = jax.make_jaxpr(jax_func)(x, y) + # 2. Convert to OpenVINO + ov_model = ov.convert_model(jaxpr) + + +Here is an example of the simplest ``flax.linen.Module`` model conversion: + +.. code-block:: py + :force: + + import flax.linen as nn + import jax + import jax.numpy as jnp + import openvino as ov + + # let user have some Flax module + class SimpleModule(nn.Module): + features: int + + @nn.compact + def __call__(self, x): + return nn.Dense(features=self.features)(x) + + module = SimpleModule(features=4) + + # create example_input used for training + example_input = jnp.ones((2, 3)) + + # prepare parameters to initialize the module + # they can be also loaded using pickle, flax.serialization + key = jax.random.PRNGKey(0) + params = module.init(key, example_input) + module = module.bind(params) + + ov_model = ov.convert_model(module, example_input=example_input) + + +When using ``flax.linen.Module`` as an input model, ``openvino.convert_model`` requires the +``example_input`` parameter to be specified. Internally, it triggers the model tracing during +the model conversion process, using the capabilities of the ``jax.make_jaxpr`` function. + +The ``__call__`` method of ``flax.linen.Module`` object can also have extra custom flags +(like ``training``) in the input signature. In this case, it is required to create a helper function +which has an input signature without any extra custom flags and parameters not related to input data. +Here is an example of handling such case below: + +.. code-block:: py + :force: + + import jax + import jax.numpy as jnp + import openvino as ov + from flax import linen as nn + from flax.core import freeze, unfreeze + + class SimpleModuleWithExtraFlag(nn.Module): + features: int + + @nn.compact + def __call__(self, x, training): + x = nn.Dense(self.features)(x) + x = nn.BatchNorm(use_running_average=not training)(x) + return x + + + # 1. Initialize the model + module = SimpleModuleWithExtraFlag(features=10) + key = jax.random.PRNGKey(0) + input_data = jnp.ones((4, 5)) # Batch of 4 samples, each with 5 features + params = module.init(key, input_data, training=False) + + # 2. Create helper function with only input data parameter + def helper_function(x): + return module.apply(params, x, training=False) + + # 3. Trace the helper function + jaxpr = jax.make_jaxpr(helper_function)(input_data) + + # 4. Convert to OpenVINO + ov_model = ov.convert_model(jaxpr) + + +.. note:: + + In the examples above the ``openvino.save_model`` function is not used because there are no + JAX-specific details regarding the usage of this function. In all examples, the converted + OpenVINO model can be saved to IR by calling ``ov.save_model(ov_model, 'model.xml')`` as usual. + +Exporting a JAX/Flax Model to TensorFlow SavedModel Format +########################################################## + +An alternative method of converting JAX/Flax models is exporting a JAX/Flax model to TensorFlow SavedModel format +with ``jax.experimental.jax2tf.convert`` first and then converting the resulting SavedModel directory to OpenVINO Model +with ``openvino.convert_model``. It can be considered as a backup solution if a model cannot be +converted directly from JAX/Flax to OpenVINO as described in the above chapters. + +1. Refer to the `JAX and TensorFlow interoperation https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md>`__ + guide to learn how to export models from JAX to SavedModel format. +2. Follow :doc:`Convert a TensorFlow model ` chapter to produce OpenVINO model. + +Here is an illustration of using these two steps together: + +.. code-block:: py + :force: + + import flax.linen as nn + import jax + import jax.experimental.jax2tf as jax2tf + import jax.numpy as jnp + import openvino as ov + import openvino as ov + import tensorflow as tf + + # let user have some Flax module + class SimpleModule(nn.Module): + features: int + + @nn.compact + def __call__(self, x): + return nn.Dense(features=self.features)(x) + + flax_module = SimpleModule(features=4) + + # prepare parameters to initialize the module + # they can be also loaded using pickle, flax.serialization + example_input = jnp.ones((2, 3)) + key = jax.random.PRNGKey(0) + params = flax_module.init(key, example_input) + module = flax_module.bind(params) + + # 1. Export to SavedModel + # create TF function and wrap it into TF Module + tf_function = tf.function(jax2tf.convert(flax_module, native_serialization=False), autograph=False, + input_signature=[tf.TensorSpec(shape=[2, 3], dtype=tf.float32)]) + tf_module = tf.Module() + tf_module.f = tf_function + tf.saved_model.save(tf_module, 'saved_model') + + # 2. Convert to OpenVINO + ov_model = ov.convert_model('saved_model') + +.. note:: + + As of version 0.4.15, it is requred to pass ``native_serialization=False`` parameter + into ``jax2tf.convert`` for graph serialization mode. Without this option, the created TensorFlow + function will contain the embedded StableHLO modules that are not handled by OpenVINO TensorFlow Frontend. diff --git a/docs/articles_en/openvino-workflow/model-preparation/convert-model-to-ir.rst b/docs/articles_en/openvino-workflow/model-preparation/convert-model-to-ir.rst index 171422f932ea5b..dcaafd24033746 100644 --- a/docs/articles_en/openvino-workflow/model-preparation/convert-model-to-ir.rst +++ b/docs/articles_en/openvino-workflow/model-preparation/convert-model-to-ir.rst @@ -13,6 +13,7 @@ Convert to OpenVINO IR Convert from ONNX Convert from TensorFlow Lite Convert from PaddlePaddle + Convert from JAX/Flax @@ -572,12 +573,85 @@ used by OpenVINO, typically obtained by converting models of supported framework For details on the conversion, refer to the :doc:`article `. + .. tab-item:: JAX/Flax + :sync: torch + + .. tab-set:: + + .. tab-item:: Python + :sync: py + + The ``convert_model()`` method is the only method applicable to JAX/Flax models. + + .. dropdown:: List of supported formats: + + * **Python objects**: + + * ``jax._src.core.ClosedJaxpr`` + * ``flax.linen.Module`` + + * Converion of ``jax._src.core.ClosedJaxpr`` object + + .. code-block:: py + :force: + + import jax + import jax.numpy as jnp + import openvino as ov + + # let user have some JAX function + def jax_func(x, y): + return jax.lax.tanh(jax.lax.max(x, y)) + + # use example inputs for creation of ClosedJaxpr object + x = jnp.array([1.0, 2.0]) + y = jnp.array([-1.0, 10.0]) + jaxpr = jax.make_jaxpr(jax_func)(x, y) + + ov_model = ov.convert_model(jaxpr) + compiled_model = ov.compile_model(ov_model, "AUTO") + + * Converion of ``flax.linen.Module`` object + + .. code-block:: py + :force: + + import flax.linen as nn + import jax + import jax.numpy as jnp + import openvino as ov + + # let user have some Flax module + class SimpleDenseModule(nn.Module): + features: int + + @nn.compact + def __call__(self, x): + return nn.Dense(features=self.features)(x) + + module = SimpleDenseModule(features=4) + + # create example_input used in training + example_input = jnp.ones((2, 3)) + + # prepare parameters to initialize the module + # they can be also loaded from a disk + # using pickle, flax.serialization for deserialization + key = jax.random.PRNGKey(0) + params = module.init(key, example_input) + module = module.bind(params) + + ov_model = ov.convert_model(module, example_input=example_input) + compiled_model = ov.compile_model(ov_model, "AUTO") + + For more details on conversion, refer to the :doc:`guide `. + These are basic examples, for detailed conversion instructions, see the individual guides on :doc:`PyTorch `, :doc:`ONNX `, :doc:`TensorFlow `, :doc:`TensorFlow Lite `, -and :doc:`PaddlePaddle `. +:doc:`PaddlePaddle ` and :doc:`JAX/Flax `. Refer to the list of all supported conversion options in :doc:`Conversion Parameters `. @@ -596,7 +670,7 @@ IR Conversion Benefits especially useful for large models, like Llama2-7B. | **Saving to IR to avoid large dependencies in inference code** -| Frameworks such as TensorFlow and PyTorch tend to be large dependencies for applications +| Frameworks such as TensorFlow, PyTorch and JAX/Flax tend to be large dependencies for applications running inference (multiple gigabytes). Converting models to OpenVINO IR removes this dependency, as OpenVINO can run its inference with no additional components. This way, much less disk space is needed, while loading and compiling usually takes less