Skip to content

Commit

Permalink
teach pythran to use blas for strided array view
Browse files Browse the repository at this point in the history
  • Loading branch information
serge-sans-paille committed Aug 31, 2024
1 parent a7f4026 commit a76f8c0
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 3 deletions.
50 changes: 48 additions & 2 deletions pythran/pythonic/include/numpy/dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,19 @@ struct is_strided {

template <class E>
struct is_blas_array {
// FIXME: also support gexpr with stride?
static constexpr bool value =
pythonic::types::is_array<E>::value &&
is_blas_type<typename pythonic::types::dtype_of<E>::type>::value &&
!is_strided<E>::value;
};

template <class E>
struct is_blas_expr {
static constexpr bool value =
pythonic::types::is_array<E>::value &&
is_blas_type<typename pythonic::types::dtype_of<E>::type>::value;
};

PYTHONIC_NS_BEGIN

namespace numpy
Expand All @@ -50,7 +56,7 @@ namespace numpy
typename std::enable_if<
types::is_numexpr_arg<E>::value && types::is_numexpr_arg<F>::value &&
E::value == 1 && F::value == 1 &&
(!is_blas_array<E>::value || !is_blas_array<F>::value ||
(!is_blas_expr<E>::value || !is_blas_expr<F>::value ||
!std::is_same<typename E::dtype, typename F::dtype>::value),
typename __combined<typename E::dtype, typename F::dtype>::type>::type
dot(E const &e, F const &f);
Expand Down Expand Up @@ -91,6 +97,46 @@ namespace numpy
std::complex<double>>::type
dot(E const &e, F const &f);

template <class E, class F>
typename std::enable_if<
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, float>::value &&
std::is_same<typename F::dtype, float>::value &&
(is_blas_expr<E>::value && is_blas_expr<F>::value &&
!(is_blas_array<E>::value && is_blas_array<F>::value)),
float>::type
dot(E const &e, F const &f);

template <class E, class F>
typename std::enable_if<
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, double>::value &&
std::is_same<typename F::dtype, double>::value &&
(is_blas_expr<E>::value && is_blas_expr<F>::value &&
!(is_blas_array<E>::value && is_blas_array<F>::value)),
double>::type
dot(E const &e, F const &f);

template <class E, class F>
typename std::enable_if<
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, std::complex<float>>::value &&
std::is_same<typename F::dtype, std::complex<float>>::value &&
(is_blas_expr<E>::value && is_blas_expr<F>::value &&
!(is_blas_array<E>::value && is_blas_array<F>::value)),
std::complex<float>>::type
dot(E const &e, F const &f);

template <class E, class F>
typename std::enable_if<
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, std::complex<double>>::value &&
std::is_same<typename F::dtype, std::complex<double>>::value &&
(is_blas_expr<E>::value && is_blas_expr<F>::value &&
!(is_blas_array<E>::value && is_blas_array<F>::value)),
std::complex<double>>::type
dot(E const &e, F const &f);

/// Matrix / Vector multiplication

// We transpose the matrix to reflect our C order
Expand Down
89 changes: 88 additions & 1 deletion pythran/pythonic/numpy/dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "pythonic/include/numpy/dot.hpp"

#include "pythonic/numpy/asarray.hpp"
#include "pythonic/numpy/multiply.hpp"
#include "pythonic/numpy/sum.hpp"
#include "pythonic/types/ndarray.hpp"
Expand Down Expand Up @@ -1342,12 +1343,18 @@ namespace numpy
return blas_buffer_t<E>{}(e);
}

template <class E, class... S>
typename E::dtype const *blas_buffer(types::numpy_gexpr<E, S...> const &e)
{
return e.data();
}

template <class E, class F>
typename std::enable_if<
types::is_numexpr_arg<E>::value &&
types::is_numexpr_arg<F>::value // Arguments are array_like
&& E::value == 1 && F::value == 1 // It is a two vectors.
&& (!is_blas_array<E>::value || !is_blas_array<F>::value ||
&& (!is_blas_expr<E>::value || !is_blas_expr<F>::value ||
!std::is_same<typename E::dtype, typename F::dtype>::value),
typename __combined<typename E::dtype, typename F::dtype>::type>::type
dot(E const &e, F const &f)
Expand Down Expand Up @@ -1411,6 +1418,86 @@ namespace numpy
return out;
}

template <class E, class F>
typename std::enable_if<
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, float>::value &&
std::is_same<typename F::dtype, float>::value &&
(is_blas_expr<E>::value && is_blas_expr<F>::value &&
!(is_blas_array<E>::value && is_blas_array<F>::value)),
float>::type
dot(E const &e, F const &f)
{
if (e.template strides<0>() >= 1 && f.template strides<0>() >= 1) {
return BLAS_MANGLE(cblas_sdot)(e.size(), blas_buffer(e),
e.template strides<0>(), blas_buffer(f),
f.template strides<0>());
} else {
return dot(asarray(e), asarray(f));
}
}

template <class E, class F>
typename std::enable_if<
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, double>::value &&
std::is_same<typename F::dtype, double>::value &&
(is_blas_expr<E>::value && is_blas_expr<F>::value &&
!(is_blas_array<E>::value && is_blas_array<F>::value)),
double>::type
dot(E const &e, F const &f)
{
if (e.template strides<0>() >= 1 && f.template strides<0>() >= 1) {
return BLAS_MANGLE(cblas_ddot)(e.size(), blas_buffer(e),
e.template strides<0>(), blas_buffer(f),
f.template strides<0>());
} else {
return dot(asarray(e), asarray(f));
}
}

template <class E, class F>
typename std::enable_if<
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, std::complex<float>>::value &&
std::is_same<typename F::dtype, std::complex<float>>::value &&
(is_blas_expr<E>::value && is_blas_expr<F>::value &&
!(is_blas_array<E>::value && is_blas_array<F>::value)),
std::complex<float>>::type
dot(E const &e, F const &f)
{
if (e.template strides<0>() >= 1 && f.template strides<0>() >= 1) {
std::complex<float> out;
BLAS_MANGLE(cblas_cdotu_sub)
(e.size(), blas_buffer(e), e.template strides<0>(), blas_buffer(f),
f.template strides<0>(), &out);
return out;
} else {
return dot(asarray(e), asarray(f));
}
}

template <class E, class F>
typename std::enable_if<
E::value == 1 && F::value == 1 &&
std::is_same<typename E::dtype, std::complex<double>>::value &&
std::is_same<typename F::dtype, std::complex<double>>::value &&
(is_blas_expr<E>::value && is_blas_expr<F>::value &&
!(is_blas_array<E>::value && is_blas_array<F>::value)),
std::complex<double>>::type
dot(E const &e, F const &f)
{
if (e.template strides<0>() >= 1 && f.template strides<0>() >= 1) {
std::complex<double> out;
BLAS_MANGLE(cblas_zdotu_sub)
(e.size(), blas_buffer(e), e.template strides<0>(), blas_buffer(f),
f.template strides<0>(), &out);
return out;
} else {
return dot(asarray(e), asarray(f));
}
}

/// Matrice / Vector multiplication

#define MV_DEF(T, L) \
Expand Down
66 changes: 66 additions & 0 deletions pythran/tests/test_numpy_func3.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,72 @@ def np_dot24(x, y):
np_dot24=[NDArray[numpy.float32,:,:,:],
NDArray[numpy.float64,:,:,:,:]])

