-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JAX FE] Support concatenate, integer_pow, reduce ops
This PR finalizes and allows to infer JAX ViT model Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
- Loading branch information
Showing
10 changed files
with
293 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/frontend/jax/node_context.hpp" | ||
#include "openvino/op/concat.hpp" | ||
#include "utils.hpp" | ||
|
||
using namespace std; | ||
using namespace ov; | ||
using namespace ov::op; | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace jax { | ||
namespace op { | ||
|
||
OutputVector translate_concatenate(const NodeContext& context) { | ||
num_inputs_check(context, 1); | ||
auto num_inputs = static_cast<int>(context.get_input_size()); | ||
int64_t axis = context.const_named_param<int64_t>("dimension"); | ||
|
||
OutputVector inputs(num_inputs); | ||
for (int ind = 0; ind < num_inputs; ++ind) { | ||
inputs[ind] = context.get_input(ind); | ||
} | ||
Output<Node> res = make_shared<v0::Concat>(inputs, axis); | ||
|
||
return {res}; | ||
}; | ||
|
||
} // namespace op | ||
} // namespace jax | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/frontend/jax/node_context.hpp" | ||
#include "openvino/op/power.hpp" | ||
#include "utils.hpp" | ||
|
||
using namespace std; | ||
using namespace ov; | ||
using namespace ov::op; | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace jax { | ||
namespace op { | ||
|
||
OutputVector translate_integer_pow(const NodeContext& context) { | ||
num_inputs_check(context, 1, 1); | ||
auto x = context.get_input(0); | ||
int64_t y = context.const_named_param<int64_t>("y"); | ||
|
||
// create y const of the same type as x | ||
auto y_const = create_same_type_const_scalar<int64_t>(x, y); | ||
Output<Node> res = make_shared<v1::Power>(x, y_const); | ||
return {res}; | ||
}; | ||
|
||
} // namespace op | ||
} // namespace jax | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/frontend/jax/node_context.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/reduce_max.hpp" | ||
#include "openvino/op/reduce_sum.hpp" | ||
#include "utils.hpp" | ||
|
||
using namespace std; | ||
using namespace ov; | ||
using namespace ov::op; | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace jax { | ||
namespace op { | ||
|
||
template <typename T> | ||
OutputVector translate_reduce_op(const NodeContext& context) { | ||
num_inputs_check(context, 1, 1); | ||
auto axes = context.const_named_param<vector<int64_t>>("axes"); | ||
auto a = context.get_input(0); | ||
auto axes_const = make_shared<v0::Constant>(element::i64, Shape{axes.size()}, axes); | ||
Output<Node> reduce_op = make_shared<T>(a, axes_const, false); | ||
return {reduce_op}; | ||
} | ||
|
||
template OutputVector translate_reduce_op<v1::ReduceMax>(const NodeContext& node); | ||
template OutputVector translate_reduce_op<v1::ReduceSum>(const NodeContext& node); | ||
|
||
} // namespace op | ||
} // namespace jax | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import jax | ||
import numpy as np | ||
import pytest | ||
from jax import lax | ||
from jax import numpy as jnp | ||
|
||
from jax_layer_test_class import JaxLayerTest | ||
|
||
rng = np.random.default_rng(34455) | ||
|
||
|
||
class TestConcatenate(JaxLayerTest): | ||
def _prepare_input(self): | ||
inputs = [] | ||
for input_shape in self.input_shapes: | ||
if np.issubdtype(self.input_type, np.floating): | ||
x = rng.uniform(-5.0, 5.0, input_shape).astype(self.input_type) | ||
elif np.issubdtype(self.input_type, np.signedinteger): | ||
x = rng.integers(-8, 8, input_shape).astype(self.input_type) | ||
else: | ||
x = rng.integers(0, 8, input_shape).astype(self.input_type) | ||
x = jnp.array(x) | ||
inputs.append(x) | ||
return inputs | ||
|
||
def create_model(self, input_shapes, input_type, dimension): | ||
self.input_shapes = input_shapes | ||
self.input_type = input_type | ||
|
||
def jax_concatenate(*arrays): | ||
return jax.lax.concatenate(arrays, dimension) | ||
|
||
return jax_concatenate, None, 'concatenate' | ||
|
||
@pytest.mark.parametrize('input_shapes,dimension', [ | ||
([[2], [3]], 0), | ||
([[2, 3], [2, 4]], 1), | ||
([[1, 2, 3], [1, 2, 2], [1, 2, 1]], 2), | ||
]) | ||
@pytest.mark.parametrize('input_type', [np.int8, np.uint8, np.int16, np.uint16, | ||
np.int32, np.uint32, np.int64, np.uint64, | ||
np.float16, np.float32, np.float64]) | ||
@pytest.mark.nightly | ||
@pytest.mark.precommit_jax_fe | ||
def test_concatenate(self, input_shapes, input_type, dimension, | ||
ie_device, precision, ir_version): | ||
self._test(*self.create_model(input_shapes, input_type, dimension), | ||
ie_device, precision, ir_version) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import jax | ||
import numpy as np | ||
import pytest | ||
from jax import lax | ||
from jax import numpy as jnp | ||
|
||
from jax_layer_test_class import JaxLayerTest | ||
|
||
rng = np.random.default_rng(34455) | ||
|
||
|
||
class TestIntegerPow(JaxLayerTest): | ||
def _prepare_input(self): | ||
if np.issubdtype(self.input_type, np.floating): | ||
x = rng.uniform(-3.0, 3.0, self.x_shape).astype(self.input_type) | ||
elif np.issubdtype(self.input_type, np.signedinteger): | ||
x = rng.integers(-3, 3, self.x_shape).astype(self.input_type) | ||
else: | ||
x = rng.integers(0, 3, self.x_shape).astype(self.input_type) | ||
x = jnp.array(x) | ||
return [x] | ||
|
||
def create_model(self, x_shape, y, input_type): | ||
self.x_shape = x_shape | ||
self.input_type = input_type | ||
|
||
def jax_integer_pow(x): | ||
return jax.lax.integer_pow(x, y) | ||
|
||
return jax_integer_pow, None, 'integer_pow' | ||
|
||
@pytest.mark.parametrize('x_shape', [[3], [2, 3], [1, 2, 3]]) | ||
@pytest.mark.parametrize('y', [2, 3, 4]) | ||
@pytest.mark.parametrize('input_type', [np.int8, np.uint8, np.int16, np.uint16, | ||
np.int32, np.uint32, np.int64, np.uint64, | ||
np.float16, np.float32, np.float64]) | ||
@pytest.mark.nightly | ||
@pytest.mark.precommit_jax_fe | ||
def test_integer_pow(self, x_shape, y, input_type, | ||
ie_device, precision, ir_version): | ||
kwargs = {} | ||
if input_type == np.float16: | ||
kwargs["custom_eps"] = 2e-2 | ||
self._test(*self.create_model(x_shape, y, input_type), | ||
ie_device, precision, ir_version, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import jax | ||
import numpy as np | ||
import pytest | ||
from jax import lax | ||
from jax import numpy as jnp | ||
|
||
from jax_layer_test_class import JaxLayerTest | ||
|
||
rng = np.random.default_rng(3456834) | ||
|
||
|
||
class TestReduce(JaxLayerTest): | ||
def _prepare_input(self): | ||
if np.issubdtype(self.input_type, np.floating): | ||
x = rng.uniform(-5.0, 5.0, self.input_shape).astype(self.input_type) | ||
elif np.issubdtype(self.input_type, np.signedinteger): | ||
x = rng.integers(-8, 8, self.input_shape).astype(self.input_type) | ||
else: | ||
x = rng.integers(0, 8, self.input_shape).astype(self.input_type) | ||
x = jnp.array(x) | ||
return [x] | ||
|
||
def create_model(self, input_shape, axis, op_type, input_type): | ||
reduce_map = { | ||
'reduce_max': jnp.max, | ||
'reduce_sum': jnp.sum | ||
} | ||
|
||
self.input_shape = input_shape | ||
self.input_type = input_type | ||
|
||
def jax_reduce(x): | ||
return reduce_map[op_type](x, axis=axis) | ||
|
||
return jax_reduce, None, op_type | ||
|
||
@pytest.mark.parametrize('input_shape', [[2, 3, 4], [4, 3, 2, 1]]) | ||
@pytest.mark.parametrize('axis', [None, 1, [0, 1], [0, 1, 2]]) | ||
@pytest.mark.parametrize('op_type', ['reduce_max', 'reduce_sum']) | ||
@pytest.mark.parametrize('input_type', [np.int8, np.uint8, np.int16, np.uint16, | ||
np.int32, np.uint32, np.int64, np.uint64, | ||
np.float16, np.float32, np.float64]) | ||
@pytest.mark.nightly | ||
@pytest.mark.precommit_jax_fe | ||
def test_reduce(self, input_shape, axis, op_type, input_type, | ||
ie_device, precision, ir_version): | ||
self._test(*self.create_model(input_shape, axis, op_type, input_type), | ||
ie_device, precision, ir_version) |