Skip to content

Commit

Permalink
update repeat tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alifahrri committed Nov 5, 2023
1 parent 086ec48 commit 71c4cef
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
2 changes: 1 addition & 1 deletion cmake/toolchains/clang-werror.cmake
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
set(CMAKE_C_COMPILER clang)
set(CMAKE_CXX_COMPILER clang++)

add_compile_options(-W -Wall -Werror -Wextra -Wno-gnu-string-literal-operator-template)
add_compile_options(-W -Wall -Werror -Wextra -Wno-gnu-string-literal-operator-template -Wno-deprecated-declarations)
6 changes: 3 additions & 3 deletions include/nmtools/array/eval/opencl/kernels/repeat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ kernel void nmtools_cl_kernel_name(out_type,inp_type) \
, const unsigned int axis \
) \
{ \
auto repeats = na::create_vector(repeats_ptr,repeats_size); \
auto input = na::create_array(inp_ptr,inp_shape_ptr,inp_dim); \
auto output = na::create_mutable_array(out_ptr,out_shape_ptr,out_dim); \
auto repeats = na::create_vector(repeats_ptr,repeats_size); \
auto input = na::create_array(inp_ptr,inp_shape_ptr,inp_dim); \
auto output = na::create_mutable_array(out_ptr,out_shape_ptr,out_dim); \
auto repeated = view::repeat(input,repeats,axis); \
opencl::assign_array(output,repeated); \
}
Expand Down
22 changes: 20 additions & 2 deletions tests/opencl/kernels/repeat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ TEST_CASE("repeat(case1)" * doctest::test_suite("opencl::repeat"))
{
auto shape = nmtools_array<uint32_t,2>{1,64};
auto input = na::reshape(na::arange(ix::product(shape)),shape);
auto repeats = nmtools_array<uint32_t,2>{1,2};
auto repeats = nmtools_array<uint32_t,1>{2};
auto axis = 0;
OPENCL_TEST(repeat,input,repeats,axis);
}
Expand All @@ -34,9 +34,27 @@ TEST_CASE("repeat(case2)" * doctest::test_suite("opencl::repeat"))
{
auto shape = nmtools_array<uint32_t,2>{8,8};
auto input = na::reshape(na::arange(ix::product(shape)),shape);
auto repeats = nmtools_array<uint32_t,2>{1,2};
auto repeats = nmtools_array<uint32_t,8>{2,1,1,1,1,1,1,1};
auto axis = 0;
OPENCL_TEST(repeat,input,repeats,axis);
}

TEST_CASE("repeat(case3)" * doctest::test_suite("opencl::repeat"))
{
auto shape = nmtools_array<uint32_t,2>{64,1};
auto input = na::reshape(na::arange(ix::product(shape)),shape);
auto repeats = nmtools_array<uint32_t,1>{3};
auto axis = 1;
OPENCL_TEST(repeat,input,repeats,axis);
}

TEST_CASE("repeat(case4)" * doctest::test_suite("opencl::repeat"))
{
auto shape = nmtools_array<uint32_t,2>{8,8};
auto input = na::reshape(na::arange(ix::product(shape)),shape);
auto repeats = nmtools_array<uint32_t,8>{1,2,3,4,1,1,1,1};
auto axis = 1;
OPENCL_TEST(repeat,input,repeats,axis);
}

#endif // NMTOOLS_OPENCL_BUILD_KERNELS

0 comments on commit 71c4cef

Please sign in to comment.