Skip to content

Commit

Permalink
Implement dedicated strided full kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
ndgrigorian committed Aug 23, 2024
1 parent cfba263 commit 4346510
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 7 deletions.
69 changes: 69 additions & 0 deletions dpctl/tensor/libtensor/include/kernels/constructors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ namespace constructors

template <typename Ty> class linear_sequence_step_kernel;
template <typename Ty, typename wTy> class linear_sequence_affine_kernel;
template <typename Ty> class full_strided_kernel;
template <typename Ty> class eye_kernel;

using namespace dpctl::tensor::offset_utils;
Expand Down Expand Up @@ -252,6 +253,74 @@ sycl::event full_contig_impl(sycl::queue &q,
return fill_ev;
}

template <typename Ty, typename IndexerT> class FullStridedFunctor
{
private:
Ty *p = nullptr;
const Ty fill_v;
const IndexerT indexer;

public:
FullStridedFunctor(Ty *p_, const Ty &fill_v_, const IndexerT &indexer_)
: p(p_), fill_v(fill_v_), indexer(indexer_)
{
}

void operator()(sycl::id<1> id) const
{
auto offset = indexer(id.get(0));
p[offset] = fill_v;
}
};

/*!
* @brief Function to submit kernel to fill given contiguous memory allocation
* with specified value.
*
* @param exec_q Sycl queue to which kernel is submitted for execution.
* @param nd Array dimensionality
* @param nelems Length of the sequence
* @param shape_strides Kernel accessible USM pointer to packed shape and
* strides of array.
* @param offset Displacement of first element of dst relative dst_p in
* elements
* @param fill_v Value to fill the array with
* @param dst_p Kernel accessible USM pointer to the start of array to be
* populated.
* @param depends List of events to wait for before starting computations, if
* any.
*
* @return Event to wait on to ensure that computation completes.
* @defgroup CtorKernels
*/
template <typename dstTy>
sycl::event full_strided_impl(sycl::queue &q,
int nd,
size_t nelems,
const ssize_t *shape_strides,
const ssize_t offset,
dstTy fill_v,
char *dst_p,
const std::vector<sycl::event> &depends)
{
dpctl::tensor::type_utils::validate_type_for_device<dstTy>(q);

dstTy *dst_tp = reinterpret_cast<dstTy *>(dst_p);

using dpctl::tensor::offset_utils::StridedIndexer;
const StridedIndexer strided_indexer(nd, offset, shape_strides);

sycl::event fill_ev = q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
cgh.parallel_for<full_strided_kernel<dstTy>>(
sycl::range<1>{nelems},
FullStridedFunctor<dstTy, decltype(strided_indexer)>(
dst_tp, fill_v, strided_indexer));
});

return fill_ev;
}

/* ================ Eye ================== */

typedef sycl::event (*eye_fn_ptr_t)(sycl::queue &,
Expand Down
138 changes: 131 additions & 7 deletions dpctl/tensor/libtensor/source/full_ctor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "utils/type_utils.hpp"

#include "full_ctor.hpp"
#include "simplify_iteration_space.hpp"

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;
Expand All @@ -61,9 +62,9 @@ typedef sycl::event (*full_contig_fn_ptr_t)(sycl::queue &,
*
* @param exec_q Sycl queue to which kernel is submitted for execution.
* @param nelems Length of the sequence
* @param py_value Python object representing the value to fill the array with.
* @param py_value Python object representing the value to fill the array with.
* Must be convertible to `dstTy`.
* @param dst_p Kernel accessible USM pointer to the start of array to be
* @param dst_p Kernel accessible USM pointer to the start of array to be
* populated.
* @param depends List of events to wait for before starting computations, if
* any.
Expand Down Expand Up @@ -152,7 +153,66 @@ template <typename fnT, typename Ty> struct FullContigFactory
}
};

typedef sycl::event (*full_strided_fn_ptr_t)(sycl::queue &,
int,
size_t,
py::ssize_t *,
py::ssize_t,
const py::object &,
char *,
const std::vector<sycl::event> &);

