diff --git a/src/vampyr/tests/test_projector1d.py b/src/vampyr/tests/test_projector1d.py new file mode 100644 index 00000000..d11ba287 --- /dev/null +++ b/src/vampyr/tests/test_projector1d.py @@ -0,0 +1,17 @@ +import pytest + +from vampyr import vampyr1d as vp + +def test_ScalingProjector(): + def f(x): + return x + + mra = vp.MultiResolutionAnalysis(box=[0, 1], order=7) + P_scaling = vp.ScalingProjector(mra, 2) + P_wavelet = vp.WaveletProjector(mra, 2) + + with pytest.raises(Exception): + P_scaling(f) + + with pytest.raises(Exception): + P_wavelet(f) \ No newline at end of file diff --git a/src/vampyr/tests/test_projector3d.py b/src/vampyr/tests/test_projector3d.py new file mode 100644 index 00000000..e7c949fc --- /dev/null +++ b/src/vampyr/tests/test_projector3d.py @@ -0,0 +1,17 @@ +import pytest + +from vampyr import vampyr3d as vp + +def test_ScalingProjector(): + def f(x): + return x + + mra = vp.MultiResolutionAnalysis(box=[0, 1], order=7) + P_scaling = vp.ScalingProjector(mra, 2) + P_wavelet = vp.WaveletProjector(mra, 2) + + with pytest.raises(Exception): + P_scaling(f) + + with pytest.raises(Exception): + P_wavelet(f) \ No newline at end of file diff --git a/src/vampyr/treebuilders/project.h b/src/vampyr/treebuilders/project.h index 9fe25848..145f75b4 100644 --- a/src/vampyr/treebuilders/project.h +++ b/src/vampyr/treebuilders/project.h @@ -1,7 +1,8 @@ #pragma once -#include +#include +#include #include "PyProjectors.h" namespace vampyr { @@ -18,6 +19,15 @@ template void project(pybind11::module &m) { .def( "__call__", [](PyScalingProjector &P, std::function &r)> func) { + + try { + auto arr = std::array(); + arr.fill(111111.111); // A number which hopefully does not divide by zero + func(arr); + } catch (py::cast_error &e) { + py::print("Error: Invalid definition of analytic function"); + throw; + } auto old_threads = mrcpp_get_num_threads(); set_max_threads(1); auto out = P(func); @@ -33,6 +43,16 @@ template void project(pybind11::module &m) { .def( "__call__", [](PyWaveletProjector &P, std::function &r)> func) { + + try { + auto arr = std::array(); + arr.fill(111111.111); // A number which hopefully does not divide by zero + func(arr); + } catch (py::cast_error &e) { + py::print("Error: Invalid definition of analytic function"); + throw; + } + auto old_threads = mrcpp_get_num_threads(); set_max_threads(1); auto out = P(func);