Skip to content

Commit

Permalink
Fix test issues
Browse files Browse the repository at this point in the history
  • Loading branch information
praasz authored and beleiuandrei committed Jan 15, 2024
1 parent c72808c commit 2121c38
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 24 deletions.
2 changes: 1 addition & 1 deletion src/bindings/python/src/pyopenvino/core/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ const std::map<ov::element::Type, py::dtype>& ov_type_to_dtype() {
{ov::element::boolean, py::dtype("bool")}, {ov::element::u1, py::dtype("uint8")},
{ov::element::u4, py::dtype("uint8")}, {ov::element::nf4, py::dtype("uint8")},
{ov::element::i4, py::dtype("int8")}, {ov::element::f8e4m3, py::dtype("uint8")},
{ov::element::f8e5m2, py::dtype("uint8")},
{ov::element::f8e5m2, py::dtype("uint8")}, {ov::element::string, py::dtype("bytes_")},
};
return ov_type_to_dtype_mapping;
}
Expand Down
27 changes: 4 additions & 23 deletions src/bindings/python/tests/test_graph/test_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,25 +338,6 @@ def test_memory_sharing(shared_flag):
assert not np.shares_memory(arr, ov_const.data)


@pytest.mark.parametrize(
("ov_type", "src_dtype"),
[
(Type.u1, np.uint8),
(Type.u4, np.uint8),
(Type.i4, np.int8),
(Type.nf4, np.uint8),
(Type.bf16, np.float16),
],
)
def test_raise_for_packed_types(ov_type, src_dtype):
data = np.ones((2, 4, 16)).astype(src_dtype)

with pytest.raises(RuntimeError) as err:
_ = ops.constant(data, dtype=ov_type)

assert "All values must be equal to 0 to initialize Constant with type of" in str(err.value)


@pytest.mark.parametrize(("ov_type", "numpy_dtype"), [
(Type.f32, np.float32),
(Type.f16, np.float16),
Expand Down Expand Up @@ -411,10 +392,10 @@ def test_float_to_f8e4m3_constant(ov_type, numpy_dtype):
target = [5.0, 4.5, -5.0, 0.0, 0.1015625, 0.203125, 0.3125,
0.40625, 0.5, 0.625, 0.6875, 0.8125, 0.875, 1,
-0, -0.1015625, -0.203125, -0.3125, -0.40625, -0.5, -0.625,
-0.6875, -0.8125, -0.875, -1, 448, 448]
-0.6875, -0.8125, -0.875, -1, 448, np.nan]
target = np.array(target, dtype=numpy_dtype)

assert np.allclose(result, target)
assert np.allclose(result, target, equal_nan=True)


@pytest.mark.parametrize(("ov_type", "numpy_dtype"), [
Expand Down Expand Up @@ -473,7 +454,7 @@ def test_float_to_f8e4m3_convert(ov_type, numpy_dtype):
target = [5.0, 4.5, -5.0, 0.0, 0.1015625, 0.203125, 0.3125,
0.40625, 0.5, 0.625, 0.6875, 0.8125, 0.875, 1,
-0, -0.1015625, -0.203125, -0.3125, -0.40625, -0.5, -0.625,
-0.6875, -0.8125, -0.875, -1, 448, 448]
-0.6875, -0.8125, -0.875, -1, 448, np.nan]
target = np.array(target, dtype=numpy_dtype)

assert np.allclose(result, target)
assert np.allclose(result, target, equal_nan=True)

0 comments on commit 2121c38

Please sign in to comment.