Skip to content

Commit

Permalink
add test coverage, simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
bhazelton committed Dec 7, 2023
1 parent d90ba48 commit 54bb252
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 8 deletions.
17 changes: 9 additions & 8 deletions src/pyradiosky/skymodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,13 +1144,14 @@ def add_extra_columns(self, *, names, values, dtype=None):
values = [values]
if len(names) != len(values):
raise ValueError("Must provide the same number of names and values.")
if dtype is not None and not isinstance(dtype, (list, tuple, np.ndarray)):
if len(names) == 1:
if dtype is not None:
if isinstance(dtype, (list, tuple, np.ndarray)):
if len(dtype) != len(names):
raise ValueError(
"If dtype is set, it must be the same length as `name`."
)
else:
dtype = [dtype]
if len(names) != len(values):
raise ValueError(
"If dtype is set, it must be the same length as `name`."
)
for val in values:
if val.shape != (self.Ncomponents,):
raise ValueError(
Expand Down Expand Up @@ -2736,15 +2737,15 @@ def concat(
"combine objects."
)
if set(this.extra_columns.dtype.names) != set(
this.extra_columns.dtype.names
other.extra_columns.dtype.names
):
raise ValueError(
"Both objects have extra_columns but the column names do not "
"match. Cannot combine objects."
)
for name in this.extra_columns.dtype.names:
this_dtype = this.extra_columns.dtype[name].type
other_dtype = this.extra_columns.dtype[name].type
other_dtype = other.extra_columns.dtype[name].type
if this_dtype != other_dtype:
raise ValueError(
"Both objects have extra_columns but the dtypes for column "
Expand Down
101 changes: 101 additions & 0 deletions tests/test_skymodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,9 @@ def test_healpix_to_point_loop(
skyobj.add_extra_columns(
names="foo", values=np.arange(skyobj.Ncomponents, dtype=float)
)
skyobj.add_extra_columns(
names="foo2", values=np.arange(skyobj.Ncomponents, dtype=float), dtype=float
)
skyobj.add_extra_columns(
names="bar", values=np.arange(skyobj.Ncomponents, dtype=int)
)
Expand Down Expand Up @@ -922,6 +925,44 @@ def test_healpix_to_point_loop(
sky.check()


def test_extra_columns_errors():
skyobj = SkyModel.from_file(GLEAM_vot, with_error=True)

with pytest.raises(
ValueError, match="Must provide the same number of names and values."
):
skyobj.add_extra_columns(
names=["foo", "bar"], values=np.arange(skyobj.Ncomponents, dtype=float)
)

with pytest.raises(
ValueError, match="If dtype is set, it must be the same length as `name`."
):
skyobj.add_extra_columns(
names="foo",
values=np.arange(skyobj.Ncomponents, dtype=float),
dtype=[float, int],
)

with pytest.raises(
ValueError,
match=re.escape("value array(s) must be 1D, Ncomponents length array(s)"),
):
skyobj.add_extra_columns(
names="foo", values=np.arange(skyobj.Ncomponents - 1, dtype=float)
)

with pytest.raises(TypeError, match="data type 'foo' not understood"):
skyobj.add_extra_columns(
names=["foo", "bar"],
values=[
np.arange(skyobj.Ncomponents, dtype=float),
np.arange(skyobj.Ncomponents, dtype=int),
],
dtype="foo",
)


def test_healpix_to_point_loop_ordering(healpix_disk_new):
skyobj = healpix_disk_new

Expand Down Expand Up @@ -1556,6 +1597,7 @@ def test_concat(comp_type, spec_type, healpix_disk_new):
np.arange(skyobj_full.Ncomponents, dtype=int),
np.array(["gah " + str(ind) for ind in range(skyobj_full.Ncomponents)]),
],
dtype=[float, int, str],
)
skyobj1 = skyobj_full.select(
component_inds=np.arange(skyobj_full.Ncomponents // 2), inplace=False
Expand Down Expand Up @@ -1869,6 +1911,65 @@ def test_concat_compatibility_errors(healpix_disk_new, time_location):
skyobj1.history = skyobj_hpx_disk.history
assert skyobj1 == skyobj_hpx_disk

skyobj1 = skyobj_gleam_subband.select(
component_inds=np.arange(skyobj_gleam_subband.Ncomponents // 2), inplace=False
)
skyobj2 = skyobj_gleam_subband.select(
component_inds=np.arange(
skyobj_gleam_subband.Ncomponents // 2, skyobj_gleam_subband.Ncomponents
),
inplace=False,
)
skyobj1.add_extra_columns(
names=["foo", "bar", "gah"],
values=[
np.arange(skyobj1.Ncomponents, dtype=float),
np.arange(skyobj1.Ncomponents, dtype=int),
np.array(["gah " + str(ind) for ind in range(skyobj1.Ncomponents)]),
],
)
with pytest.raises(
ValueError,
match="One object has extra_columns and the other does not. Cannot combine "
"objects.",
):
skyobj1.concat(skyobj2)

skyobj2.add_extra_columns(
names=["blech", "bar", "gah"],
values=[
np.arange(skyobj2.Ncomponents, dtype=float),
np.arange(skyobj2.Ncomponents, dtype=int),
np.array(["gah " + str(ind) for ind in range(skyobj2.Ncomponents)]),
],
)
with pytest.raises(
ValueError,
match="Both objects have extra_columns but the column names do not match. "
"Cannot combine objects.",
):
skyobj1.concat(skyobj2)
skyobj2 = skyobj_gleam_subband.select(
component_inds=np.arange(
skyobj_gleam_subband.Ncomponents // 2, skyobj_gleam_subband.Ncomponents
),
inplace=False,
)
skyobj2.add_extra_columns(
names=["foo", "bar", "gah"],
values=[
np.arange(skyobj2.Ncomponents, dtype=int),
np.arange(skyobj2.Ncomponents, dtype=int),
np.array(["gah " + str(ind) for ind in range(skyobj2.Ncomponents)]),
],
)
with pytest.raises(
ValueError,
match="Both objects have extra_columns but the dtypes for column "
"foo do not match. Cannot combine objects.",
):
skyobj1.concat(skyobj2)


def test_healpix_import_err(zenith_skymodel):
try:
Expand Down

0 comments on commit 54bb252

Please sign in to comment.