Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JSON Schema generation for union dtypes #39

Merged
merged 7 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,28 @@ jobs:
matrix:
platform: ["ubuntu-latest", "macos-latest", "windows-latest"]
numpy-version: ["<2.0.0", ">=2.0.0"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
exclude:
- numpy-version: "<2.0.0"
python-version: "3.10"
- numpy-version: "<2.0.0"
python-version: "3.11"
# let's call python 3.12 the last version we're going to support with numpy <2
# they don't provide wheels for <2 in 3.13 and beyond.
- numpy-version: "<2.0.0"
python-version: "3.13"
- platform: "macos-latest"
python-version: "3.10"
- platform: "macos-latest"
python-version: "3.11"
- platform: "macos-latest"
python-version: "3.12"
- platform: "windows-latest"
python-version: "3.10"
- platform: "windows-latest"
python-version: "3.11"
- platform: "windows-latest"
python-version: "3.12"

runs-on: ${{ matrix.platform }}

Expand Down
16 changes: 16 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
# Changelog

## Upcoming

**Bugfix**
- [#38](https://github.com/p2p-ld/numpydantic/issues/38), [#39](https://github.com/p2p-ld/numpydantic/pull/39) -
- JSON Schema generation failed when the `dtype` was embedded from dtypes that lack a `__name__` attribute.
An additional check was added for presence of `__name__` when embedding.
- `NDArray` types were incorrectly cached s.t. pipe-union dtypes were considered equivalent to `Union[]`
dtypes. An additional tuple with the type of the args was added to the cache key to disambiguate them.

**Testing**
- [#39](https://github.com/p2p-ld/numpydantic/pull/39) - Test that all combinations of shapes, dtypes, and interfaces
can generate JSON schema.
- [#39](https://github.com/p2p-ld/numpydantic/pull/39) - Add python 3.13 to the testing matrix.
- [#39](https://github.com/p2p-ld/numpydantic/pull/39) - Add an additional `marks` field to ValidationCase
for finer-grained control over running tests.

## 1.*

### 1.6.*
Expand Down
2 changes: 1 addition & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ markers = [
"numpy: numpy interface",
"video: video interface",
"zarr: zarr interface",
"union: union dtypes",
"pipe_union: union dtypes specified with a pipe",
]

[tool.black]
Expand Down
12 changes: 9 additions & 3 deletions src/numpydantic/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,15 @@ def __get_pydantic_json_schema__(
json_schema = handler(schema["metadata"])
json_schema = handler.resolve_ref_schema(json_schema)

if not isinstance(dtype, tuple) and dtype.__module__ not in (
"builtins",
"typing",
if (
not isinstance(dtype, tuple)
and dtype.__module__
not in (
"builtins",
"typing",
"types",
)
and hasattr(dtype, "__name__")
):
json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__])

Expand Down
15 changes: 14 additions & 1 deletion src/numpydantic/testing/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,27 +143,35 @@ class SubClass(BasicModel):
dtype=np.uint32,
passes=True,
id="union-type-uint32",
marks={"union"},
),
ValidationCase(
annotation_dtype=UNION_TYPE,
dtype=np.float32,
passes=True,
id="union-type-float32",
marks={"union"},
),
ValidationCase(
annotation_dtype=UNION_TYPE,
dtype=np.uint64,
passes=False,
id="union-type-uint64",
marks={"union"},
),
ValidationCase(
annotation_dtype=UNION_TYPE,
dtype=np.float64,
passes=False,
id="union-type-float64",
marks={"union"},
),
ValidationCase(
annotation_dtype=UNION_TYPE, dtype=str, passes=False, id="union-type-str"
annotation_dtype=UNION_TYPE,
dtype=str,
passes=False,
id="union-type-str",
marks={"union"},
),
]
"""
Expand All @@ -181,30 +189,35 @@ class SubClass(BasicModel):
dtype=np.uint32,
passes=True,
id="union-pipe-uint32",
marks={"union", "pipe_union"},
),
ValidationCase(
annotation_dtype=UNION_PIPE,
dtype=np.float32,
passes=True,
id="union-pipe-float32",
marks={"union", "pipe_union"},
),
ValidationCase(
annotation_dtype=UNION_PIPE,
dtype=np.uint64,
passes=False,
id="union-pipe-uint64",
marks={"union", "pipe_union"},
),
ValidationCase(
annotation_dtype=UNION_PIPE,
dtype=np.float64,
passes=False,
id="union-pipe-float64",
marks={"union", "pipe_union"},
),
ValidationCase(
annotation_dtype=UNION_PIPE,
dtype=str,
passes=False,
id="union-pipe-str",
marks={"union", "pipe_union"},
),
]
)
Expand Down
28 changes: 25 additions & 3 deletions src/numpydantic/testing/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
from itertools import product
from operator import ior
from pathlib import Path
from typing import Generator, List, Literal, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Generator, List, Literal, Optional, Tuple, Type, Union

import numpy as np
from pydantic import BaseModel, ConfigDict, ValidationError, computed_field
from pydantic import BaseModel, ConfigDict, Field, ValidationError, computed_field

from numpydantic import NDArray, Shape
from numpydantic.dtype import Float
from numpydantic.interface import Interface
from numpydantic.types import DtypeType, NDArrayType

if TYPE_CHECKING:
from _pytest.mark.structures import MarkDecorator


class InterfaceCase(ABC):
"""
Expand Down Expand Up @@ -139,6 +142,8 @@ class ValidationCase(BaseModel):
"""The interface test case to generate and validate the array with"""
path: Optional[Path] = None
"""The path to generate arrays into, if any."""
marks: set[str] = Field(default_factory=set)
"""pytest marks to set for this test case"""

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down Expand Up @@ -179,6 +184,19 @@ class Model(BaseModel):

return Model

@property
def pytest_marks(self) -> list["MarkDecorator"]:
"""
Instantiated pytest marks from :attr:`.ValidationCase.marks`
plus the interface name.
"""
import pytest

marks = self.marks.copy()
if self.interface is not None:
marks.add(self.interface.interface.name)
return [getattr(pytest.mark, m) for m in marks]

def validate_case(self, path: Optional[Path] = None) -> bool:
"""
Whether the generated array correctly validated against the annotation,
Expand Down Expand Up @@ -246,7 +264,10 @@ def merge_cases(*args: ValidationCase) -> ValidationCase:
return args[0]

dumped = [
m.model_dump(exclude_unset=True, exclude={"model", "annotation"}) for m in args
m.model_dump(
exclude_unset=True, exclude={"model", "annotation", "pytest_marks"}
)
for m in args
]

# self_dump = self.model_dump(exclude_unset=True)
Expand All @@ -263,6 +284,7 @@ def merge_cases(*args: ValidationCase) -> ValidationCase:
merged = reduce(ior, dumped, {})
merged["passes"] = passes
merged["id"] = ids
merged["marks"] = set().union(*[v.get("marks", set()) for v in dumped])
return ValidationCase.model_construct(**merged)


Expand Down
4 changes: 2 additions & 2 deletions src/numpydantic/vendor/nptyping/base_meta_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class SubscriptableMeta(ABCMeta):
new type is returned for every unique set of arguments.
"""

_all_types: Dict[Tuple[type, Tuple[Any, ...]], type] = {}
_all_types: Dict[Tuple[type, Tuple[Any, ...], tuple[type, ...]], type] = {}
_parameterized: bool = False

@abstractmethod
Expand Down Expand Up @@ -160,7 +160,7 @@ def __getitem__(cls, item: Any) -> type:
def _create_type(
cls, args: Tuple[Any, ...], additional_values: Dict[str, Any]
) -> type:
key = (cls, args)
key = (cls, args, tuple(type(a) for a in args))
if key not in cls._all_types:
cls._all_types[key] = type(
cls.__name__,
Expand Down
6 changes: 4 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def pytest_addoption(parser):


@pytest.fixture(
scope="function", params=[pytest.param(c, id=c.id) for c in SHAPE_CASES]
scope="function",
params=[pytest.param(c, id=c.id, marks=c.pytest_marks) for c in SHAPE_CASES],
)
def shape_cases(request, tmp_output_dir_func) -> ValidationCase:
case: ValidationCase = request.param.model_copy()
Expand All @@ -23,7 +24,8 @@ def shape_cases(request, tmp_output_dir_func) -> ValidationCase:


@pytest.fixture(
scope="function", params=[pytest.param(c, id=c.id) for c in DTYPE_CASES]
scope="function",
params=[pytest.param(c, id=c.id, marks=c.pytest_marks) for c in DTYPE_CASES],
)
def dtype_cases(request, tmp_output_dir_func) -> ValidationCase:
case: ValidationCase = request.param.model_copy()
Expand Down
12 changes: 3 additions & 9 deletions tests/test_interface/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ def interface_cases(request) -> InterfaceCase:


@pytest.fixture(
params=(
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
for p in ALL_CASES
)
params=(pytest.param(p, id=p.id, marks=p.pytest_marks) for p in ALL_CASES)
)
def all_cases(interface_cases, request) -> ValidationCase:
"""
Expand All @@ -83,10 +80,7 @@ def all_cases(interface_cases, request) -> ValidationCase:


@pytest.fixture(
params=(
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
for p in ALL_CASES_PASSING
)
params=(pytest.param(p, id=p.id, marks=p.pytest_marks) for p in ALL_CASES_PASSING)
)
def all_passing_cases(request) -> ValidationCase:
"""
Expand Down Expand Up @@ -132,7 +126,7 @@ def all_passing_cases_instance(all_passing_cases, tmp_output_dir_func):

@pytest.fixture(
params=(
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
pytest.param(p, id=p.id, marks=p.pytest_marks)
for p in DTYPE_AND_INTERFACE_CASES_PASSING
)
)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_interface/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ def test_interface_revalidate(all_passing_cases_instance):
_ = type(all_passing_cases_instance)(array=all_passing_cases_instance.array)


@pytest.mark.json_schema
def test_interface_jsonschema(all_passing_cases_instance):
"""
All interfaces should be able to generate json schema
for all combinations of dtype and shape

Note that this does not test for json schema correctness -
see ndarray tests for that
"""
_ = all_passing_cases_instance.model_json_schema()


@pytest.mark.xfail
def test_interface_rematch(interface_cases, tmp_output_dir_func):
"""
Expand Down
Loading