From da954f911607a4cb010a6906857a40f174d7b226 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 21 Aug 2024 17:49:32 -0500 Subject: [PATCH] Added special case for _full_usm_ndarray MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bitwise zero values, and 1-byte wide types now use memset, instead of using fill. ``` In [1]: import dpctl.tensor as dpt, dpctl.tensor._tensor_impl as ti In [2]: res = dpt.empty(10**6, dtype="i8") In [3]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(0, dst=res, sycl_queue=res.sycl_queue)[0].wait() 243 µs ± 22.6 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [4]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(0, dst=res, sycl_queue=res.sycl_queue)[0].wait() 229 µs ± 14 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [5]: %timeit -n 2000 -r 11 ti._zeros_usm_ndarray(dst=res, sycl_queue=res.sycl_queue)[0].wait() 227 µs ± 23 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [6]: %timeit -n 2000 -r 11 ti._zeros_usm_ndarray(dst=res, sycl_queue=res.sycl_queue)[0].wait() 233 µs ± 25.9 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [7]: %timeit -n 2000 -r 11 ti._zeros_usm_ndarray(dst=res, sycl_queue=res.sycl_queue)[0].wait() 301 µs ± 54.1 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [8]: %timeit -n 2000 -r 11 ti._zeros_usm_ndarray(dst=res, sycl_queue=res.sycl_queue)[0].wait() 236 µs ± 17.2 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [9]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(0, dst=res, sycl_queue=res.sycl_queue)[0].wait() 240 µs ± 35.2 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [10]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(1, dst=res, sycl_queue=res.sycl_queue)[0].wait() 243 µs ± 17.6 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [11]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(1, dst=res, sycl_queue=res.sycl_queue)[0].wait() 263 µs ± 39.9 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [12]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(0, dst=res, sycl_queue=res.sycl_queue)[0].wait() 239 µs ± 26.4 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) In [13]: %timeit -n 2000 -r 11 ti._zeros_usm_ndarray(dst=res, sycl_queue=res.sycl_queue)[0].wait() 224 µs ± 18.1 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each) ``` --- dpctl/tensor/libtensor/source/full_ctor.cpp | 62 +++++++++++++++++++-- 1 file changed, 58 insertions(+), 4 deletions(-) diff --git a/dpctl/tensor/libtensor/source/full_ctor.cpp b/dpctl/tensor/libtensor/source/full_ctor.cpp index 41b3093652..fe5af748eb 100644 --- a/dpctl/tensor/libtensor/source/full_ctor.cpp +++ b/dpctl/tensor/libtensor/source/full_ctor.cpp @@ -80,10 +80,65 @@ sycl::event full_contig_impl(sycl::queue &exec_q, { dstTy fill_v = py::cast(py_value); - using dpctl::tensor::kernels::constructors::full_contig_impl; + sycl::event fill_ev; - sycl::event fill_ev = - full_contig_impl(exec_q, nelems, fill_v, dst_p, depends); + if constexpr (sizeof(dstTy) == sizeof(char)) { + const auto memset_val = sycl::bit_cast(fill_v); + fill_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.memset(reinterpret_cast(dst_p), memset_val, + nelems * sizeof(dstTy)); + }); + } + else { + bool is_zero = false; + if constexpr (sizeof(dstTy) == 1) { + is_zero = (std::uint8_t{0} == sycl::bit_cast(fill_v)); + } + else if constexpr (sizeof(dstTy) == 2) { + is_zero = + (std::uint16_t{0} == sycl::bit_cast(fill_v)); + } + else if constexpr (sizeof(dstTy) == 4) { + is_zero = + (std::uint32_t{0} == sycl::bit_cast(fill_v)); + } + else if constexpr (sizeof(dstTy) == 8) { + is_zero = + (std::uint64_t{0} == sycl::bit_cast(fill_v)); + } + else if constexpr (sizeof(dstTy) == 16) { + struct UInt128 + { + + constexpr UInt128() : v1{}, v2{} {} + UInt128(const UInt128 &) = default; + + operator bool() const { return bool(v1) && bool(v2); } + + std::uint64_t v1; + std::uint64_t v2; + }; + is_zero = static_cast(sycl::bit_cast(fill_v)); + } + + if (is_zero) { + constexpr int memset_val = 0; + fill_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.memset(reinterpret_cast(dst_p), memset_val, + nelems * sizeof(dstTy)); + }); + } + else { + using dpctl::tensor::kernels::constructors::full_contig_impl; + + fill_ev = + full_contig_impl(exec_q, nelems, fill_v, dst_p, depends); + } + } return fill_ev; } @@ -126,7 +181,6 @@ usm_ndarray_full(const py::object &py_value, int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); char *dst_data = dst.get_data(); - sycl::event full_event; if (dst_nelems == 1 || dst.is_c_contiguous() || dst.is_f_contiguous()) { auto fn = full_contig_dispatch_vector[dst_typeid];