Skip to content

Commit

Permalink
Merge pull request #1816 from IntelPython/dtype-matrices-for-in-place…
Browse files Browse the repository at this point in the history
…-element-wise-ops

Introduce dedicated type support matrices for in-place element-wise operations
  • Loading branch information
ndgrigorian authored Aug 30, 2024
2 parents e2c7425 + 2ffedad commit 46dc288
Show file tree
Hide file tree
Showing 92 changed files with 1,082 additions and 935 deletions.
8 changes: 3 additions & 5 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,9 @@ def _acceptance_fn_default_unary(arg_dtype, ret_buf_dt, res_dt, sycl_dev):


def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev):
# if the kind of result is different from
# the kind of input, use the default data
# we use default dtype for the resulting kind.
# This guarantees alignment of reciprocal and
# divide output types.
# if the kind of result is different from the kind of input, we use the
# default floating-point dtype for the resulting kind. This guarantees
# alignment of reciprocal and divide output types.
if buf_dt.kind != arg_dtype.kind:
default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev)
if res_dt == default_dt:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ using AbsContigFunctor =

template <typename T> struct AbsOutputType
{
using value_type = typename std::disjunction< // disjunction is C++17
// feature, supported by DPC++
using value_type = typename std::disjunction<
td_ns::TypeMapResultEntry<T, bool>,
td_ns::TypeMapResultEntry<T, std::uint8_t>,
td_ns::TypeMapResultEntry<T, std::uint16_t>,
Expand All @@ -119,6 +118,8 @@ template <typename T> struct AbsOutputType
td_ns::TypeMapResultEntry<T, std::complex<float>, float>,
td_ns::TypeMapResultEntry<T, std::complex<double>, double>,
td_ns::DefaultResultEntry<void>>::result_type;

static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
Expand All @@ -140,9 +141,7 @@ template <typename fnT, typename T> struct AbsContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AbsOutputType<T>::value_type,
void>)
{
if constexpr (!AbsOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -191,9 +190,7 @@ template <typename fnT, typename T> struct AbsStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AbsOutputType<T>::value_type,
void>)
{
if constexpr (!AbsOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,15 @@ using AcosStridedFunctor = elementwise_common::

template <typename T> struct AcosOutputType
{
using value_type = typename std::disjunction< // disjunction is C++17
// feature, supported by DPC++
using value_type = typename std::disjunction<
td_ns::TypeMapResultEntry<T, sycl::half>,
td_ns::TypeMapResultEntry<T, float>,
td_ns::TypeMapResultEntry<T, double>,
td_ns::TypeMapResultEntry<T, std::complex<float>>,
td_ns::TypeMapResultEntry<T, std::complex<double>>,
td_ns::DefaultResultEntry<void>>::result_type;

static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
Expand All @@ -174,9 +175,7 @@ template <typename fnT, typename T> struct AcosContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AcosOutputType<T>::value_type,
void>)
{
if constexpr (!AcosOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -222,9 +221,7 @@ template <typename fnT, typename T> struct AcosStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AcosOutputType<T>::value_type,
void>)
{
if constexpr (!AcosOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,15 @@ using AcoshStridedFunctor = elementwise_common::

template <typename T> struct AcoshOutputType
{
using value_type = typename std::disjunction< // disjunction is C++17
// feature, supported by DPC++
using value_type = typename std::disjunction<
td_ns::TypeMapResultEntry<T, sycl::half>,
td_ns::TypeMapResultEntry<T, float>,
td_ns::TypeMapResultEntry<T, double>,
td_ns::TypeMapResultEntry<T, std::complex<float>>,
td_ns::TypeMapResultEntry<T, std::complex<double>>,
td_ns::DefaultResultEntry<void>>::result_type;

static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
Expand All @@ -201,9 +202,7 @@ template <typename fnT, typename T> struct AcoshContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AcoshOutputType<T>::value_type,
void>)
{
if constexpr (!AcoshOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -249,9 +248,7 @@ template <typename fnT, typename T> struct AcoshStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AcoshOutputType<T>::value_type,
void>)
{
if constexpr (!AcoshOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ using AddStridedFunctor =

template <typename T1, typename T2> struct AddOutputType
{
using value_type = typename std::disjunction< // disjunction is C++17
// feature, supported by DPC++
using value_type = typename std::disjunction<
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, bool>,
td_ns::BinaryTypeMapResultEntry<T1,
std::uint8_t,
Expand Down Expand Up @@ -193,6 +192,8 @@ template <typename T1, typename T2> struct AddOutputType
std::complex<double>,
std::complex<double>>,
td_ns::DefaultResultEntry<void>>::result_type;

static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

template <typename argT1,
Expand Down Expand Up @@ -223,9 +224,7 @@ template <typename fnT, typename T1, typename T2> struct AddContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
void>)
{
if constexpr (!AddOutputType<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -273,9 +272,7 @@ template <typename fnT, typename T1, typename T2> struct AddStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
void>)
{
if constexpr (!AddOutputType<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -324,12 +321,12 @@ struct AddContigMatrixContigRowBroadcastFactory
{
fnT get()
{
using resT = typename AddOutputType<T1, T2>::value_type;
if constexpr (std::is_same_v<resT, void>) {
if constexpr (!AddOutputType<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
else {
using resT = typename AddOutputType<T1, T2>::value_type;
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
dpctl::tensor::type_utils::is_complex<T2>::value ||
dpctl::tensor::type_utils::is_complex<resT>::value)
Expand Down Expand Up @@ -371,12 +368,12 @@ struct AddContigRowContigMatrixBroadcastFactory
{
fnT get()
{
using resT = typename AddOutputType<T1, T2>::value_type;
if constexpr (std::is_same_v<resT, void>) {
if constexpr (!AddOutputType<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
else {
using resT = typename AddOutputType<T1, T2>::value_type;
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
dpctl::tensor::type_utils::is_complex<T2>::value ||
dpctl::tensor::type_utils::is_complex<resT>::value)
Expand Down Expand Up @@ -438,6 +435,50 @@ template <typename argT,
unsigned int n_vecs>
class add_inplace_contig_kernel;

/* @brief Types supported by in-place add */
template <typename argTy, typename resTy> struct AddInplaceTypePairSupport
{
/* value if true a kernel for <argTy, resTy> must be instantiated */
static constexpr bool is_defined = std::disjunction<
td_ns::TypePairDefinedEntry<argTy, bool, resTy, bool>,
td_ns::TypePairDefinedEntry<argTy, std::int8_t, resTy, std::int8_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, resTy, std::uint8_t>,
td_ns::TypePairDefinedEntry<argTy, std::int16_t, resTy, std::int16_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, resTy, std::uint16_t>,
td_ns::TypePairDefinedEntry<argTy, std::int32_t, resTy, std::int32_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, resTy, std::uint32_t>,
td_ns::TypePairDefinedEntry<argTy, std::int64_t, resTy, std::int64_t>,
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, resTy, std::uint64_t>,
td_ns::TypePairDefinedEntry<argTy, sycl::half, resTy, sycl::half>,
td_ns::TypePairDefinedEntry<argTy, float, resTy, float>,
td_ns::TypePairDefinedEntry<argTy, double, resTy, double>,
td_ns::TypePairDefinedEntry<argTy,
std::complex<float>,
resTy,
std::complex<float>>,
td_ns::TypePairDefinedEntry<argTy,
std::complex<double>,
resTy,
std::complex<double>>,
// fall-through
td_ns::NotDefinedEntry>::is_defined;
};

template <typename fnT, typename argT, typename resT>
struct AddInplaceTypeMapFactory
{
/*! @brief get typeid for output type of x += y */
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
{
if constexpr (AddInplaceTypePairSupport<argT, resT>::is_defined) {
return td_ns::GetTypeid<resT>{}.get();
}
else {
return td_ns::GetTypeid<void>{}.get();
}
}
};

template <typename argTy, typename resTy>
sycl::event
add_inplace_contig_impl(sycl::queue &exec_q,
Expand All @@ -457,9 +498,7 @@ template <typename fnT, typename T1, typename T2> struct AddInplaceContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
void>)
{
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -497,9 +536,7 @@ struct AddInplaceStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
void>)
{
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -544,8 +581,7 @@ struct AddInplaceRowMatrixBroadcastFactory
{
fnT get()
{
using resT = typename AddOutputType<T1, T2>::value_type;
if constexpr (!std::is_same_v<resT, T2>) {
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,12 @@ using AngleStridedFunctor = elementwise_common::

template <typename T> struct AngleOutputType
{
using value_type = typename std::disjunction< // disjunction is C++17
// feature, supported by DPC++
using value_type = typename std::disjunction<
td_ns::TypeMapResultEntry<T, std::complex<float>, float>,
td_ns::TypeMapResultEntry<T, std::complex<double>, double>,
td_ns::DefaultResultEntry<void>>::result_type;

static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
Expand All @@ -117,9 +118,7 @@ template <typename fnT, typename T> struct AngleContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AngleOutputType<T>::value_type,
void>)
{
if constexpr (!AngleOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -165,9 +164,7 @@ template <typename fnT, typename T> struct AngleStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AngleOutputType<T>::value_type,
void>)
{
if constexpr (!AngleOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,15 @@ using AsinStridedFunctor = elementwise_common::

template <typename T> struct AsinOutputType
{
using value_type = typename std::disjunction< // disjunction is C++17
// feature, supported by DPC++
using value_type = typename std::disjunction<
td_ns::TypeMapResultEntry<T, sycl::half>,
td_ns::TypeMapResultEntry<T, float>,
td_ns::TypeMapResultEntry<T, double>,
td_ns::TypeMapResultEntry<T, std::complex<float>>,
td_ns::TypeMapResultEntry<T, std::complex<double>>,
td_ns::DefaultResultEntry<void>>::result_type;

static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
Expand All @@ -194,9 +195,7 @@ template <typename fnT, typename T> struct AsinContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AsinOutputType<T>::value_type,
void>)
{
if constexpr (!AsinOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -242,9 +241,7 @@ template <typename fnT, typename T> struct AsinStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AsinOutputType<T>::value_type,
void>)
{
if constexpr (!AsinOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Loading

0 comments on commit 46dc288

Please sign in to comment.