Skip to content

Commit

Permalink
Merge pull request #39 from p2p-ld/bugfix-union-dtypes
Browse files Browse the repository at this point in the history
Fix JSON Schema generation for union dtypes
  • Loading branch information
sneakers-the-rat authored Dec 14, 2024
2 parents dc547c4 + a7b8b5f commit d54698f
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 22 deletions.
10 changes: 9 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,28 @@ jobs:
matrix:
platform: ["ubuntu-latest", "macos-latest", "windows-latest"]
numpy-version: ["<2.0.0", ">=2.0.0"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
exclude:
- numpy-version: "<2.0.0"
python-version: "3.10"
- numpy-version: "<2.0.0"
python-version: "3.11"
# let's call python 3.12 the last version we're going to support with numpy <2
# they don't provide wheels for <2 in 3.13 and beyond.
- numpy-version: "<2.0.0"
python-version: "3.13"
- platform: "macos-latest"
python-version: "3.10"
- platform: "macos-latest"
python-version: "3.11"
- platform: "macos-latest"
python-version: "3.12"
- platform: "windows-latest"
python-version: "3.10"
- platform: "windows-latest"
python-version: "3.11"
- platform: "windows-latest"
python-version: "3.12"

runs-on: ${{ matrix.platform }}

Expand Down
16 changes: 16 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
# Changelog

## Upcoming

**Bugfix**
- [#38](https://github.com/p2p-ld/numpydantic/issues/38), [#39](https://github.com/p2p-ld/numpydantic/pull/39) -
- JSON Schema generation failed when the `dtype` was embedded from dtypes that lack a `__name__` attribute.
An additional check was added for presence of `__name__` when embedding.
- `NDArray` types were incorrectly cached s.t. pipe-union dtypes were considered equivalent to `Union[]`
dtypes. An additional tuple with the type of the args was added to the cache key to disambiguate them.

**Testing**
- [#39](https://github.com/p2p-ld/numpydantic/pull/39) - Test that all combinations of shapes, dtypes, and interfaces
can generate JSON schema.
- [#39](https://github.com/p2p-ld/numpydantic/pull/39) - Add python 3.13 to the testing matrix.
- [#39](https://github.com/p2p-ld/numpydantic/pull/39) - Add an additional `marks` field to ValidationCase
for finer-grained control over running tests.

## 1.*

### 1.6.*
Expand Down
2 changes: 1 addition & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ markers = [
"numpy: numpy interface",
"video: video interface",
"zarr: zarr interface",
"union: union dtypes",
"pipe_union: union dtypes specified with a pipe",
]

[tool.black]
Expand Down
12 changes: 9 additions & 3 deletions src/numpydantic/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,15 @@ def __get_pydantic_json_schema__(
json_schema = handler(schema["metadata"])
json_schema = handler.resolve_ref_schema(json_schema)

if not isinstance(dtype, tuple) and dtype.__module__ not in (
"builtins",
"typing",
if (
not isinstance(dtype, tuple)
and dtype.__module__
not in (
"builtins",
"typing",
"types",
)
and hasattr(dtype, "__name__")
):
json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__])

Expand Down
15 changes: 14 additions & 1 deletion src/numpydantic/testing/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,27 +143,35 @@ class SubClass(BasicModel):
dtype=np.uint32,
passes=True,
id="union-type-uint32",
marks={"union"},
),
ValidationCase(
annotation_dtype=UNION_TYPE,
dtype=np.float32,
passes=True,
id="union-type-float32",
marks={"union"},
),
ValidationCase(
annotation_dtype=UNION_TYPE,
dtype=np.uint64,
passes=False,
id="union-type-uint64",
marks={"union"},
),
ValidationCase(
annotation_dtype=UNION_TYPE,
dtype=np.float64,
passes=False,
id="union-type-float64",
marks={"union"},
),
ValidationCase(
annotation_dtype=UNION_TYPE, dtype=str, passes=False, id="union-type-str"
annotation_dtype=UNION_TYPE,
dtype=str,
passes=False,
id="union-type-str",
marks={"union"},
),
]
"""
Expand All @@ -181,30 +189,35 @@ class SubClass(BasicModel):
dtype=np.uint32,
passes=True,
id="union-pipe-uint32",
marks={"union", "pipe_union"},
),
ValidationCase(
annotation_dtype=UNION_PIPE,
dtype=np.float32,
passes=True,
id="union-pipe-float32",
marks={"union", "pipe_union"},
),
ValidationCase(
annotation_dtype=UNION_PIPE,
dtype=np.uint64,
passes=False,
id="union-pipe-uint64",
marks={"union", "pipe_union"},
),
ValidationCase(
annotation_dtype=UNION_PIPE,
dtype=np.float64,
passes=False,
id="union-pipe-float64",
marks={"union", "pipe_union"},
),
ValidationCase(
annotation_dtype=UNION_PIPE,
dtype=str,
passes=False,
id="union-pipe-str",
marks={"union", "pipe_union"},
),
]
)
Expand Down
28 changes: 25 additions & 3 deletions src/numpydantic/testing/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
from itertools import product
from operator import ior
from pathlib import Path
from typing import Generator, List, Literal, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Generator, List, Literal, Optional, Tuple, Type, Union

import numpy as np
from pydantic import BaseModel, ConfigDict, ValidationError, computed_field
from pydantic import BaseModel, ConfigDict, Field, ValidationError, computed_field

from numpydantic import NDArray, Shape
from numpydantic.dtype import Float
from numpydantic.interface import Interface
from numpydantic.types import DtypeType, NDArrayType

if TYPE_CHECKING:
from _pytest.mark.structures import MarkDecorator


class InterfaceCase(ABC):
"""
Expand Down Expand Up @@ -139,6 +142,8 @@ class ValidationCase(BaseModel):
"""The interface test case to generate and validate the array with"""
path: Optional[Path] = None
"""The path to generate arrays into, if any."""
marks: set[str] = Field(default_factory=set)
"""pytest marks to set for this test case"""

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down Expand Up @@ -179,6 +184,19 @@ class Model(BaseModel):

return Model

@property
def pytest_marks(self) -> list["MarkDecorator"]:
"""
Instantiated pytest marks from :attr:`.ValidationCase.marks`
plus the interface name.
"""
import pytest

marks = self.marks.copy()
if self.interface is not None:
marks.add(self.interface.interface.name)
return [getattr(pytest.mark, m) for m in marks]

def validate_case(self, path: Optional[Path] = None) -> bool:
"""
Whether the generated array correctly validated against the annotation,
Expand Down Expand Up @@ -246,7 +264,10 @@ def merge_cases(*args: ValidationCase) -> ValidationCase:
return args[0]

dumped = [
m.model_dump(exclude_unset=True, exclude={"model", "annotation"}) for m in args
m.model_dump(
exclude_unset=True, exclude={"model", "annotation", "pytest_marks"}
)
for m in args
]

# self_dump = self.model_dump(exclude_unset=True)
Expand All @@ -263,6 +284,7 @@ def merge_cases(*args: ValidationCase) -> ValidationCase:
merged = reduce(ior, dumped, {})
merged["passes"] = passes
merged["id"] = ids
merged["marks"] = set().union(*[v.get("marks", set()) for v in dumped])
return ValidationCase.model_construct(**merged)


Expand Down
4 changes: 2 additions & 2 deletions src/numpydantic/vendor/nptyping/base_meta_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class SubscriptableMeta(ABCMeta):
new type is returned for every unique set of arguments.
"""

_all_types: Dict[Tuple[type, Tuple[Any, ...]], type] = {}
_all_types: Dict[Tuple[type, Tuple[Any, ...], tuple[type, ...]], type] = {}
_parameterized: bool = False

@abstractmethod
Expand Down Expand Up @@ -160,7 +160,7 @@ def __getitem__(cls, item: Any) -> type:
def _create_type(
cls, args: Tuple[Any, ...], additional_values: Dict[str, Any]
) -> type:
key = (cls, args)
key = (cls, args, tuple(type(a) for a in args))
if key not in cls._all_types:
cls._all_types[key] = type(
cls.__name__,
Expand Down
6 changes: 4 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def pytest_addoption(parser):


@pytest.fixture(
scope="function", params=[pytest.param(c, id=c.id) for c in SHAPE_CASES]
scope="function",
params=[pytest.param(c, id=c.id, marks=c.pytest_marks) for c in SHAPE_CASES],
)
def shape_cases(request, tmp_output_dir_func) -> ValidationCase:
case: ValidationCase = request.param.model_copy()
Expand All @@ -23,7 +24,8 @@ def shape_cases(request, tmp_output_dir_func) -> ValidationCase:


@pytest.fixture(
scope="function", params=[pytest.param(c, id=c.id) for c in DTYPE_CASES]
scope="function",
params=[pytest.param(c, id=c.id, marks=c.pytest_marks) for c in DTYPE_CASES],
)
def dtype_cases(request, tmp_output_dir_func) -> ValidationCase:
case: ValidationCase = request.param.model_copy()
Expand Down
12 changes: 3 additions & 9 deletions tests/test_interface/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ def interface_cases(request) -> InterfaceCase:


@pytest.fixture(
params=(
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
for p in ALL_CASES
)
params=(pytest.param(p, id=p.id, marks=p.pytest_marks) for p in ALL_CASES)
)
def all_cases(interface_cases, request) -> ValidationCase:
"""
Expand All @@ -83,10 +80,7 @@ def all_cases(interface_cases, request) -> ValidationCase:


@pytest.fixture(
params=(
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
for p in ALL_CASES_PASSING
)
params=(pytest.param(p, id=p.id, marks=p.pytest_marks) for p in ALL_CASES_PASSING)
)
def all_passing_cases(request) -> ValidationCase:
"""
Expand Down Expand Up @@ -132,7 +126,7 @@ def all_passing_cases_instance(all_passing_cases, tmp_output_dir_func):

@pytest.fixture(
params=(
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
pytest.param(p, id=p.id, marks=p.pytest_marks)
for p in DTYPE_AND_INTERFACE_CASES_PASSING
)
)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_interface/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ def test_interface_revalidate(all_passing_cases_instance):
_ = type(all_passing_cases_instance)(array=all_passing_cases_instance.array)


@pytest.mark.json_schema
def test_interface_jsonschema(all_passing_cases_instance):
"""
All interfaces should be able to generate json schema
for all combinations of dtype and shape
Note that this does not test for json schema correctness -
see ndarray tests for that
"""
_ = all_passing_cases_instance.model_json_schema()


@pytest.mark.xfail
def test_interface_rematch(interface_cases, tmp_output_dir_func):
"""
Expand Down

0 comments on commit d54698f

Please sign in to comment.