Skip to content

Commit

Permalink
Merge pull request #15 from arjunsavel/diff_fit_axis
Browse files Browse the repository at this point in the history
test linalg error for bad cube
  • Loading branch information
arjunsavel committed Feb 9, 2024
2 parents d41c6a9 + f91a886 commit bd44a2b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/cortecs/fit/fit_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def do_pca(cube, nc=3):
try:
xMat, s, vh, u = do_svd(standardized_cube, nc, nx)

except np.linalg.LinAlgError:
except np.linalg.LinAlgError as e:
print("SVD did not converge.")
return
raise e

return xMat, standardized_cube, s, vh, u

Expand Down
47 changes: 47 additions & 0 deletions src/cortecs/tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from cortecs.fit.fit_pca import *
import os
import sys
import numpy as np

sys.path.insert(0, os.path.abspath("."))

Expand Down Expand Up @@ -56,3 +57,49 @@ def test_available_methods_assigned_func(self):

fitter = Fitter(self.opac, method="pca")
self.assertEqual(fitter.fit_func, fit_pca)


class TestFitUtils(unittest.TestCase):
"""
Test the fitter object itself
"""

T_filename = os.path.abspath(".") + "/src/cortecs/tests/temperatures.npy"
P_filename = os.path.abspath(".") + "/src/cortecs/tests/pressures.npy"
wl_filename = os.path.abspath(".") + "/src/cortecs/tests/wavelengths.npy"
cross_sec_filename = (
os.path.abspath(".") + "/src/cortecs/tests/absorb_coeffs_C2H4.npy"
)

load_kwargs = {
"T_filename": T_filename,
"P_filename": P_filename,
"wl_filename": wl_filename,
}
opac = Opac(cross_sec_filename, loader="platon", load_kwargs=load_kwargs)

def test_nan_pca_cube_errors(self):
"""
if i pass nans, should fail.
:return:
"""
bad_cube = np.zeros((3, 3)) * np.nan

self.assertRaises(
ValueError,
do_pca,
bad_cube,
)

def test_nan_pca_cube_errors(self):
"""
i want to make a linalg errror!
:return:
"""
bad_cube = np.zeros((3, 3)) * np.inf

self.assertRaises(
ValueError,
do_pca,
bad_cube,
)

0 comments on commit bd44a2b

Please sign in to comment.