Skip to content

Commit

Permalink
Merge pull request #63 from B612-Asteroid-Institute/ak/validation-extras
Browse files Browse the repository at this point in the history
Allows concatenation to ignore validation and adds method on table to…
  • Loading branch information
akoumjian authored Sep 17, 2024
2 parents f7c92ae + b003dfb commit 8aafd94
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 561 deletions.
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@

This file documents notable changes between versions of quivr.

## [0.7.4] - 2024-09-17

### Removed

- `experimental.shmem` has been removed. Users are encouraged to use tools
such as `ray` for parallel processing and shared memory of pyarrow types.

### Added

- `Table.invalid_mask`, `Table.separate_invalid` have been added to allow users to select rows
that fail validation checks.
- `concatenate` now supports passing the `validate` argument, if you want to postpone automatic
validation of a table until after concatenation.


## [0.7.3] - 2024-05-20

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ python = ["3.10", "3.11"]

[tool.hatch.envs.test.scripts]
all = [
"ruff ./quivr ./test",
"ruff check ./quivr ./test",
"black --check ./quivr ./test",
"isort --check-only ./quivr ./test",
"mypy --strict ./quivr ./examples ./test/typing_tests",
Expand Down
6 changes: 4 additions & 2 deletions quivr/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from . import defragment, errors, tables


def concatenate(values: Iterable[tables.AnyTable], defrag: bool = True) -> tables.AnyTable:
def concatenate(
values: Iterable[tables.AnyTable], defrag: bool = True, validate: bool = True
) -> tables.AnyTable:
"""Concatenate a collection of Tables into a single Table.
All input Tables be of the same class, and have the same attribute
Expand Down Expand Up @@ -59,7 +61,7 @@ def concatenate(values: Iterable[tables.AnyTable], defrag: bool = True) -> table
return first_cls.empty()

table = pa.Table.from_batches(batches)
result = first_cls.from_pyarrow(table=table)
result = first_cls.from_pyarrow(table=table, validate=validate)
if defrag:
result = defragment.defragment(result)
return result
Empty file removed quivr/experimental/__init__.py
Empty file.
256 changes: 0 additions & 256 deletions quivr/experimental/shmem.py

This file was deleted.

26 changes: 25 additions & 1 deletion quivr/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Literal,
Optional,
Protocol,
Tuple,
Type,
TypeAlias,
TypeVar,
Expand Down Expand Up @@ -960,6 +961,29 @@ def validate(self) -> None:
except errors.ValidationError as e:
raise errors.ValidationError(f"Column {name} failed validation: {str(e)}", e.failures) from e

def invalid_mask(self) -> pa.Array:
"""Return a boolean mask indicating which rows are invalid."""
num_rows = self.table.num_rows
mask = np.zeros(num_rows, dtype=bool)
for name, validator in self._column_validators.items():
indices, _ = validator.failures(self.table.column(name))
mask[indices.to_numpy()] = True
return pa.array(mask, type=pa.bool_())

def separate_invalid(self) -> Tuple[Self, Self]:
"""
Separates rows that have invalid data from those that have valid data.
Returns:
Tuple[Self, Self]: A tuple of two Tables. The first Table contains the rows that
passed validation, and the second Table contains the rows that failed validation.
"""
# Separate the rows that do not validate
failure_indices = self.invalid_mask()
valid = self.apply_mask(pyarrow.compute.invert(failure_indices))
invalid = self.apply_mask(failure_indices)
return valid, invalid

@classmethod
def empty(cls, **kwargs: AttributeValueType) -> Self:
"""Create an empty instance of the table.
Expand Down Expand Up @@ -1116,7 +1140,7 @@ def _encode_attr_dict(cls, attrs: dict[str, Any]) -> dict[bytes, bytes]:
result[k.encode("utf8")] = descriptor.to_bytes(pytyped)
return result

def apply_mask(self, mask: pa.BooleanArray | np.ndarray[bool, Any] | list[bool]) -> Self:
def apply_mask(self, mask: pa.BooleanArray | npt.NDArray[np.bool_] | list[bool]) -> Self:
"""
Return a new table with rows filtered to match a boolean mask.
Expand Down
14 changes: 14 additions & 0 deletions test/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ def test_concatenate_empty_tables():
assert len(have) == 0


def test_concatenate_no_validate():
class ValidationTable(qv.Table):
x = qv.Int64Column(validator=qv.ge(0))

t1 = ValidationTable.from_kwargs(x=[-1], validate=False)
t2 = ValidationTable.from_kwargs(x=[1], validate=False)

with pytest.raises(qv.ValidationError, match="Column x failed validation"):
qv.concatenate([t1, t2])

have = qv.concatenate([t1, t2], validate=False)
assert len(have) == 2


@pytest.mark.benchmark(group="ops")
def test_benchmark_concatenate_100(benchmark):
xs1 = pa.array([1, 2, 3], pa.int64())
Expand Down
Loading

0 comments on commit 8aafd94

Please sign in to comment.