From a76f8c08cfe1dda03e4ee28143928fc39451a2bb Mon Sep 17 00:00:00 2001 From: serge-sans-paille Date: Sat, 31 Aug 2024 17:21:15 +0200 Subject: [PATCH] teach pythran to use blas for strided array view --- pythran/pythonic/include/numpy/dot.hpp | 50 ++++++++++++++- pythran/pythonic/numpy/dot.hpp | 89 +++++++++++++++++++++++++- pythran/tests/test_numpy_func3.py | 66 +++++++++++++++++++ 3 files changed, 202 insertions(+), 3 deletions(-) diff --git a/pythran/pythonic/include/numpy/dot.hpp b/pythran/pythonic/include/numpy/dot.hpp index 8778456ae..1d7b2a257 100644 --- a/pythran/pythonic/include/numpy/dot.hpp +++ b/pythran/pythonic/include/numpy/dot.hpp @@ -28,13 +28,19 @@ struct is_strided { template struct is_blas_array { - // FIXME: also support gexpr with stride? static constexpr bool value = pythonic::types::is_array::value && is_blas_type::type>::value && !is_strided::value; }; +template +struct is_blas_expr { + static constexpr bool value = + pythonic::types::is_array::value && + is_blas_type::type>::value; +}; + PYTHONIC_NS_BEGIN namespace numpy @@ -50,7 +56,7 @@ namespace numpy typename std::enable_if< types::is_numexpr_arg::value && types::is_numexpr_arg::value && E::value == 1 && F::value == 1 && - (!is_blas_array::value || !is_blas_array::value || + (!is_blas_expr::value || !is_blas_expr::value || !std::is_same::value), typename __combined::type>::type dot(E const &e, F const &f); @@ -91,6 +97,46 @@ namespace numpy std::complex>::type dot(E const &e, F const &f); + template + typename std::enable_if< + E::value == 1 && F::value == 1 && + std::is_same::value && + std::is_same::value && + (is_blas_expr::value && is_blas_expr::value && + !(is_blas_array::value && is_blas_array::value)), + float>::type + dot(E const &e, F const &f); + + template + typename std::enable_if< + E::value == 1 && F::value == 1 && + std::is_same::value && + std::is_same::value && + (is_blas_expr::value && is_blas_expr::value && + !(is_blas_array::value && is_blas_array::value)), + double>::type + dot(E const &e, F const &f); + + template + typename std::enable_if< + E::value == 1 && F::value == 1 && + std::is_same>::value && + std::is_same>::value && + (is_blas_expr::value && is_blas_expr::value && + !(is_blas_array::value && is_blas_array::value)), + std::complex>::type + dot(E const &e, F const &f); + + template + typename std::enable_if< + E::value == 1 && F::value == 1 && + std::is_same>::value && + std::is_same>::value && + (is_blas_expr::value && is_blas_expr::value && + !(is_blas_array::value && is_blas_array::value)), + std::complex>::type + dot(E const &e, F const &f); + /// Matrix / Vector multiplication // We transpose the matrix to reflect our C order diff --git a/pythran/pythonic/numpy/dot.hpp b/pythran/pythonic/numpy/dot.hpp index a0b7edd0c..2c1f65b3d 100644 --- a/pythran/pythonic/numpy/dot.hpp +++ b/pythran/pythonic/numpy/dot.hpp @@ -3,6 +3,7 @@ #include "pythonic/include/numpy/dot.hpp" +#include "pythonic/numpy/asarray.hpp" #include "pythonic/numpy/multiply.hpp" #include "pythonic/numpy/sum.hpp" #include "pythonic/types/ndarray.hpp" @@ -1342,12 +1343,18 @@ namespace numpy return blas_buffer_t{}(e); } + template + typename E::dtype const *blas_buffer(types::numpy_gexpr const &e) + { + return e.data(); + } + template typename std::enable_if< types::is_numexpr_arg::value && types::is_numexpr_arg::value // Arguments are array_like && E::value == 1 && F::value == 1 // It is a two vectors. - && (!is_blas_array::value || !is_blas_array::value || + && (!is_blas_expr::value || !is_blas_expr::value || !std::is_same::value), typename __combined::type>::type dot(E const &e, F const &f) @@ -1411,6 +1418,86 @@ namespace numpy return out; } + template + typename std::enable_if< + E::value == 1 && F::value == 1 && + std::is_same::value && + std::is_same::value && + (is_blas_expr::value && is_blas_expr::value && + !(is_blas_array::value && is_blas_array::value)), + float>::type + dot(E const &e, F const &f) + { + if (e.template strides<0>() >= 1 && f.template strides<0>() >= 1) { + return BLAS_MANGLE(cblas_sdot)(e.size(), blas_buffer(e), + e.template strides<0>(), blas_buffer(f), + f.template strides<0>()); + } else { + return dot(asarray(e), asarray(f)); + } + } + + template + typename std::enable_if< + E::value == 1 && F::value == 1 && + std::is_same::value && + std::is_same::value && + (is_blas_expr::value && is_blas_expr::value && + !(is_blas_array::value && is_blas_array::value)), + double>::type + dot(E const &e, F const &f) + { + if (e.template strides<0>() >= 1 && f.template strides<0>() >= 1) { + return BLAS_MANGLE(cblas_ddot)(e.size(), blas_buffer(e), + e.template strides<0>(), blas_buffer(f), + f.template strides<0>()); + } else { + return dot(asarray(e), asarray(f)); + } + } + + template + typename std::enable_if< + E::value == 1 && F::value == 1 && + std::is_same>::value && + std::is_same>::value && + (is_blas_expr::value && is_blas_expr::value && + !(is_blas_array::value && is_blas_array::value)), + std::complex>::type + dot(E const &e, F const &f) + { + if (e.template strides<0>() >= 1 && f.template strides<0>() >= 1) { + std::complex out; + BLAS_MANGLE(cblas_cdotu_sub) + (e.size(), blas_buffer(e), e.template strides<0>(), blas_buffer(f), + f.template strides<0>(), &out); + return out; + } else { + return dot(asarray(e), asarray(f)); + } + } + + template + typename std::enable_if< + E::value == 1 && F::value == 1 && + std::is_same>::value && + std::is_same>::value && + (is_blas_expr::value && is_blas_expr::value && + !(is_blas_array::value && is_blas_array::value)), + std::complex>::type + dot(E const &e, F const &f) + { + if (e.template strides<0>() >= 1 && f.template strides<0>() >= 1) { + std::complex out; + BLAS_MANGLE(cblas_zdotu_sub) + (e.size(), blas_buffer(e), e.template strides<0>(), blas_buffer(f), + f.template strides<0>(), &out); + return out; + } else { + return dot(asarray(e), asarray(f)); + } + } + /// Matrice / Vector multiplication #define MV_DEF(T, L) \ diff --git a/pythran/tests/test_numpy_func3.py b/pythran/tests/test_numpy_func3.py index 2e288525c..5588c5535 100644 --- a/pythran/tests/test_numpy_func3.py +++ b/pythran/tests/test_numpy_func3.py @@ -303,6 +303,72 @@ def np_dot24(x, y): np_dot24=[NDArray[numpy.float32,:,:,:], NDArray[numpy.float64,:,:,:,:]]) + def test_dot25(self): + ''' 1d x 1d, slice''' + self.run_test(""" + def np_dot25(x, y): + from numpy import dot + return dot(x[1:], y[:-1])""", + numpy.arange(24., dtype=numpy.float32), + numpy.arange(24., dtype=numpy.float32), + np_dot25=[NDArray[numpy.float32,:], + NDArray[numpy.float32,:]]) + + def test_dot26(self): + ''' 1d x 1d, slice''' + self.run_test(""" + def np_dot26(x, y): + from numpy import dot + return dot(x[1:], y[:-1])""", + numpy.arange(24., dtype=numpy.float64), + numpy.arange(24., dtype=numpy.float64), + np_dot26=[NDArray[numpy.float64,:], + NDArray[numpy.float64,:]]) + + def test_dot27(self): + ''' 1d x 1d, slice''' + self.run_test(""" + def np_dot27(x, y): + from numpy import dot + return dot(x[1:], y[:-1])""", + numpy.arange(24., dtype=numpy.complex64), + numpy.arange(24., dtype=numpy.complex64), + np_dot27=[NDArray[numpy.complex64,:], + NDArray[numpy.complex64,:]]) + + def test_dot28(self): + ''' 1d x 1d, slice''' + self.run_test(""" + def np_dot28(x, y): + from numpy import dot + return dot(x[1:], y[:-1])""", + numpy.arange(24., dtype=numpy.complex128), + numpy.arange(24., dtype=numpy.complex128), + np_dot28=[NDArray[numpy.complex128,:], + NDArray[numpy.complex128,:]]) + + def test_dot29(self): + ''' 1d x 1d, slice''' + self.run_test(""" + def np_dot29(x, y): + from numpy import dot + return dot(x[-1:0:-1], y[:-1])""", + numpy.arange(24., dtype=numpy.float32), + numpy.arange(24., dtype=numpy.float32), + np_dot29=[NDArray[numpy.float32,:], + NDArray[numpy.float32,:]]) + + def test_dot30(self): + ''' 1d x 1d, slice''' + self.run_test(""" + def np_dot30(x, y): + from numpy import dot + return dot(x[-1:0:-1], y[-1:0:-1])""", + numpy.arange(24., dtype=numpy.float32), + numpy.arange(24., dtype=numpy.float32), + np_dot30=[NDArray[numpy.float32,:], + NDArray[numpy.float32,:]]) + def test_vdot0(self): self.run_test(""" def np_vdot0(x, y):