diff --git a/README.md b/README.md index a9824db..403acef 100644 --- a/README.md +++ b/README.md @@ -48,30 +48,38 @@ array([[ 0., 0., 0., 1.], [ 0., 0., 2., 0.], [ 0., 4., 0., -2.], [ 8., 0., -12., 0.]]) - >>> wavefunction_smod(0, 1.0) +>>> wave_smod = wavefunction(s_mode = True, o_dimensional = True, complex_bool = False, cache = False, cache_size = 128) +>>> wave_smmd = wavefunction(s_mode = True, o_dimensional = False, complex_bool = False, cache = False, cache_size = 128) +>>> wave_mmod = wavefunction(s_mode = False, o_dimensional = True, complex_bool = False, cache = False, cache_size = 128) +>>> wave_mmmd = wavefunction(s_mode = False, o_dimensional = False, complex_bool = False, cache = False, cache_size = 128) +>>> c_wave_smod = wavefunction(s_mode = True, o_dimensional = True, complex_bool = True, cache = False, cache_size = 128) +>>> c_wave_smmd = wavefunction(s_mode = True, o_dimensional = False, complex_bool = True, cache = False, cache_size = 128) +>>> c_wave_mmod = wavefunction(s_mode = False, o_dimensional = True, complex_bool = True, cache = False, cache_size = 128) +>>> c_wave_mmmd = wavefunction(s_mode = False, o_dimensional = False, complex_bool = True, cache = False, cache_size = 128) +>>> wave_smod(0, 1.0) 0.45558067201133257 ->>> wavefunction_smod(61, 1.0) +>>> wave_smod(61, 1.0) -0.2393049199171131 ->>> c_wavefunction_smod(0,1.0+2.0j) +>>> c_wave_smod(0,1.0+2.0j) (-1.4008797330262455-3.0609780602975003j) ->>> c_wavefunction_smod(61,1.0+2.0j) +>>> c_wave_smod(61,1.0+2.0j) (-511062135.47555304+131445997.75753704j) ->>> wavefunction_smmd(0,(1.0,2.0)) +>>> wave_smmd(0,np.array([1.0,2.0])) array([0.45558067, 0.10165379]) ->>> wavefunction_smmd(61,(1.0,2.0)) +>>> wave_smmd(61,np.array([1.0,2.0])) array([-0.23930492, -0.01677378]) ->>> c_wavefunction_smmd(0,(1.0 + 1.0j, 2.0 + 2.0j)) +>>> c_wave_smmd(0,np.array([1.0 + 1.0j, 2.0 + 2.0j])) array([ 0.40583486-0.63205035j, -0.49096842+0.56845369j]) ->>> c_wavefunction_smmd(61,(1.0 + 1.0j, 2.0 + 2.0j)) +>>> c_wave_smmd(61,np.array([1.0 + 1.0j, 2.0 + 2.0j])) array([-7.56548941e+03+9.21498621e+02j, -1.64189542e+08-3.70892077e+08j]) ->>> wavefunction_mmod(1,1.0) +>>> wave_mmod(1,1.0) array([0.45558067, 0.64428837]) ->>> c_wavefunction_mmod(1,1.0 +2.0j) +>>> c_wave_mmod(1,1.0 +2.0j) array([-1.40087973-3.06097806j, 6.67661026-8.29116292j]) ->>> wavefunction_mmmd(1,(1.0 ,2.0)) +>>> wave_mmmd(1,np.array([1.0 ,2.0])) array([[0.45558067, 0.10165379], [0.64428837, 0.28752033]]) ->>> c_wavefunction_mmmd(1,(1.0 + 1.0j,2.0 + 2.0j)) +>>> c_wave_mmmd(1,np.array([1.0 + 1.0j,2.0 + 2.0j])) array([[ 0.40583486-0.63205035j, -0.49096842+0.56845369j], [ 1.46779135-0.31991701j, -2.99649822+0.21916143j]]) ``` diff --git a/setup.cfg b/setup.cfg index 65afad3..cc76b24 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = fast_wave -version = 1.0.4 +version = 1.1.0 description = Package for the calculation of the time-independent wavefunction. author = Matheus Gomes Cordeiro author_email = matheusgomescord@gmail.com diff --git a/setup.py b/setup.py index e2eb8dd..db744f9 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ long_description = fh.read() name = "fast_wave" -version = "1.0.4" +version = "1.1.0" description = "Package for the calculation of the time-independent wavefunction." author_email = "matheusgomescord@gmail.com" url = "https://github.com/pikachu123deimos/fast-wave" diff --git a/src/fast_wave/wavefunction.py b/src/fast_wave/wavefunction.py index c3a1fd7..2a633f6 100644 --- a/src/fast_wave/wavefunction.py +++ b/src/fast_wave/wavefunction.py @@ -522,7 +522,7 @@ def c_wavefunction_mmmd(n: np.uint64, x: np.ndarray[np.complex128]) -> np.ndarra return result -def wavefunction(s_mode: bool = True, o_dimensional: bool = True, complex_boll: bool = False, cache: bool = False, cache_size: np.uint64 = 128) -> nb.core.registry.CPUDispatcher: +def wavefunction(s_mode: bool = True, o_dimensional: bool = True, complex_bool: bool = False, cache: bool = False, cache_size: np.uint64 = 128) -> nb.core.registry.CPUDispatcher: """ Computes the wavefunction of a quantum harmonic oscillator .This function dispatches to different implementations of the wavefunction depending on the specified parameters. @@ -548,7 +548,7 @@ def wavefunction(s_mode: bool = True, o_dimensional: bool = True, complex_boll: if(s_mode): if(o_dimensional): - if(not(complex_boll)): + if(not(complex_bool)): if(not(cache)): return wavefunction_smod else: @@ -559,7 +559,7 @@ def wavefunction(s_mode: bool = True, o_dimensional: bool = True, complex_boll: else: return lru_cache(maxsize=cache_size)(c_wavefunction_smod) else: - if(not(complex_boll)): + if(not(complex_bool)): if(not(cache)): return wavefunction_smmd else: @@ -571,7 +571,7 @@ def wavefunction(s_mode: bool = True, o_dimensional: bool = True, complex_boll: return lru_cache(maxsize=cache_size)(c_wavefunction_smmd) else: if(o_dimensional): - if(not(complex_boll)): + if(not(complex_bool)): if(not(cache)): return wavefunction_mmod else: @@ -582,7 +582,7 @@ def wavefunction(s_mode: bool = True, o_dimensional: bool = True, complex_boll: else: return lru_cache(maxsize=cache_size)(c_wavefunction_mmod) else: - if(not(complex_boll)): + if(not(complex_bool)): if(not(cache)): return wavefunction_mmmd else: diff --git a/tests/test_wavefunction.py b/tests/test_wavefunction.py index 6cc635d..b69d565 100644 --- a/tests/test_wavefunction.py +++ b/tests/test_wavefunction.py @@ -43,29 +43,39 @@ def test_wavefunction_computation(): """ Tests the basic functionality of all wavefunction functions. """ + + wave_smod = wavefunction(s_mode = True, o_dimensional = True, complex_bool = False, cache = False, cache_size = 128) + wave_smmd = wavefunction(s_mode = True, o_dimensional = False, complex_bool = False, cache = False, cache_size = 128) + wave_mmod = wavefunction(s_mode = False, o_dimensional = True, complex_bool = False, cache = False, cache_size = 128) + wave_mmmd = wavefunction(s_mode = False, o_dimensional = False, complex_bool = False, cache = False, cache_size = 128) + c_wave_smod = wavefunction(s_mode = True, o_dimensional = True, complex_bool = True, cache = False, cache_size = 128) + c_wave_smmd = wavefunction(s_mode = True, o_dimensional = False, complex_bool = True, cache = False, cache_size = 128) + c_wave_mmod = wavefunction(s_mode = False, o_dimensional = True, complex_bool = True, cache = False, cache_size = 128) + c_wave_mmmd = wavefunction(s_mode = False, o_dimensional = False, complex_bool = True, cache = False, cache_size = 128) + # Testing basic functionality - test_output_udsm = wavefunction_smod(2, 10.0) - assert isinstance(test_output_udsm, float) + test_output_odsm = wave_smod(2, 10.0) + assert isinstance(test_output_odsm, float) - test_output_udmm = wavefunction_mmod(2, 10.0) - assert isinstance(test_output_udmm, np.ndarray) + test_output_odmm = wave_mmod(2, 10.0) + assert isinstance(test_output_odmm, np.ndarray) - test_output_mdsm = wavefunction_smmd(2, (10.0, 4.5)) + test_output_mdsm = wave_smmd(2, np.array([10.0, 4.5])) assert isinstance(test_output_mdsm, np.ndarray) - test_output_mdmm = wavefunction_mmmd(2, (10.0, 4.5)) + test_output_mdmm = wave_mmmd(2, np.array([10.0, 4.5])) assert isinstance(test_output_mdmm, np.ndarray) - test_output_c_udsm = c_wavefunction_smod(2, 10.0 + 0.0j) - assert isinstance(test_output_c_udsm, complex) + test_output_c_odsm = c_wave_smod(2, 10.0 + 0.0j) + assert isinstance(test_output_c_odsm, complex) - test_output_c_udmm = c_wavefunction_mmod(2, 10.0 + 0.0j) - assert isinstance(test_output_c_udmm, np.ndarray) + test_output_c_odmm = c_wave_mmod(2, 10.0 + 0.0j) + assert isinstance(test_output_c_odmm, np.ndarray) - test_output_c_mdsm = c_wavefunction_smmd(2, (10.0 + 0.0j, 4.5 + 0.0j)) + test_output_c_mdsm = c_wave_smmd(2, np.array([10.0 + 0.0j, 4.5 + 0.0j])) assert isinstance(test_output_c_mdsm, np.ndarray) - test_output_c_mdmm = c_wavefunction_mmmd(2, (10.0 + 0.0j, 4.5 + 0.0j)) + test_output_c_mdmm = c_wave_mmmd(2, np.array([10.0 + 0.0j, 4.5 + 0.0j])) assert isinstance(test_output_c_mdmm, np.ndarray) print("All functionality tests passed.")