Skip to content

Commit

Permalink
Added special case for _full_usm_ndarray
Browse files Browse the repository at this point in the history
Bitwise zero values, and 1-byte wide types now use memset, instead
of using fill.

```
In [1]: import dpctl.tensor as dpt, dpctl.tensor._tensor_impl as ti

In [2]: res = dpt.empty(10**6, dtype="i8")

In [3]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(0, dst=res, sycl_queue=res.sycl_queue)[0].wait()
243 µs ± 22.6 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each)

In [4]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(0, dst=res, sycl_queue=res.sycl_queue)[0].wait()
229 µs ± 14 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each)

In [5]: %timeit -n 2000 -r 11 ti._zeros_usm_ndarray(dst=res, sycl_queue=res.sycl_queue)[0].wait()
227 µs ± 23 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each)

In [6]: %timeit -n 2000 -r 11 ti._zeros_usm_ndarray(dst=res, sycl_queue=res.sycl_queue)[0].wait()
233 µs ± 25.9 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each)

In [7]: %timeit -n 2000 -r 11 ti._zeros_usm_ndarray(dst=res, sycl_queue=res.sycl_queue)[0].wait()
301 µs ± 54.1 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each)

In [8]: %timeit -n 2000 -r 11 ti._zeros_usm_ndarray(dst=res, sycl_queue=res.sycl_queue)[0].wait()
236 µs ± 17.2 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each)

In [9]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(0, dst=res, sycl_queue=res.sycl_queue)[0].wait()
240 µs ± 35.2 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each)

In [10]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(1, dst=res, sycl_queue=res.sycl_queue)[0].wait()
243 µs ± 17.6 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each)

In [11]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(1, dst=res, sycl_queue=res.sycl_queue)[0].wait()
263 µs ± 39.9 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each)

In [12]: %timeit -n 2000 -r 11 ti._full_usm_ndarray(0, dst=res, sycl_queue=res.sycl_queue)[0].wait()
239 µs ± 26.4 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each)

In [13]: %timeit -n 2000 -r 11 ti._zeros_usm_ndarray(dst=res, sycl_queue=res.sycl_queue)[0].wait()
224 µs ± 18.1 µs per loop (mean ± std. dev. of 11 runs, 2,000 loops each)
```
  • Loading branch information
oleksandr-pavlyk committed Aug 21, 2024
1 parent bec95f9 commit da954f9
Showing 1 changed file with 58 additions and 4 deletions.
62 changes: 58 additions & 4 deletions dpctl/tensor/libtensor/source/full_ctor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,65 @@ sycl::event full_contig_impl(sycl::queue &exec_q,
{
dstTy fill_v = py::cast<dstTy>(py_value);

using dpctl::tensor::kernels::constructors::full_contig_impl;
sycl::event fill_ev;

sycl::event fill_ev =
full_contig_impl<dstTy>(exec_q, nelems, fill_v, dst_p, depends);
if constexpr (sizeof(dstTy) == sizeof(char)) {
const auto memset_val = sycl::bit_cast<unsigned char>(fill_v);
fill_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

cgh.memset(reinterpret_cast<void *>(dst_p), memset_val,
nelems * sizeof(dstTy));
});
}
else {
bool is_zero = false;
if constexpr (sizeof(dstTy) == 1) {
is_zero = (std::uint8_t{0} == sycl::bit_cast<std::uint8_t>(fill_v));
}
else if constexpr (sizeof(dstTy) == 2) {
is_zero =
(std::uint16_t{0} == sycl::bit_cast<std::uint16_t>(fill_v));
}
else if constexpr (sizeof(dstTy) == 4) {
is_zero =
(std::uint32_t{0} == sycl::bit_cast<std::uint32_t>(fill_v));
}
else if constexpr (sizeof(dstTy) == 8) {
is_zero =
(std::uint64_t{0} == sycl::bit_cast<std::uint64_t>(fill_v));
}
else if constexpr (sizeof(dstTy) == 16) {
struct UInt128
{

constexpr UInt128() : v1{}, v2{} {}
UInt128(const UInt128 &) = default;

operator bool() const { return bool(v1) && bool(v2); }

std::uint64_t v1;
std::uint64_t v2;
};
is_zero = static_cast<bool>(sycl::bit_cast<UInt128>(fill_v));
}

if (is_zero) {
constexpr int memset_val = 0;
fill_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

cgh.memset(reinterpret_cast<void *>(dst_p), memset_val,
nelems * sizeof(dstTy));
});
}
else {
using dpctl::tensor::kernels::constructors::full_contig_impl;

fill_ev =
full_contig_impl<dstTy>(exec_q, nelems, fill_v, dst_p, depends);
}
}

return fill_ev;
}
Expand Down Expand Up @@ -126,7 +181,6 @@ usm_ndarray_full(const py::object &py_value,
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);

char *dst_data = dst.get_data();
sycl::event full_event;

if (dst_nelems == 1 || dst.is_c_contiguous() || dst.is_f_contiguous()) {
auto fn = full_contig_dispatch_vector[dst_typeid];
Expand Down

0 comments on commit da954f9

Please sign in to comment.