Skip to content

Commit

Permalink
Add ability to create null tables with attributes and easily fetch ma…
Browse files Browse the repository at this point in the history
…sk for null rows in a table
  • Loading branch information
akoumjian committed Jan 10, 2025
1 parent 80cbc82 commit 33cfa19
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
42 changes: 40 additions & 2 deletions quivr/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,12 @@ def as_column(
return columns.SubTableColumn(cls, nullable=nullable, metadata=metadata)

@classmethod
def from_kwargs(cls, validate: bool = True, **kwargs: Union[DataSourceType, AttributeValueType]) -> Self:
def from_kwargs(
cls,
validate: bool = True,
permit_nulls: bool = False,
**kwargs: Union[DataSourceType, AttributeValueType],
) -> Self:
"""Create a Table instance from keyword arguments.
Each keyword argument corresponds to a column in the Table.
Expand Down Expand Up @@ -304,7 +309,12 @@ def from_kwargs(cls, validate: bool = True, **kwargs: Union[DataSourceType, Attr

pyarrow_table = cls._build_arrow_table(arrays, metadata)
attrib_kwargs = cls._attribute_kwargs_from_kwargs(kwargs)
return cls.from_pyarrow(table=pyarrow_table, validate=validate, permit_nulls=False, **attrib_kwargs)
return cls.from_pyarrow(
table=pyarrow_table,
validate=validate,
permit_nulls=permit_nulls,
**attrib_kwargs,
)

@classmethod
def _build_arrow_table(cls, arrays: List[pa.Array], metadata: dict[bytes, bytes]) -> pa.Table:
Expand Down Expand Up @@ -992,6 +1002,15 @@ def separate_invalid(self) -> Tuple[Self, Self]:
valid = self.apply_mask(pyarrow.compute.invert(failure_indices))
invalid = self.apply_mask(failure_indices)
return valid, invalid

def null_mask(self) -> pa.Array:
"""Return a boolean mask indicating which rows of the entire table are null."""
# Get the null mask for each column
flattened_table = self.flattened_table()
mask = pa.repeat(True, len(flattened_table ))
for name in flattened_table.column_names:
mask = pc.and_(mask, pc.is_null(flattened_table.column(name)))
return pa.array(mask, type=pa.bool_())

@classmethod
def empty(cls, **kwargs: AttributeValueType) -> Self:
Expand All @@ -1003,6 +1022,25 @@ def empty(cls, **kwargs: AttributeValueType) -> Self:
empty_table = pa.table(data, schema=cls.schema)
return cls.from_pyarrow(table=empty_table, validate=False, permit_nulls=False, **kwargs)


@classmethod
def nulls(cls, size: int, **kwargs: AttributeValueType) -> Self:
"""Create a table with nulls.
:param size: The number of rows to create.
:param \\**kwargs: Additional keyword arguments to set the Table's attributes.
Even for tables which do not permit nulls in their columns, it is possible
to create a table with nulls in the context of SubTableColumn. So for
both of these cases, we need a method to populate tables with null values
for all columns, while also setting the Table's attributes.
"""
null_array = pa.repeat(None, size)
data = [null_array for _ in range(len(cls.schema))] # type: ignore
null_table = pa.table(data, schema=cls.schema)
return cls.from_pyarrow(table=null_table, validate=False, permit_nulls=True, **kwargs)

def attributes(self) -> dict[str, Any]:
"""Return a dictionary of the table's attributes."""
return {name: getattr(self, name) for name in self._quivr_attributes}
Expand Down
23 changes: 23 additions & 0 deletions test/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,6 +1424,29 @@ class DoublyNested(qv.Table):
assert len(dn.column("inner.pair.y")) == 0


def test_nulls():
t = Pair.nulls(3)
assert t.x.equals(pa.array([None, None, None], pa.int64()))
assert t.y.equals(pa.array([None, None, None], pa.int64()))


def test_null_mask():
t = Pair.nulls(3)
assert t.null_mask().equals(pa.array([True, True, True], pa.bool_()))

t = Pair.from_kwargs(x=[1, 2, 3], y=[4, 5, 6])
assert t.null_mask().equals(pa.array([False, False, False], pa.bool_()))

t = Pair.from_kwargs(x=[1, None, 3], y=[4, 5, 6], permit_nulls=True)
assert t.null_mask().equals(pa.array([False, False, False], pa.bool_()))

t = Pair.from_kwargs(x=[1, None, 3], y=[None, 5, 6], permit_nulls=True)
assert t.null_mask().equals(pa.array([False, False, False], pa.bool_()))

t = Pair.from_kwargs(x=[1, None, 3], y=[4, None, 6], permit_nulls=True)
assert t.null_mask().equals(pa.array([False, True, False], pa.bool_()))


def test_column_invalid_name():
t = Pair.from_kwargs(x=[1, 2, 3], y=[4, 5, 6])
with pytest.raises(KeyError):
Expand Down

0 comments on commit 33cfa19

Please sign in to comment.