Skip to content

Commit

Permalink
new normalized_hermite_coefficients_matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
fobos123deimos committed Sep 14, 2024
1 parent 04067eb commit 4d063fb
Showing 1 changed file with 40 additions and 9 deletions.
49 changes: 40 additions & 9 deletions src/fast_wave/wavefunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import pickle
import math
from functools import lru_cache
from sympy import symbols, diff, exp, Poly

import numpy as np
import numba as nb
Expand All @@ -43,6 +44,34 @@
c_s_matrix = None
compilation_test = None

def hermite_sympy(n: np.uint64) -> Poly:
"""
Compute the nth Hermite polynomial using symbolic differentiation.
Parameters
----------
n : np.uint64
Order of the Hermite polynomial.
Returns
-------
Poly
The nth Hermite polynomial as a sympy expression.
Examples
--------
```
>>> hermite_sympy(2)
4*x**2 - 2
```
References
----------
- Wikipedia contributors. (2021). Hermite polynomials. In Wikipedia, The Free Encyclopedia. Retrieved from https://en.wikipedia.org/wiki/Hermite_polynomials
"""
x = symbols("x")
return 1 if n == 0 else ((-1) ** n) * exp(x ** 2) * diff(exp(-x ** 2), x, n)

@nb.jit(nopython=True, looplift=True, nogil=True, boundscheck=False, cache=True)
def create_normalized_hermite_coefficients_matrix(n_max: np.uint64) -> np.ndarray:
"""
Expand Down Expand Up @@ -74,17 +103,19 @@ def create_normalized_hermite_coefficients_matrix(n_max: np.uint64) -> np.ndarra
- NIST Digital Library of Mathematical Functions. https://dlmf.nist.gov/, Release 1.0.28 of 2020-09-15.
- Sympy Documentation: https://docs.sympy.org/latest/modules/polys/index.html
"""
C_s = np.zeros((n_max + 1, n_max + 1), dtype=np.float64)
x = symbols("x")
C = np.zeros((n_max + 1, n_max + 1), dtype=np.float64)
C[0, n_max] = 1

for i in range(n_max+1):
for j in range(n_max+1):
if((j>=(n_max-i)) and ((n_max-i+j)%2 == 0)):
C_s[i,j] = ( ((-1)**((j-n_max+i)/2)) * (2**(n_max-j-(i*0.5))) ) / ( math.gamma(((j-n_max+i)/2) + 1) * math.gamma(n_max-j + 1) )
else:
C_s[i,j] = 0.0
C_s[i] *= math.gamma(i+1)**0.5
for n in range(1, n_max + 1):
c = Poly(hermite_sympy(n), x).all_coeffs()
for index in range(n, -1, -1):
C[n, (n_max + 1) - index - 1] = float(c[n - index])

for i in range(n_max + 1):
C[i] /= (np.pi**0.50 * (2**i) * math.gamma(i+1))**0.5

return C_s/(np.pi**0.25)
return C



Expand Down

0 comments on commit 4d063fb

Please sign in to comment.