def test_dot25(self):
''' 1d x 1d, slice'''
self.run_test("""
def np_dot25(x, y):
from numpy import dot
return dot(x[1:], y[:-1])""",
numpy.arange(24., dtype=numpy.float32),
numpy.arange(24., dtype=numpy.float32),
np_dot25=[NDArray[numpy.float32,:],
NDArray[numpy.float32,:]])

def test_dot26(self):
''' 1d x 1d, slice'''
self.run_test("""
def np_dot26(x, y):
from numpy import dot
return dot(x[1:], y[:-1])""",
numpy.arange(24., dtype=numpy.float64),
numpy.arange(24., dtype=numpy.float64),
np_dot26=[NDArray[numpy.float64,:],
NDArray[numpy.float64,:]])

def test_dot27(self):
''' 1d x 1d, slice'''
self.run_test("""
def np_dot27(x, y):
from numpy import dot
return dot(x[1:], y[:-1])""",
numpy.arange(24., dtype=numpy.complex64),
numpy.arange(24., dtype=numpy.complex64),
np_dot27=[NDArray[numpy.complex64,:],
NDArray[numpy.complex64,:]])

def test_dot28(self):
''' 1d x 1d, slice'''
self.run_test("""
def np_dot28(x, y):
from numpy import dot
return dot(x[1:], y[:-1])""",
numpy.arange(24., dtype=numpy.complex128),
numpy.arange(24., dtype=numpy.complex128),
np_dot28=[NDArray[numpy.complex128,:],
NDArray[numpy.complex128,:]])

def test_dot29(self):
''' 1d x 1d, slice'''
self.run_test("""
def np_dot29(x, y):
from numpy import dot
return dot(x[-1:0:-1], y[:-1])""",
numpy.arange(24., dtype=numpy.float32),
numpy.arange(24., dtype=numpy.float32),
np_dot29=[NDArray[numpy.float32,:],
NDArray[numpy.float32,:]])

def test_dot30(self):
''' 1d x 1d, slice'''
self.run_test("""
def np_dot30(x, y):
from numpy import dot
return dot(x[-1:0:-1], y[-1:0:-1])""",
numpy.arange(24., dtype=numpy.float32),
numpy.arange(24., dtype=numpy.float32),
np_dot30=[NDArray[numpy.float32,:],
NDArray[numpy.float32,:]])

def test_vdot0(self):
self.run_test("""
def np_vdot0(x, y):
Expand Down

0 comments on commit a76f8c0

Please sign in to comment.