diff --git a/src/bindings/python/src/openvino/frontend/jax/utils.py b/src/bindings/python/src/openvino/frontend/jax/utils.py index 936c86c2e28b9e..aa57c6f6836424 100644 --- a/src/bindings/python/src/openvino/frontend/jax/utils.py +++ b/src/bindings/python/src/openvino/frontend/jax/utils.py @@ -13,7 +13,7 @@ numpy_to_ov_type_map = { np.float32: OVType.f32, bool: OVType.boolean, - jax.dtypes.bfloat16: OVType.bf16, # TODO: check this + jax.dtypes.bfloat16: OVType.bf16, # TODO: check this np.float16: OVType.f16, np.float32: OVType.f32, np.float64: OVType.f64, @@ -29,7 +29,7 @@ jax_to_ov_type_map = { jnp.float32: OVType.f32, - jnp.bfloat16: OVType.bf16, # TODO: check this + jnp.bfloat16: OVType.bf16, # TODO: check this jnp.float16: OVType.f16, jnp.float64: OVType.f64, jnp.uint8: OVType.u8, @@ -62,6 +62,9 @@ OVType.f16: 5, OVType.f32: 6, OVType.f64: 7, + OVType.u16: 8, + OVType.u32: 9, + OVType.u64: 10, OVType.boolean: 11, OVType.bf16: 15, } diff --git a/src/frontends/jax/src/op/concatenate.cpp b/src/frontends/jax/src/op/concatenate.cpp new file mode 100644 index 00000000000000..762bd6ce1024d8 --- /dev/null +++ b/src/frontends/jax/src/op/concatenate.cpp @@ -0,0 +1,35 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/jax/node_context.hpp" +#include "openvino/op/concat.hpp" +#include "utils.hpp" + +using namespace std; +using namespace ov; +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace jax { +namespace op { + +OutputVector translate_concatenate(const NodeContext& context) { + num_inputs_check(context, 1); + auto num_inputs = static_cast(context.get_input_size()); + int64_t axis = context.const_named_param("dimension"); + + OutputVector inputs(num_inputs); + for (int ind = 0; ind < num_inputs; ++ind) { + inputs[ind] = context.get_input(ind); + } + Output res = make_shared(inputs, axis); + + return {res}; +}; + +} // namespace op +} // namespace jax +} // namespace frontend +} // namespace ov diff --git a/src/frontends/jax/src/op/integer_pow.cpp b/src/frontends/jax/src/op/integer_pow.cpp new file mode 100644 index 00000000000000..059d3fe94ef30f --- /dev/null +++ b/src/frontends/jax/src/op/integer_pow.cpp @@ -0,0 +1,32 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/jax/node_context.hpp" +#include "openvino/op/power.hpp" +#include "utils.hpp" + +using namespace std; +using namespace ov; +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace jax { +namespace op { + +OutputVector translate_integer_pow(const NodeContext& context) { + num_inputs_check(context, 1, 1); + auto x = context.get_input(0); + int64_t y = context.const_named_param("y"); + + // create y const of the same type as x + auto y_const = create_same_type_const_scalar(x, y); + Output res = make_shared(x, y_const); + return {res}; +}; + +} // namespace op +} // namespace jax +} // namespace frontend +} // namespace ov diff --git a/src/frontends/jax/src/op/reduce.cpp b/src/frontends/jax/src/op/reduce.cpp new file mode 100644 index 00000000000000..704b8a46332e3f --- /dev/null +++ b/src/frontends/jax/src/op/reduce.cpp @@ -0,0 +1,36 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/jax/node_context.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/reduce_max.hpp" +#include "openvino/op/reduce_sum.hpp" +#include "utils.hpp" + +using namespace std; +using namespace ov; +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace jax { +namespace op { + +template +OutputVector translate_reduce_op(const NodeContext& context) { + num_inputs_check(context, 1, 1); + auto axes = context.const_named_param>("axes"); + auto a = context.get_input(0); + auto axes_const = make_shared(element::i64, Shape{axes.size()}, axes); + Output reduce_op = make_shared(a, axes_const, false); + return {reduce_op}; +} + +template OutputVector translate_reduce_op(const NodeContext& node); +template OutputVector translate_reduce_op(const NodeContext& node); + +} // namespace op +} // namespace jax +} // namespace frontend +} // namespace ov diff --git a/src/frontends/jax/src/op_table.cpp b/src/frontends/jax/src/op_table.cpp index 00e40b1dfa5f7b..762d216e81e81c 100644 --- a/src/frontends/jax/src/op_table.cpp +++ b/src/frontends/jax/src/op_table.cpp @@ -9,6 +9,8 @@ #include "openvino/op/exp.hpp" #include "openvino/op/maximum.hpp" #include "openvino/op/multiply.hpp" +#include "openvino/op/reduce_max.hpp" +#include "openvino/op/reduce_sum.hpp" #include "openvino/op/sqrt.hpp" #include "openvino/op/subtract.hpp" #include "openvino/op/tanh.hpp" @@ -22,13 +24,19 @@ namespace jax { namespace op { #define OP_CONVERTER(op) OutputVector op(const NodeContext& node) +#define OP_T_CONVERTER(op) \ + template \ + OutputVector op(const ov::frontend::jax::NodeContext& node) OP_CONVERTER(translate_broadcast_in_dim); +OP_CONVERTER(translate_concatenate); OP_CONVERTER(translate_constant); OP_CONVERTER(translate_convert); OP_CONVERTER(translate_convolution); OP_CONVERTER(translate_copy); OP_CONVERTER(translate_dot_general); +OP_CONVERTER(translate_integer_pow); +OP_T_CONVERTER(translate_reduce_op); OP_CONVERTER(translate_reduce_window_max); OP_CONVERTER(translate_reduce_window_sum); OP_CONVERTER(translate_reshape); @@ -43,6 +51,7 @@ OP_CONVERTER(translate_transpose); const std::map get_supported_ops_jaxpr() { return {{"add", op::translate_1to1_match_2_inputs}, {"broadcast_in_dim", op::translate_broadcast_in_dim}, + {"concatenate", op::translate_concatenate}, {"constant", op::translate_constant}, {"convert_element_type", op::translate_convert}, {"conv_general_dilated", op::translate_convolution}, @@ -51,8 +60,11 @@ const std::map get_supported_ops_jaxpr() { {"div", op::translate_1to1_match_2_inputs}, {"dot_general", op::translate_dot_general}, {"exp", op::translate_1to1_match_1_input}, + {"integer_pow", op::translate_integer_pow}, {"max", op::translate_1to1_match_2_inputs}, {"mul", op::translate_1to1_match_2_inputs}, + {"reduce_max", op::translate_reduce_op}, + {"reduce_sum", op::translate_reduce_op}, {"reduce_window_max", op::translate_reduce_window_max}, {"reduce_window_sum", op::translate_reduce_window_sum}, {"transpose", op::translate_transpose}, diff --git a/src/frontends/jax/src/utils.cpp b/src/frontends/jax/src/utils.cpp index d1092a37dcf9dd..d47abfbba56188 100644 --- a/src/frontends/jax/src/utils.cpp +++ b/src/frontends/jax/src/utils.cpp @@ -18,6 +18,11 @@ void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_ FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= min_inputs, "Got less inputs than expected"); } +void num_inputs_check(const NodeContext& context, size_t min_inputs) { + auto inputs = context.inputs(); + FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= min_inputs, "Got less inputs than expected"); +} + const std::string& get_jax_prefix() { return jax_prefix; } @@ -48,6 +53,9 @@ const std::unordered_map JAX_TO_OV_TYPE{ {5, element::f16}, {6, element::f32}, {7, element::f64}, + {8, element::u16}, + {9, element::u32}, + {10, element::u64}, {11, element::boolean}, {12, element::i8}, // quantized i8 {13, element::u8}, // quantized u8 diff --git a/src/frontends/jax/src/utils.hpp b/src/frontends/jax/src/utils.hpp index ab806bfc61da7c..f5a109b8602e9d 100644 --- a/src/frontends/jax/src/utils.hpp +++ b/src/frontends/jax/src/utils.hpp @@ -8,6 +8,7 @@ #include "openvino/frontend/jax/node_context.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/convert_like.hpp" namespace ov { @@ -37,6 +38,8 @@ const std::string& get_jax_prefix(); void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs); +void num_inputs_check(const NodeContext& context, size_t min_inputs); + void add_exception_to_fw_node(std::shared_ptr node, const std::string& msg); bool is_python_scalar_input(const NodeContext& context, size_t index); @@ -62,6 +65,18 @@ inline OutputVector skip_node(const NodeContext& context) { return {context.get_input(0)}; } +template +ov::Output create_same_type_const_scalar(const ov::Output& same_type_output, const T& value) { + if (same_type_output.get_element_type().is_static()) { + return std::make_shared(same_type_output.get_element_type(), ov::Shape{}, value); + } else { + ov::Output const_res = + std::make_shared(ov::element::from(), ov::Shape{}, value); + const_res = std::make_shared(const_res, same_type_output); + return const_res; + } +} + } // namespace op } // namespace jax diff --git a/tests/layer_tests/jax_tests/test_concatenate.py b/tests/layer_tests/jax_tests/test_concatenate.py new file mode 100644 index 00000000000000..a476a29d9bc449 --- /dev/null +++ b/tests/layer_tests/jax_tests/test_concatenate.py @@ -0,0 +1,51 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import jax +import numpy as np +import pytest +from jax import lax +from jax import numpy as jnp + +from jax_layer_test_class import JaxLayerTest + +rng = np.random.default_rng(34455) + + +class TestConcatenate(JaxLayerTest): + def _prepare_input(self): + inputs = [] + for input_shape in self.input_shapes: + if np.issubdtype(self.input_type, np.floating): + x = rng.uniform(-5.0, 5.0, input_shape).astype(self.input_type) + elif np.issubdtype(self.input_type, np.signedinteger): + x = rng.integers(-8, 8, input_shape).astype(self.input_type) + else: + x = rng.integers(0, 8, input_shape).astype(self.input_type) + x = jnp.array(x) + inputs.append(x) + return inputs + + def create_model(self, input_shapes, input_type, dimension): + self.input_shapes = input_shapes + self.input_type = input_type + + def jax_concatenate(*arrays): + return jax.lax.concatenate(arrays, dimension) + + return jax_concatenate, None, 'concatenate' + + @pytest.mark.parametrize('input_shapes,dimension', [ + ([[2], [3]], 0), + ([[2, 3], [2, 4]], 1), + ([[1, 2, 3], [1, 2, 2], [1, 2, 1]], 2), + ]) + @pytest.mark.parametrize('input_type', [np.int8, np.uint8, np.int16, np.uint16, + np.int32, np.uint32, np.int64, np.uint64, + np.float16, np.float32, np.float64]) + @pytest.mark.nightly + @pytest.mark.precommit_jax_fe + def test_concatenate(self, input_shapes, input_type, dimension, + ie_device, precision, ir_version): + self._test(*self.create_model(input_shapes, input_type, dimension), + ie_device, precision, ir_version) diff --git a/tests/layer_tests/jax_tests/test_integer_pow.py b/tests/layer_tests/jax_tests/test_integer_pow.py new file mode 100644 index 00000000000000..dc83f4fa1ddc5a --- /dev/null +++ b/tests/layer_tests/jax_tests/test_integer_pow.py @@ -0,0 +1,48 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import jax +import numpy as np +import pytest +from jax import lax +from jax import numpy as jnp + +from jax_layer_test_class import JaxLayerTest + +rng = np.random.default_rng(34455) + + +class TestIntegerPow(JaxLayerTest): + def _prepare_input(self): + if np.issubdtype(self.input_type, np.floating): + x = rng.uniform(-3.0, 3.0, self.x_shape).astype(self.input_type) + elif np.issubdtype(self.input_type, np.signedinteger): + x = rng.integers(-3, 3, self.x_shape).astype(self.input_type) + else: + x = rng.integers(0, 3, self.x_shape).astype(self.input_type) + x = jnp.array(x) + return [x] + + def create_model(self, x_shape, y, input_type): + self.x_shape = x_shape + self.input_type = input_type + + def jax_integer_pow(x): + return jax.lax.integer_pow(x, y) + + return jax_integer_pow, None, 'integer_pow' + + @pytest.mark.parametrize('x_shape', [[3], [2, 3], [1, 2, 3]]) + @pytest.mark.parametrize('y', [2, 3, 4]) + @pytest.mark.parametrize('input_type', [np.int8, np.uint8, np.int16, np.uint16, + np.int32, np.uint32, np.int64, np.uint64, + np.float16, np.float32, np.float64]) + @pytest.mark.nightly + @pytest.mark.precommit_jax_fe + def test_integer_pow(self, x_shape, y, input_type, + ie_device, precision, ir_version): + kwargs = {} + if input_type == np.float16: + kwargs["custom_eps"] = 2e-2 + self._test(*self.create_model(x_shape, y, input_type), + ie_device, precision, ir_version, **kwargs) diff --git a/tests/layer_tests/jax_tests/test_reduce.py b/tests/layer_tests/jax_tests/test_reduce.py new file mode 100644 index 00000000000000..b9a097213cb0da --- /dev/null +++ b/tests/layer_tests/jax_tests/test_reduce.py @@ -0,0 +1,51 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import jax +import numpy as np +import pytest +from jax import lax +from jax import numpy as jnp + +from jax_layer_test_class import JaxLayerTest + +rng = np.random.default_rng(3456834) + + +class TestReduce(JaxLayerTest): + def _prepare_input(self): + if np.issubdtype(self.input_type, np.floating): + x = rng.uniform(-5.0, 5.0, self.input_shape).astype(self.input_type) + elif np.issubdtype(self.input_type, np.signedinteger): + x = rng.integers(-8, 8, self.input_shape).astype(self.input_type) + else: + x = rng.integers(0, 8, self.input_shape).astype(self.input_type) + x = jnp.array(x) + return [x] + + def create_model(self, input_shape, axis, op_type, input_type): + reduce_map = { + 'reduce_max': jnp.max, + 'reduce_sum': jnp.sum + } + + self.input_shape = input_shape + self.input_type = input_type + + def jax_reduce(x): + return reduce_map[op_type](x, axis=axis) + + return jax_reduce, None, op_type + + @pytest.mark.parametrize('input_shape', [[2, 3, 4], [4, 3, 2, 1]]) + @pytest.mark.parametrize('axis', [None, 1, [0, 1], [0, 1, 2]]) + @pytest.mark.parametrize('op_type', ['reduce_max', 'reduce_sum']) + @pytest.mark.parametrize('input_type', [np.int8, np.uint8, np.int16, np.uint16, + np.int32, np.uint32, np.int64, np.uint64, + np.float16, np.float32, np.float64]) + @pytest.mark.nightly + @pytest.mark.precommit_jax_fe + def test_reduce(self, input_shape, axis, op_type, input_type, + ie_device, precision, ir_version): + self._test(*self.create_model(input_shape, axis, op_type, input_type), + ie_device, precision, ir_version)