Skip to content

Commit

Permalink
Add is_defined static boolean to output type tables and uses them i…
Browse files Browse the repository at this point in the history
…n type dispatching
  • Loading branch information
ndgrigorian committed Aug 29, 2024
1 parent 0aa5321 commit 2ffedad
Show file tree
Hide file tree
Showing 68 changed files with 310 additions and 502 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ template <typename T> struct AbsOutputType
td_ns::TypeMapResultEntry<T, std::complex<float>, float>,
td_ns::TypeMapResultEntry<T, std::complex<double>, double>,
td_ns::DefaultResultEntry<void>>::result_type;

static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
Expand All @@ -139,9 +141,7 @@ template <typename fnT, typename T> struct AbsContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AbsOutputType<T>::value_type,
void>)
{
if constexpr (!AbsOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -190,9 +190,7 @@ template <typename fnT, typename T> struct AbsStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AbsOutputType<T>::value_type,
void>)
{
if constexpr (!AbsOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ template <typename T> struct AcosOutputType
td_ns::TypeMapResultEntry<T, std::complex<float>>,
td_ns::TypeMapResultEntry<T, std::complex<double>>,
td_ns::DefaultResultEntry<void>>::result_type;

static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
Expand All @@ -173,9 +175,7 @@ template <typename fnT, typename T> struct AcosContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AcosOutputType<T>::value_type,
void>)
{
if constexpr (!AcosOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -221,9 +221,7 @@ template <typename fnT, typename T> struct AcosStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AcosOutputType<T>::value_type,
void>)
{
if constexpr (!AcosOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ template <typename T> struct AcoshOutputType
td_ns::TypeMapResultEntry<T, std::complex<float>>,
td_ns::TypeMapResultEntry<T, std::complex<double>>,
td_ns::DefaultResultEntry<void>>::result_type;

static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
Expand All @@ -200,9 +202,7 @@ template <typename fnT, typename T> struct AcoshContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AcoshOutputType<T>::value_type,
void>)
{
if constexpr (!AcoshOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -248,9 +248,7 @@ template <typename fnT, typename T> struct AcoshStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AcoshOutputType<T>::value_type,
void>)
{
if constexpr (!AcoshOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ template <typename T1, typename T2> struct AddOutputType
std::complex<double>,
std::complex<double>>,
td_ns::DefaultResultEntry<void>>::result_type;

static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

template <typename argT1,
Expand Down Expand Up @@ -222,9 +224,7 @@ template <typename fnT, typename T1, typename T2> struct AddContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
void>)
{
if constexpr (!AddOutputType<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -272,9 +272,7 @@ template <typename fnT, typename T1, typename T2> struct AddStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
void>)
{
if constexpr (!AddOutputType<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -323,12 +321,12 @@ struct AddContigMatrixContigRowBroadcastFactory
{
fnT get()
{
using resT = typename AddOutputType<T1, T2>::value_type;
if constexpr (std::is_same_v<resT, void>) {
if constexpr (!AddOutputType<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
else {
using resT = typename AddOutputType<T1, T2>::value_type;
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
dpctl::tensor::type_utils::is_complex<T2>::value ||
dpctl::tensor::type_utils::is_complex<resT>::value)
Expand Down Expand Up @@ -370,12 +368,12 @@ struct AddContigRowContigMatrixBroadcastFactory
{
fnT get()
{
using resT = typename AddOutputType<T1, T2>::value_type;
if constexpr (std::is_same_v<resT, void>) {
if constexpr (!AddOutputType<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
else {
using resT = typename AddOutputType<T1, T2>::value_type;
if constexpr (dpctl::tensor::type_utils::is_complex<T1>::value ||
dpctl::tensor::type_utils::is_complex<T2>::value ||
dpctl::tensor::type_utils::is_complex<resT>::value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ template <typename T> struct AngleOutputType
td_ns::TypeMapResultEntry<T, std::complex<float>, float>,
td_ns::TypeMapResultEntry<T, std::complex<double>, double>,
td_ns::DefaultResultEntry<void>>::result_type;

static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
Expand All @@ -116,9 +118,7 @@ template <typename fnT, typename T> struct AngleContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AngleOutputType<T>::value_type,
void>)
{
if constexpr (!AngleOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -164,9 +164,7 @@ template <typename fnT, typename T> struct AngleStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AngleOutputType<T>::value_type,
void>)
{
if constexpr (!AngleOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ template <typename T> struct AsinOutputType
td_ns::TypeMapResultEntry<T, std::complex<float>>,
td_ns::TypeMapResultEntry<T, std::complex<double>>,
td_ns::DefaultResultEntry<void>>::result_type;

static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
Expand All @@ -193,9 +195,7 @@ template <typename fnT, typename T> struct AsinContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AsinOutputType<T>::value_type,
void>)
{
if constexpr (!AsinOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -241,9 +241,7 @@ template <typename fnT, typename T> struct AsinStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AsinOutputType<T>::value_type,
void>)
{
if constexpr (!AsinOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ template <typename T> struct AsinhOutputType
td_ns::TypeMapResultEntry<T, std::complex<float>>,
td_ns::TypeMapResultEntry<T, std::complex<double>>,
td_ns::DefaultResultEntry<void>>::result_type;

static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
Expand All @@ -176,9 +178,7 @@ template <typename fnT, typename T> struct AsinhContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AsinhOutputType<T>::value_type,
void>)
{
if constexpr (!AsinhOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -224,9 +224,7 @@ template <typename fnT, typename T> struct AsinhStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AsinhOutputType<T>::value_type,
void>)
{
if constexpr (!AsinhOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ template <typename T> struct AtanOutputType
td_ns::TypeMapResultEntry<T, std::complex<float>>,
td_ns::TypeMapResultEntry<T, std::complex<double>>,
td_ns::DefaultResultEntry<void>>::result_type;

static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
Expand All @@ -183,9 +185,7 @@ template <typename fnT, typename T> struct AtanContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AtanOutputType<T>::value_type,
void>)
{
if constexpr (!AtanOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -231,9 +231,7 @@ template <typename fnT, typename T> struct AtanStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AtanOutputType<T>::value_type,
void>)
{
if constexpr (!AtanOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ template <typename T1, typename T2> struct Atan2OutputType
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, float>,
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, double>,
td_ns::DefaultResultEntry<void>>::result_type;

static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

template <typename argT1,
Expand Down Expand Up @@ -129,9 +131,7 @@ template <typename fnT, typename T1, typename T2> struct Atan2ContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename Atan2OutputType<T1, T2>::value_type, void>)
{
if constexpr (!Atan2OutputType<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand All @@ -148,7 +148,6 @@ template <typename fnT, typename T1, typename T2> struct Atan2TypeMapFactory
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
{
using rT = typename Atan2OutputType<T1, T2>::value_type;
;
return td_ns::GetTypeid<rT>{}.get();
}
};
Expand Down Expand Up @@ -182,9 +181,7 @@ template <typename fnT, typename T1, typename T2> struct Atan2StridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename Atan2OutputType<T1, T2>::value_type, void>)
{
if constexpr (!Atan2OutputType<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ template <typename T> struct AtanhOutputType
td_ns::TypeMapResultEntry<T, std::complex<float>>,
td_ns::TypeMapResultEntry<T, std::complex<double>>,
td_ns::DefaultResultEntry<void>>::result_type;

static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
Expand All @@ -177,9 +179,7 @@ template <typename fnT, typename T> struct AtanhContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AtanhOutputType<T>::value_type,
void>)
{
if constexpr (!AtanhOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -225,9 +225,7 @@ template <typename fnT, typename T> struct AtanhStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AtanhOutputType<T>::value_type,
void>)
{
if constexpr (!AtanhOutputType<T>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ template <typename T1, typename T2> struct BitwiseAndOutputType
std::int64_t,
std::int64_t>,
td_ns::DefaultResultEntry<void>>::result_type;

static constexpr bool is_defined = !std::is_same_v<value_type, void>;
};

template <typename argT1,
Expand Down Expand Up @@ -187,10 +189,7 @@ template <typename fnT, typename T1, typename T2> struct BitwiseAndContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename BitwiseAndOutputType<T1, T2>::value_type,
void>)
{
if constexpr (!BitwiseAndOutputType<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down Expand Up @@ -243,10 +242,7 @@ struct BitwiseAndStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename BitwiseAndOutputType<T1, T2>::value_type,
void>)
{
if constexpr (!BitwiseAndOutputType<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
Expand Down
Loading

0 comments on commit 2ffedad

Please sign in to comment.