Skip to content

Commit

Permalink
Make sure blas function argument actually have an associated buffer
Browse files Browse the repository at this point in the history
Otherwise call a fallback. Fix #2248
  • Loading branch information
serge-sans-paille committed Nov 24, 2024
1 parent bd488d2 commit ef5de41
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 13 deletions.
16 changes: 8 additions & 8 deletions pythran/pythonic/include/numpy/dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ struct is_strided {
template <class E>
struct is_blas_array {
static constexpr bool value =
pythonic::types::is_array<E>::value &&
pythonic::types::has_buffer<E>::value &&
is_blas_type<typename pythonic::types::dtype_of<E>::type>::value &&
!is_strided<E>::value;
};

template <class E>
struct is_blas_expr {
struct is_blas_view {
static constexpr bool value =
pythonic::types::is_array<E>::value &&
pythonic::types::has_buffer<E>::value &&
is_blas_type<typename pythonic::types::dtype_of<E>::type>::value;
};

Expand All @@ -56,7 +56,7 @@ namespace numpy
typename std::enable_if<
types::is_numexpr_arg<E>::value && types::is_numexpr_arg<F>::value &&
E::value == 1 && F::value == 1 &&
(!is_blas_expr<E>::value || !is_blas_expr<F>::value ||
(!is_blas_view<E>::value || !is_blas_view<F>::value ||
!std::is_same<typename E::dtype, typename F::dtype>::value),
typename __combined<typename E::dtype, typename F::dtype>::type>::type
dot(E const &e, F const &f);
Expand Down Expand Up @@ -102,7 +102,7 @@ namespace numpy
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, float>::value &&
std::is_same<typename F::dtype, float>::value &&
(is_blas_expr<E>::value && is_blas_expr<F>::value &&
(is_blas_view<E>::value && is_blas_view<F>::value &&
!(is_blas_array<E>::value && is_blas_array<F>::value)),
float>::type
dot(E const &e, F const &f);
Expand All @@ -112,7 +112,7 @@ namespace numpy
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, double>::value &&
std::is_same<typename F::dtype, double>::value &&
(is_blas_expr<E>::value && is_blas_expr<F>::value &&
(is_blas_view<E>::value && is_blas_view<F>::value &&
!(is_blas_array<E>::value && is_blas_array<F>::value)),
double>::type
dot(E const &e, F const &f);
Expand All @@ -122,7 +122,7 @@ namespace numpy
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, std::complex<float>>::value &&
std::is_same<typename F::dtype, std::complex<float>>::value &&
(is_blas_expr<E>::value && is_blas_expr<F>::value &&
(is_blas_view<E>::value && is_blas_view<F>::value &&
!(is_blas_array<E>::value && is_blas_array<F>::value)),
std::complex<float>>::type
dot(E const &e, F const &f);
Expand All @@ -132,7 +132,7 @@ namespace numpy
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, std::complex<double>>::value &&
std::is_same<typename F::dtype, std::complex<double>>::value &&
(is_blas_expr<E>::value && is_blas_expr<F>::value &&
(is_blas_view<E>::value && is_blas_view<F>::value &&
!(is_blas_array<E>::value && is_blas_array<F>::value)),
std::complex<double>>::type
dot(E const &e, F const &f);
Expand Down
19 changes: 19 additions & 0 deletions pythran/pythonic/include/utils/numpy_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,25 @@ namespace types
static T get(...);
using type = decltype(get<E>(nullptr));
};

template <class T>
struct has_buffer {
static constexpr bool value = false;
};

template <class T, class pS>
struct has_buffer<ndarray<T, pS>> {
static constexpr bool value = true;
};

template <class A>
struct has_buffer<numpy_iexpr<A>> : has_buffer<A>{
};

template <class A, class... S>
struct has_buffer<numpy_gexpr<A, S...>> : has_buffer<A> {
};

} // namespace types
PYTHONIC_NS_END

Expand Down
10 changes: 5 additions & 5 deletions pythran/pythonic/numpy/dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1354,7 +1354,7 @@ namespace numpy
types::is_numexpr_arg<E>::value &&
types::is_numexpr_arg<F>::value // Arguments are array_like
&& E::value == 1 && F::value == 1 // It is a two vectors.
&& (!is_blas_expr<E>::value || !is_blas_expr<F>::value ||
&& (!is_blas_view<E>::value || !is_blas_view<F>::value ||
!std::is_same<typename E::dtype, typename F::dtype>::value),
typename __combined<typename E::dtype, typename F::dtype>::type>::type
dot(E const &e, F const &f)
Expand Down Expand Up @@ -1423,7 +1423,7 @@ namespace numpy
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, float>::value &&
std::is_same<typename F::dtype, float>::value &&
(is_blas_expr<E>::value && is_blas_expr<F>::value &&
(is_blas_view<E>::value && is_blas_view<F>::value &&
!(is_blas_array<E>::value && is_blas_array<F>::value)),
float>::type
dot(E const &e, F const &f)
Expand All @@ -1442,7 +1442,7 @@ namespace numpy
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, double>::value &&
std::is_same<typename F::dtype, double>::value &&
(is_blas_expr<E>::value && is_blas_expr<F>::value &&
(is_blas_view<E>::value && is_blas_view<F>::value &&
!(is_blas_array<E>::value && is_blas_array<F>::value)),
double>::type
dot(E const &e, F const &f)
Expand All @@ -1461,7 +1461,7 @@ namespace numpy
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, std::complex<float>>::value &&
std::is_same<typename F::dtype, std::complex<float>>::value &&
(is_blas_expr<E>::value && is_blas_expr<F>::value &&
(is_blas_view<E>::value && is_blas_view<F>::value &&
!(is_blas_array<E>::value && is_blas_array<F>::value)),
std::complex<float>>::type
dot(E const &e, F const &f)
Expand All @@ -1482,7 +1482,7 @@ namespace numpy
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, std::complex<double>>::value &&
std::is_same<typename F::dtype, std::complex<double>>::value &&
(is_blas_expr<E>::value && is_blas_expr<F>::value &&
(is_blas_view<E>::value && is_blas_view<F>::value &&
!(is_blas_array<E>::value && is_blas_array<F>::value)),
std::complex<double>>::type
dot(E const &e, F const &f)
Expand Down
12 changes: 12 additions & 0 deletions pythran/tests/test_numpy_func3.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,18 @@ def np_dot30(x, y):
np_dot30=[NDArray[numpy.float32,:],
NDArray[numpy.float32,:]])

def test_dot31(self):
''' 1d x 1d, expr'''
self.run_test("""
def np_dot31(x):
import numpy as np
func = np.exp(-2j * np.pi * np.arange(len(x)))
normFunc = np.sqrt(np.real(np.dot(np.conjugate(func), func)))
return normFunc
""",
numpy.arange(24., dtype=numpy.float32),
np_dot31=[NDArray[numpy.float32,:]])

def test_vdot0(self):
self.run_test("""
def np_vdot0(x, y):
Expand Down

0 comments on commit ef5de41

Please sign in to comment.