Skip to content

Commit

Permalink
add test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
bhazelton committed Oct 16, 2023
1 parent 688fb28 commit 5bf939a
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/pyradiosky/skymodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,10 +1144,11 @@ def add_extra_columns(self, *, names, values, dtype=None):
values = [values]
if len(names) != len(values):
raise ValueError("Must provide the same number of names and values.")
if dtype is not None and not isinstance(dtype, (list, tuple, np.ndarray)):
if len(names) == 1:
dtype = [dtype]
if len(names) != len(values):
if dtype is not None:
if not isinstance(dtype, (list, tuple, np.ndarray)):
if len(names) == 1:
dtype = [dtype]
elif len(dtype) != len(names):
raise ValueError(
"If dtype is set, it must be the same length as `name`."
)
Expand Down Expand Up @@ -2736,15 +2737,15 @@ def concat(
"combine objects."
)
if set(this.extra_columns.dtype.names) != set(
this.extra_columns.dtype.names
other.extra_columns.dtype.names
):
raise ValueError(
"Both objects have extra_columns but the column names do not "
"match. Cannot combine objects."
)
for name in this.extra_columns.dtype.names:
this_dtype = this.extra_columns.dtype[name].type
other_dtype = this.extra_columns.dtype[name].type
other_dtype = other.extra_columns.dtype[name].type
if this_dtype != other_dtype:
raise ValueError(
"Both objects have extra_columns but the dtypes for column "
Expand Down
90 changes: 90 additions & 0 deletions tests/test_skymodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,9 @@ def test_healpix_to_point_loop(
skyobj.add_extra_columns(
names="foo", values=np.arange(skyobj.Ncomponents, dtype=float)
)
skyobj.add_extra_columns(
names="foo2", values=np.arange(skyobj.Ncomponents, dtype=float), dtype=float
)
skyobj.add_extra_columns(
names="bar", values=np.arange(skyobj.Ncomponents, dtype=int)
)
Expand Down Expand Up @@ -922,6 +925,34 @@ def test_healpix_to_point_loop(
sky.check()


def test_extra_columns_errors():
skyobj = SkyModel.from_file(GLEAM_vot, with_error=True)

with pytest.raises(
ValueError, match="Must provide the same number of names and values."
):
skyobj.add_extra_columns(
names=["foo", "bar"], values=np.arange(skyobj.Ncomponents, dtype=float)
)

with pytest.raises(
ValueError, match="If dtype is set, it must be the same length as `name`."
):
skyobj.add_extra_columns(
names="foo",
values=np.arange(skyobj.Ncomponents, dtype=float),
dtype=[float, int],
)

with pytest.raises(
ValueError,
match=re.escape("value array(s) must be 1D, Ncomponents length array(s)"),
):
skyobj.add_extra_columns(
names="foo", values=np.arange(skyobj.Ncomponents - 1, dtype=float)
)


def test_healpix_to_point_loop_ordering(healpix_disk_new):
skyobj = healpix_disk_new

Expand Down Expand Up @@ -1869,6 +1900,65 @@ def test_concat_compatibility_errors(healpix_disk_new, time_location):
skyobj1.history = skyobj_hpx_disk.history
assert skyobj1 == skyobj_hpx_disk

skyobj1 = skyobj_gleam_subband.select(
component_inds=np.arange(skyobj_gleam_subband.Ncomponents // 2), inplace=False
)
skyobj2 = skyobj_gleam_subband.select(
component_inds=np.arange(
skyobj_gleam_subband.Ncomponents // 2, skyobj_gleam_subband.Ncomponents
),
inplace=False,
)
skyobj1.add_extra_columns(
names=["foo", "bar", "gah"],
values=[
np.arange(skyobj1.Ncomponents, dtype=float),
np.arange(skyobj1.Ncomponents, dtype=int),
np.array(["gah " + str(ind) for ind in range(skyobj1.Ncomponents)]),
],
)
with pytest.raises(
ValueError,
match="One object has extra_columns and the other does not. Cannot combine "
"objects.",
):
skyobj1.concat(skyobj2)

skyobj2.add_extra_columns(
names=["blech", "bar", "gah"],
values=[
np.arange(skyobj2.Ncomponents, dtype=float),
np.arange(skyobj2.Ncomponents, dtype=int),
np.array(["gah " + str(ind) for ind in range(skyobj2.Ncomponents)]),
],
)
with pytest.raises(
ValueError,
match="Both objects have extra_columns but the column names do not match. "
"Cannot combine objects.",
):
skyobj1.concat(skyobj2)
skyobj2 = skyobj_gleam_subband.select(
component_inds=np.arange(
skyobj_gleam_subband.Ncomponents // 2, skyobj_gleam_subband.Ncomponents
),
inplace=False,
)
skyobj2.add_extra_columns(
names=["foo", "bar", "gah"],
values=[
np.arange(skyobj2.Ncomponents, dtype=int),
np.arange(skyobj2.Ncomponents, dtype=int),
np.array(["gah " + str(ind) for ind in range(skyobj2.Ncomponents)]),
],
)
with pytest.raises(
ValueError,
match="Both objects have extra_columns but the dtypes for column "
"foo do not match. Cannot combine objects.",
):
skyobj1.concat(skyobj2)


def test_healpix_import_err(zenith_skymodel):
try:
Expand Down

0 comments on commit 5bf939a

Please sign in to comment.