Skip to content

Commit

Permalink
Big wavefunction 2
Browse files Browse the repository at this point in the history
  • Loading branch information
fobos123deimos committed Jul 25, 2024
1 parent c1354fa commit 230aa5a
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 31 deletions.
32 changes: 20 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,30 +48,38 @@ array([[ 0., 0., 0., 1.],
[ 0., 0., 2., 0.],
[ 0., 4., 0., -2.],
[ 8., 0., -12., 0.]])
>>> wavefunction_smod(0, 1.0)
>>> wave_smod = wavefunction(s_mode = True, o_dimensional = True, complex_bool = False, cache = False, cache_size = 128)
>>> wave_smmd = wavefunction(s_mode = True, o_dimensional = False, complex_bool = False, cache = False, cache_size = 128)
>>> wave_mmod = wavefunction(s_mode = False, o_dimensional = True, complex_bool = False, cache = False, cache_size = 128)
>>> wave_mmmd = wavefunction(s_mode = False, o_dimensional = False, complex_bool = False, cache = False, cache_size = 128)
>>> c_wave_smod = wavefunction(s_mode = True, o_dimensional = True, complex_bool = True, cache = False, cache_size = 128)
>>> c_wave_smmd = wavefunction(s_mode = True, o_dimensional = False, complex_bool = True, cache = False, cache_size = 128)
>>> c_wave_mmod = wavefunction(s_mode = False, o_dimensional = True, complex_bool = True, cache = False, cache_size = 128)
>>> c_wave_mmmd = wavefunction(s_mode = False, o_dimensional = False, complex_bool = True, cache = False, cache_size = 128)
>>> wave_smod(0, 1.0)
0.45558067201133257
>>> wavefunction_smod(61, 1.0)
>>> wave_smod(61, 1.0)
-0.2393049199171131
>>> c_wavefunction_smod(0,1.0+2.0j)
>>> c_wave_smod(0,1.0+2.0j)
(-1.4008797330262455-3.0609780602975003j)
>>> c_wavefunction_smod(61,1.0+2.0j)
>>> c_wave_smod(61,1.0+2.0j)
(-511062135.47555304+131445997.75753704j)
>>> wavefunction_smmd(0,(1.0,2.0))
>>> wave_smmd(0,np.array([1.0,2.0]))
array([0.45558067, 0.10165379])
>>> wavefunction_smmd(61,(1.0,2.0))
>>> wave_smmd(61,np.array([1.0,2.0]))
array([-0.23930492, -0.01677378])
>>> c_wavefunction_smmd(0,(1.0 + 1.0j, 2.0 + 2.0j))
>>> c_wave_smmd(0,np.array([1.0 + 1.0j, 2.0 + 2.0j]))
array([ 0.40583486-0.63205035j, -0.49096842+0.56845369j])
>>> c_wavefunction_smmd(61,(1.0 + 1.0j, 2.0 + 2.0j))
>>> c_wave_smmd(61,np.array([1.0 + 1.0j, 2.0 + 2.0j]))
array([-7.56548941e+03+9.21498621e+02j, -1.64189542e+08-3.70892077e+08j])
>>> wavefunction_mmod(1,1.0)
>>> wave_mmod(1,1.0)
array([0.45558067, 0.64428837])
>>> c_wavefunction_mmod(1,1.0 +2.0j)
>>> c_wave_mmod(1,1.0 +2.0j)
array([-1.40087973-3.06097806j, 6.67661026-8.29116292j])
>>> wavefunction_mmmd(1,(1.0 ,2.0))
>>> wave_mmmd(1,np.array([1.0 ,2.0]))
array([[0.45558067, 0.10165379],
[0.64428837, 0.28752033]])
>>> c_wavefunction_mmmd(1,(1.0 + 1.0j,2.0 + 2.0j))
>>> c_wave_mmmd(1,np.array([1.0 + 1.0j,2.0 + 2.0j]))
array([[ 0.40583486-0.63205035j, -0.49096842+0.56845369j],
[ 1.46779135-0.31991701j, -2.99649822+0.21916143j]])
```
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = fast_wave
version = 1.0.4
version = 1.1.0
description = Package for the calculation of the time-independent wavefunction.
author = Matheus Gomes Cordeiro
author_email = matheusgomescord@gmail.com
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
long_description = fh.read()

name = "fast_wave"
version = "1.0.4"
version = "1.1.0"
description = "Package for the calculation of the time-independent wavefunction."
author_email = "matheusgomescord@gmail.com"
url = "https://github.com/pikachu123deimos/fast-wave"
Expand Down
10 changes: 5 additions & 5 deletions src/fast_wave/wavefunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def c_wavefunction_mmmd(n: np.uint64, x: np.ndarray[np.complex128]) -> np.ndarra

return result

def wavefunction(s_mode: bool = True, o_dimensional: bool = True, complex_boll: bool = False, cache: bool = False, cache_size: np.uint64 = 128) -> nb.core.registry.CPUDispatcher:
def wavefunction(s_mode: bool = True, o_dimensional: bool = True, complex_bool: bool = False, cache: bool = False, cache_size: np.uint64 = 128) -> nb.core.registry.CPUDispatcher:

"""
Computes the wavefunction of a quantum harmonic oscillator .This function dispatches to different implementations of the wavefunction depending on the specified parameters.
Expand All @@ -548,7 +548,7 @@ def wavefunction(s_mode: bool = True, o_dimensional: bool = True, complex_boll:

if(s_mode):
if(o_dimensional):
if(not(complex_boll)):
if(not(complex_bool)):
if(not(cache)):
return wavefunction_smod
else:
Expand All @@ -559,7 +559,7 @@ def wavefunction(s_mode: bool = True, o_dimensional: bool = True, complex_boll:
else:
return lru_cache(maxsize=cache_size)(c_wavefunction_smod)
else:
if(not(complex_boll)):
if(not(complex_bool)):
if(not(cache)):
return wavefunction_smmd
else:
Expand All @@ -571,7 +571,7 @@ def wavefunction(s_mode: bool = True, o_dimensional: bool = True, complex_boll:
return lru_cache(maxsize=cache_size)(c_wavefunction_smmd)
else:
if(o_dimensional):
if(not(complex_boll)):
if(not(complex_bool)):
if(not(cache)):
return wavefunction_mmod
else:
Expand All @@ -582,7 +582,7 @@ def wavefunction(s_mode: bool = True, o_dimensional: bool = True, complex_boll:
else:
return lru_cache(maxsize=cache_size)(c_wavefunction_mmod)
else:
if(not(complex_boll)):
if(not(complex_bool)):
if(not(cache)):
return wavefunction_mmmd
else:
Expand Down
34 changes: 22 additions & 12 deletions tests/test_wavefunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,29 +43,39 @@ def test_wavefunction_computation():
"""
Tests the basic functionality of all wavefunction functions.
"""

wave_smod = wavefunction(s_mode = True, o_dimensional = True, complex_bool = False, cache = False, cache_size = 128)
wave_smmd = wavefunction(s_mode = True, o_dimensional = False, complex_bool = False, cache = False, cache_size = 128)
wave_mmod = wavefunction(s_mode = False, o_dimensional = True, complex_bool = False, cache = False, cache_size = 128)
wave_mmmd = wavefunction(s_mode = False, o_dimensional = False, complex_bool = False, cache = False, cache_size = 128)
c_wave_smod = wavefunction(s_mode = True, o_dimensional = True, complex_bool = True, cache = False, cache_size = 128)
c_wave_smmd = wavefunction(s_mode = True, o_dimensional = False, complex_bool = True, cache = False, cache_size = 128)
c_wave_mmod = wavefunction(s_mode = False, o_dimensional = True, complex_bool = True, cache = False, cache_size = 128)
c_wave_mmmd = wavefunction(s_mode = False, o_dimensional = False, complex_bool = True, cache = False, cache_size = 128)

# Testing basic functionality
test_output_udsm = wavefunction_smod(2, 10.0)
assert isinstance(test_output_udsm, float)
test_output_odsm = wave_smod(2, 10.0)
assert isinstance(test_output_odsm, float)

test_output_udmm = wavefunction_mmod(2, 10.0)
assert isinstance(test_output_udmm, np.ndarray)
test_output_odmm = wave_mmod(2, 10.0)
assert isinstance(test_output_odmm, np.ndarray)

test_output_mdsm = wavefunction_smmd(2, (10.0, 4.5))
test_output_mdsm = wave_smmd(2, np.array([10.0, 4.5]))
assert isinstance(test_output_mdsm, np.ndarray)

test_output_mdmm = wavefunction_mmmd(2, (10.0, 4.5))
test_output_mdmm = wave_mmmd(2, np.array([10.0, 4.5]))
assert isinstance(test_output_mdmm, np.ndarray)

test_output_c_udsm = c_wavefunction_smod(2, 10.0 + 0.0j)
assert isinstance(test_output_c_udsm, complex)
test_output_c_odsm = c_wave_smod(2, 10.0 + 0.0j)
assert isinstance(test_output_c_odsm, complex)

test_output_c_udmm = c_wavefunction_mmod(2, 10.0 + 0.0j)
assert isinstance(test_output_c_udmm, np.ndarray)
test_output_c_odmm = c_wave_mmod(2, 10.0 + 0.0j)
assert isinstance(test_output_c_odmm, np.ndarray)

test_output_c_mdsm = c_wavefunction_smmd(2, (10.0 + 0.0j, 4.5 + 0.0j))
test_output_c_mdsm = c_wave_smmd(2, np.array([10.0 + 0.0j, 4.5 + 0.0j]))
assert isinstance(test_output_c_mdsm, np.ndarray)

test_output_c_mdmm = c_wavefunction_mmmd(2, (10.0 + 0.0j, 4.5 + 0.0j))
test_output_c_mdmm = c_wave_mmmd(2, np.array([10.0 + 0.0j, 4.5 + 0.0j]))
assert isinstance(test_output_c_mdmm, np.ndarray)

print("All functionality tests passed.")
Expand Down

0 comments on commit 230aa5a

Please sign in to comment.