Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
bhazelton authored and mkolopanis committed Dec 7, 2023
1 parent 35e5b6d commit 64f8b4b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
24 changes: 15 additions & 9 deletions src/pyradiosky/skymodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,15 +1152,14 @@ def add_extra_columns(self, *, names, values, dtype=None):
)
else:
dtype = [dtype]
for val in values:
for index, val in enumerate(values):
if val.shape != (self.Ncomponents,):
raise ValueError(
"value array(s) must be 1D, Ncomponents length array(s)"
"value array(s) must be 1D, Ncomponents length array(s). The value "
f"array in index {index} is not the right shape."
)
if dtype is None:
dtype = []
for val in values:
dtype.append(val.dtype)
dtype = [val.dtype for val in values]
dtype_obj = np.dtype(list(zip(names, dtype)))
new_recarray = np.rec.fromarrays(values, dtype=dtype_obj)
if self.extra_columns is None:
Expand All @@ -1170,9 +1169,10 @@ def add_extra_columns(self, *, names, values, dtype=None):
(self.extra_columns, new_recarray), asrecarray=True, flatten=True
)
self.extra_columns = combined_recarray
expected_dtype = []
for name in self.extra_columns.dtype.names:
expected_dtype.append(self.extra_columns.dtype[name].type)
expected_dtype = [
self.extra_columns.dtype[name].type
for name in self.extra_columns.dtype.names
]

self._extra_columns.expected_type = expected_dtype

Expand Down Expand Up @@ -2739,9 +2739,15 @@ def concat(
if set(this.extra_columns.dtype.names) != set(
other.extra_columns.dtype.names
):
set_diff = set(this.extra_columns.dtype.names) - set(
other.extra_columns.dtype.names
)
raise ValueError(
"Both objects have extra_columns but the column names do not "
"match. Cannot combine objects."
"match. Cannot combine objects. Left object columns are: "
f"{this.extra_columns.dtype.names}. Right object columns are: "
f"{other.extra_columns.dtype.names}. Unmatched columns are "
f"{set_diff}"
)
for name in this.extra_columns.dtype.names:
this_dtype = this.extra_columns.dtype[name].type
Expand Down
13 changes: 10 additions & 3 deletions tests/test_skymodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,10 @@ def test_extra_columns_errors():

with pytest.raises(
ValueError,
match=re.escape("value array(s) must be 1D, Ncomponents length array(s)"),
match=re.escape(
"value array(s) must be 1D, Ncomponents length array(s). The "
"value array in index 0 is not the right shape."
),
):
skyobj.add_extra_columns(
names="foo", values=np.arange(skyobj.Ncomponents - 1, dtype=float)
Expand Down Expand Up @@ -1945,8 +1948,12 @@ def test_concat_compatibility_errors(healpix_disk_new, time_location):
)
with pytest.raises(
ValueError,
match="Both objects have extra_columns but the column names do not match. "
"Cannot combine objects.",
match=re.escape(
"Both objects have extra_columns but the column names do not match. Cannot "
"combine objects. Left object columns are: ('foo', 'bar', 'gah'). Right "
"object columns are: ('blech', 'bar', 'gah'). Unmatched columns are "
"{'foo'}"
),
):
skyobj1.concat(skyobj2)
skyobj2 = skyobj_gleam_subband.select(
Expand Down

0 comments on commit 64f8b4b

Please sign in to comment.