/*!
* @brief Function to submit kernel to fill given strided memory allocation
* with specified value.
*
* @param exec_q Sycl queue to which kernel is submitted for execution.
* @param nd Array dimensionality
* @param nelems Length of the sequence
* @param shape_strides Kernel accessible USM pointer to packed shape and
* strides of array.
* @param dst_offset Displacement of first element of dst relative dst_p in
* elements
* @param py_value Python object representing the value to fill the array with.
* Must be convertible to `dstTy`.
* @param dst_p Kernel accessible USM pointer to the start of array to be
* populated.
* @param depends List of events to wait for before starting computations, if
* any.
*
* @return Event to wait on to ensure that computation completes.
* @defgroup CtorKernels
*/
template <typename dstTy>
sycl::event full_strided_impl(sycl::queue &exec_q,
int nd,
size_t nelems,
py::ssize_t *shape_strides,
py::ssize_t dst_offset,
const py::object &py_value,
char *dst_p,
const std::vector<sycl::event> &depends)
{
dstTy fill_v = py::cast<dstTy>(py_value);

using dpctl::tensor::kernels::constructors::full_strided_impl;
sycl::event fill_ev = full_strided_impl<dstTy>(
exec_q, nd, nelems, shape_strides, dst_offset, fill_v, dst_p, depends);

return fill_ev;
}

template <typename fnT, typename Ty> struct FullStridedFactory
{
fnT get()
{
fnT f = full_strided_impl<Ty>;
return f;
}
};

static full_contig_fn_ptr_t full_contig_dispatch_vector[td_ns::num_types];
static full_strided_fn_ptr_t full_strided_dispatch_vector[td_ns::num_types];

std::pair<sycl::event, sycl::event>
usm_ndarray_full(const py::object &py_value,
Expand Down Expand Up @@ -194,8 +254,70 @@ usm_ndarray_full(const py::object &py_value,
full_contig_event);
}
else {
throw std::runtime_error(
"Only population of contiguous usm_ndarray objects is supported.");
using dpctl::tensor::py_internal::simplify_iteration_space_1;

int nd = dst.get_ndim();
const py::ssize_t *dst_shape_ptr = dst.get_shape_raw();
auto const &dst_strides = dst.get_strides_vector();

using shT = std::vector<py::ssize_t>;
shT simplified_dst_shape;
shT simplified_dst_strides;
py::ssize_t dst_offset(0);

simplify_iteration_space_1(nd, dst_shape_ptr, dst_strides,
// output
simplified_dst_shape, simplified_dst_strides,
dst_offset);

// it's possible that this branch will never be taken
// need to look carefully at `simplify_iteration_space_1`
// to find cases
if (nd == 1 && simplified_dst_strides[0] == 1) {
auto fn = full_contig_dispatch_vector[dst_typeid];

const sycl::event &full_contig_event =
fn(exec_q, static_cast<size_t>(dst_nelems), py_value,
dst_data + dst_offset, depends);

return std::make_pair(
keep_args_alive(exec_q, {dst}, {full_contig_event}),
full_contig_event);
}

auto fn = full_strided_dispatch_vector[dst_typeid];

std::vector<sycl::event> host_task_events;
host_task_events.reserve(2);
using dpctl::tensor::offset_utils::device_allocate_and_pack;
const auto &ptr_size_event_tuple =
device_allocate_and_pack<py::ssize_t>(exec_q, host_task_events,
simplified_dst_shape,
simplified_dst_strides);
py::ssize_t *shape_strides = std::get<0>(ptr_size_event_tuple);
if (shape_strides == nullptr) {
throw std::runtime_error("Unable to allocate device memory");
}
const sycl::event &copy_shape_ev = std::get<2>(ptr_size_event_tuple);

const sycl::event &full_strided_ev =
fn(exec_q, nd, dst_nelems, shape_strides, dst_offset, py_value,
dst_data, {copy_shape_ev});

// free shape_strides
const auto &ctx = exec_q.get_context();
const auto &temporaries_cleanup_ev =
exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(full_strided_ev);
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
cgh.host_task([ctx, shape_strides]() {
sycl_free_noexcept(shape_strides, ctx);
});
});
host_task_events.push_back(temporaries_cleanup_ev);

return std::make_pair(keep_args_alive(exec_q, {dst}, host_task_events),
full_strided_ev);
}
}

Expand All @@ -204,10 +326,12 @@ void init_full_ctor_dispatch_vectors(void)
using namespace td_ns;

DispatchVectorBuilder<full_contig_fn_ptr_t, FullContigFactory, num_types>
dvb;
dvb.populate_dispatch_vector(full_contig_dispatch_vector);
dvb1;
dvb1.populate_dispatch_vector(full_contig_dispatch_vector);

return;
DispatchVectorBuilder<full_strided_fn_ptr_t, FullStridedFactory, num_types>
dvb2;
dvb2.populate_dispatch_vector(full_strided_dispatch_vector);
}

} // namespace py_internal
Expand Down

0 comments on commit 4346510

Please sign in to comment.