Skip to content

Commit

Permalink
[JAX FE] Support concatenate, integer_pow, reduce ops
Browse files Browse the repository at this point in the history
This PR finalizes and allows to infer JAX ViT model

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
  • Loading branch information
rkazants committed Aug 28, 2024
1 parent 20afc50 commit f37d666
Show file tree
Hide file tree
Showing 10 changed files with 293 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/bindings/python/src/openvino/frontend/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
numpy_to_ov_type_map = {
np.float32: OVType.f32,
bool: OVType.boolean,
jax.dtypes.bfloat16: OVType.bf16, # TODO: check this
jax.dtypes.bfloat16: OVType.bf16, # TODO: check this
np.float16: OVType.f16,
np.float32: OVType.f32,
np.float64: OVType.f64,
Expand All @@ -29,7 +29,7 @@

jax_to_ov_type_map = {
jnp.float32: OVType.f32,
jnp.bfloat16: OVType.bf16, # TODO: check this
jnp.bfloat16: OVType.bf16, # TODO: check this
jnp.float16: OVType.f16,
jnp.float64: OVType.f64,
jnp.uint8: OVType.u8,
Expand Down Expand Up @@ -62,6 +62,9 @@
OVType.f16: 5,
OVType.f32: 6,
OVType.f64: 7,
OVType.u16: 8,
OVType.u32: 9,
OVType.u64: 10,
OVType.boolean: 11,
OVType.bf16: 15,
}
Expand Down
35 changes: 35 additions & 0 deletions src/frontends/jax/src/op/concatenate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/jax/node_context.hpp"
#include "openvino/op/concat.hpp"
#include "utils.hpp"

using namespace std;
using namespace ov;
using namespace ov::op;

namespace ov {
namespace frontend {
namespace jax {
namespace op {

OutputVector translate_concatenate(const NodeContext& context) {
num_inputs_check(context, 1);
auto num_inputs = static_cast<int>(context.get_input_size());
int64_t axis = context.const_named_param<int64_t>("dimension");

OutputVector inputs(num_inputs);
for (int ind = 0; ind < num_inputs; ++ind) {
inputs[ind] = context.get_input(ind);
}
Output<Node> res = make_shared<v0::Concat>(inputs, axis);

return {res};
};

} // namespace op
} // namespace jax
} // namespace frontend
} // namespace ov
32 changes: 32 additions & 0 deletions src/frontends/jax/src/op/integer_pow.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/jax/node_context.hpp"
#include "openvino/op/power.hpp"
#include "utils.hpp"

using namespace std;
using namespace ov;
using namespace ov::op;

namespace ov {
namespace frontend {
namespace jax {
namespace op {

OutputVector translate_integer_pow(const NodeContext& context) {
num_inputs_check(context, 1, 1);
auto x = context.get_input(0);
int64_t y = context.const_named_param<int64_t>("y");

// create y const of the same type as x
auto y_const = create_same_type_const_scalar<int64_t>(x, y);
Output<Node> res = make_shared<v1::Power>(x, y_const);
return {res};
};

} // namespace op
} // namespace jax
} // namespace frontend
} // namespace ov
36 changes: 36 additions & 0 deletions src/frontends/jax/src/op/reduce.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/jax/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/reduce_max.hpp"
#include "openvino/op/reduce_sum.hpp"
#include "utils.hpp"

using namespace std;
using namespace ov;
using namespace ov::op;

namespace ov {
namespace frontend {
namespace jax {
namespace op {

template <typename T>
OutputVector translate_reduce_op(const NodeContext& context) {
num_inputs_check(context, 1, 1);
auto axes = context.const_named_param<vector<int64_t>>("axes");
auto a = context.get_input(0);
auto axes_const = make_shared<v0::Constant>(element::i64, Shape{axes.size()}, axes);
Output<Node> reduce_op = make_shared<T>(a, axes_const, false);
return {reduce_op};
}

template OutputVector translate_reduce_op<v1::ReduceMax>(const NodeContext& node);
template OutputVector translate_reduce_op<v1::ReduceSum>(const NodeContext& node);

} // namespace op
} // namespace jax
} // namespace frontend
} // namespace ov
12 changes: 12 additions & 0 deletions src/frontends/jax/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "openvino/op/exp.hpp"
#include "openvino/op/maximum.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/reduce_max.hpp"
#include "openvino/op/reduce_sum.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/tanh.hpp"
Expand All @@ -22,13 +24,19 @@ namespace jax {
namespace op {

#define OP_CONVERTER(op) OutputVector op(const NodeContext& node)
#define OP_T_CONVERTER(op) \
template <class T> \
OutputVector op(const ov::frontend::jax::NodeContext& node)

OP_CONVERTER(translate_broadcast_in_dim);
OP_CONVERTER(translate_concatenate);
OP_CONVERTER(translate_constant);
OP_CONVERTER(translate_convert);
OP_CONVERTER(translate_convolution);
OP_CONVERTER(translate_copy);
OP_CONVERTER(translate_dot_general);
OP_CONVERTER(translate_integer_pow);
OP_T_CONVERTER(translate_reduce_op);
OP_CONVERTER(translate_reduce_window_max);
OP_CONVERTER(translate_reduce_window_sum);
OP_CONVERTER(translate_reshape);
Expand All @@ -43,6 +51,7 @@ OP_CONVERTER(translate_transpose);
const std::map<std::string, CreatorFunction> get_supported_ops_jaxpr() {
return {{"add", op::translate_1to1_match_2_inputs<v1::Add>},
{"broadcast_in_dim", op::translate_broadcast_in_dim},
{"concatenate", op::translate_concatenate},
{"constant", op::translate_constant},
{"convert_element_type", op::translate_convert},
{"conv_general_dilated", op::translate_convolution},
Expand All @@ -51,8 +60,11 @@ const std::map<std::string, CreatorFunction> get_supported_ops_jaxpr() {
{"div", op::translate_1to1_match_2_inputs<v1::Divide>},
{"dot_general", op::translate_dot_general},
{"exp", op::translate_1to1_match_1_input<v0::Exp>},
{"integer_pow", op::translate_integer_pow},
{"max", op::translate_1to1_match_2_inputs<v1::Maximum>},
{"mul", op::translate_1to1_match_2_inputs<v1::Multiply>},
{"reduce_max", op::translate_reduce_op<v1::ReduceMax>},
{"reduce_sum", op::translate_reduce_op<v1::ReduceSum>},
{"reduce_window_max", op::translate_reduce_window_max},
{"reduce_window_sum", op::translate_reduce_window_sum},
{"transpose", op::translate_transpose},
Expand Down
8 changes: 8 additions & 0 deletions src/frontends/jax/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_
FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= min_inputs, "Got less inputs than expected");
}

void num_inputs_check(const NodeContext& context, size_t min_inputs) {
auto inputs = context.inputs();
FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= min_inputs, "Got less inputs than expected");
}

const std::string& get_jax_prefix() {
return jax_prefix;
}
Expand Down Expand Up @@ -48,6 +53,9 @@ const std::unordered_map<int64_t, element::Type> JAX_TO_OV_TYPE{
{5, element::f16},
{6, element::f32},
{7, element::f64},
{8, element::u16},
{9, element::u32},
{10, element::u64},
{11, element::boolean},
{12, element::i8}, // quantized i8
{13, element::u8}, // quantized u8
Expand Down
15 changes: 15 additions & 0 deletions src/frontends/jax/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "openvino/frontend/jax/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"

namespace ov {

Expand Down Expand Up @@ -37,6 +38,8 @@ const std::string& get_jax_prefix();

void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs);

void num_inputs_check(const NodeContext& context, size_t min_inputs);

void add_exception_to_fw_node(std::shared_ptr<Node> node, const std::string& msg);

bool is_python_scalar_input(const NodeContext& context, size_t index);
Expand All @@ -62,6 +65,18 @@ inline OutputVector skip_node(const NodeContext& context) {
return {context.get_input(0)};
}

template <typename T>
ov::Output<ov::Node> create_same_type_const_scalar(const ov::Output<ov::Node>& same_type_output, const T& value) {
if (same_type_output.get_element_type().is_static()) {
return std::make_shared<ov::op::v0::Constant>(same_type_output.get_element_type(), ov::Shape{}, value);
} else {
ov::Output<ov::Node> const_res =
std::make_shared<ov::op::v0::Constant>(ov::element::from<T>(), ov::Shape{}, value);
const_res = std::make_shared<ov::op::v1::ConvertLike>(const_res, same_type_output);
return const_res;
}
}

} // namespace op

} // namespace jax
Expand Down
51 changes: 51 additions & 0 deletions tests/layer_tests/jax_tests/test_concatenate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import jax
import numpy as np
import pytest
from jax import lax
from jax import numpy as jnp

from jax_layer_test_class import JaxLayerTest

rng = np.random.default_rng(34455)


class TestConcatenate(JaxLayerTest):
def _prepare_input(self):
inputs = []
for input_shape in self.input_shapes:
if np.issubdtype(self.input_type, np.floating):
x = rng.uniform(-5.0, 5.0, input_shape).astype(self.input_type)
elif np.issubdtype(self.input_type, np.signedinteger):
x = rng.integers(-8, 8, input_shape).astype(self.input_type)
else:
x = rng.integers(0, 8, input_shape).astype(self.input_type)
x = jnp.array(x)
inputs.append(x)
return inputs

def create_model(self, input_shapes, input_type, dimension):
self.input_shapes = input_shapes
self.input_type = input_type

def jax_concatenate(*arrays):
return jax.lax.concatenate(arrays, dimension)

return jax_concatenate, None, 'concatenate'

@pytest.mark.parametrize('input_shapes,dimension', [
([[2], [3]], 0),
([[2, 3], [2, 4]], 1),
([[1, 2, 3], [1, 2, 2], [1, 2, 1]], 2),
])
@pytest.mark.parametrize('input_type', [np.int8, np.uint8, np.int16, np.uint16,
np.int32, np.uint32, np.int64, np.uint64,
np.float16, np.float32, np.float64])
@pytest.mark.nightly
@pytest.mark.precommit_jax_fe
def test_concatenate(self, input_shapes, input_type, dimension,
ie_device, precision, ir_version):
self._test(*self.create_model(input_shapes, input_type, dimension),
ie_device, precision, ir_version)
48 changes: 48 additions & 0 deletions tests/layer_tests/jax_tests/test_integer_pow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import jax
import numpy as np
import pytest
from jax import lax
from jax import numpy as jnp

from jax_layer_test_class import JaxLayerTest

rng = np.random.default_rng(34455)


class TestIntegerPow(JaxLayerTest):
def _prepare_input(self):
if np.issubdtype(self.input_type, np.floating):
x = rng.uniform(-3.0, 3.0, self.x_shape).astype(self.input_type)
elif np.issubdtype(self.input_type, np.signedinteger):
x = rng.integers(-3, 3, self.x_shape).astype(self.input_type)
else:
x = rng.integers(0, 3, self.x_shape).astype(self.input_type)
x = jnp.array(x)
return [x]

def create_model(self, x_shape, y, input_type):
self.x_shape = x_shape
self.input_type = input_type

def jax_integer_pow(x):
return jax.lax.integer_pow(x, y)

return jax_integer_pow, None, 'integer_pow'

@pytest.mark.parametrize('x_shape', [[3], [2, 3], [1, 2, 3]])
@pytest.mark.parametrize('y', [2, 3, 4])
@pytest.mark.parametrize('input_type', [np.int8, np.uint8, np.int16, np.uint16,
np.int32, np.uint32, np.int64, np.uint64,
np.float16, np.float32, np.float64])
@pytest.mark.nightly
@pytest.mark.precommit_jax_fe
def test_integer_pow(self, x_shape, y, input_type,
ie_device, precision, ir_version):
kwargs = {}
if input_type == np.float16:
kwargs["custom_eps"] = 2e-2
self._test(*self.create_model(x_shape, y, input_type),
ie_device, precision, ir_version, **kwargs)
51 changes: 51 additions & 0 deletions tests/layer_tests/jax_tests/test_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import jax
import numpy as np
import pytest
from jax import lax
from jax import numpy as jnp

from jax_layer_test_class import JaxLayerTest

rng = np.random.default_rng(3456834)


class TestReduce(JaxLayerTest):
def _prepare_input(self):
if np.issubdtype(self.input_type, np.floating):
x = rng.uniform(-5.0, 5.0, self.input_shape).astype(self.input_type)
elif np.issubdtype(self.input_type, np.signedinteger):
x = rng.integers(-8, 8, self.input_shape).astype(self.input_type)
else:
x = rng.integers(0, 8, self.input_shape).astype(self.input_type)
x = jnp.array(x)
return [x]

def create_model(self, input_shape, axis, op_type, input_type):
reduce_map = {
'reduce_max': jnp.max,
'reduce_sum': jnp.sum
}

self.input_shape = input_shape
self.input_type = input_type

def jax_reduce(x):
return reduce_map[op_type](x, axis=axis)

return jax_reduce, None, op_type

@pytest.mark.parametrize('input_shape', [[2, 3, 4], [4, 3, 2, 1]])
@pytest.mark.parametrize('axis', [None, 1, [0, 1], [0, 1, 2]])
@pytest.mark.parametrize('op_type', ['reduce_max', 'reduce_sum'])
@pytest.mark.parametrize('input_type', [np.int8, np.uint8, np.int16, np.uint16,
np.int32, np.uint32, np.int64, np.uint64,
np.float16, np.float32, np.float64])
@pytest.mark.nightly
@pytest.mark.precommit_jax_fe
def test_reduce(self, input_shape, axis, op_type, input_type,
ie_device, precision, ir_version):
self._test(*self.create_model(input_shape, axis, op_type, input_type),
ie_device, precision, ir_version)

0 comments on commit f37d666

Please sign in to comment.