Skip to content

Commit

Permalink
Merge pull request #576 from pablovela5620/jaxtyping
Browse files Browse the repository at this point in the history
initial working jaxtyping serializing/deserializing
  • Loading branch information
yukinarit authored Aug 2, 2024
2 parents b29cef4 + 2a012c1 commit 129e030
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 24 deletions.
90 changes: 90 additions & 0 deletions examples/type_numpy_jaxtyping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import numpy
from jaxtyping import (
Float,
Float16,
Float32,
Float64,
Inexact,
Int,
Int8,
Int16,
Int32,
Int64,
Integer,
UInt,
UInt8,
UInt16,
UInt32,
UInt64,
)
from serde import serde
from serde.json import from_json, to_json


@serde
class Foo:
float_: Float[numpy.ndarray, "3 3"]
float16: Float16[numpy.ndarray, "3 3"]
float32: Float32[numpy.ndarray, "3 3"]
float64: Float64[numpy.ndarray, "3 3"]
inexact: Inexact[numpy.ndarray, "3 3"]
int_: Int[numpy.ndarray, "3 3"]
int8: Int8[numpy.ndarray, "3 3"]
int16: Int16[numpy.ndarray, "3 3"]
int32: Int32[numpy.ndarray, "3 3"]
int64: Int64[numpy.ndarray, "3 3"]
integer: Integer[numpy.ndarray, "3 3"]
uint: UInt[numpy.ndarray, "3 3"]
uint8: UInt8[numpy.ndarray, "3 3"]
uint16: UInt16[numpy.ndarray, "3 3"]
uint32: UInt32[numpy.ndarray, "3 3"]
uint64: UInt64[numpy.ndarray, "3 3"]


def main() -> None:
foo = Foo(
float_=numpy.zeros((3, 3), dtype=float),
float16=numpy.zeros((3, 3), dtype=numpy.float16),
float32=numpy.zeros((3, 3), dtype=numpy.float32),
float64=numpy.zeros((3, 3), dtype=numpy.float64),
inexact=numpy.zeros((3, 3), dtype=numpy.inexact),
int_=numpy.zeros((3, 3), dtype=int),
int8=numpy.zeros((3, 3), dtype=numpy.int8),
int16=numpy.zeros((3, 3), dtype=numpy.int16),
int32=numpy.zeros((3, 3), dtype=numpy.int32),
int64=numpy.zeros((3, 3), dtype=numpy.int64),
integer=numpy.zeros((3, 3), dtype=numpy.integer),
uint=numpy.zeros((3, 3), dtype=numpy.uint),
uint8=numpy.zeros((3, 3), dtype=numpy.uint8),
uint16=numpy.zeros((3, 3), dtype=numpy.uint16),
uint32=numpy.zeros((3, 3), dtype=numpy.uint32),
uint64=numpy.zeros((3, 3), dtype=numpy.uint64),
)

print(f"Into Json: {to_json(foo)}")

s = """
{
"float_": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
"float16": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
"float32": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
"float64": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
"inexact": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
"int_": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"int8": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"int16": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"int32": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"int64": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"integer": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"uint": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"uint8": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"uint16": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"uint32": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"uint64": [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
}
"""
print(f"From Json: {from_json(Foo, s)}")


if __name__ == "__main__":
main()
60 changes: 36 additions & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ packages = [
{ include = "serde" },
]
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]

[tool.poetry.dependencies]
python = "^3.9.0"
Expand All @@ -33,11 +33,12 @@ tomli = { version = "*", markers = "extra == 'toml' or extra == 'all'", optional
tomli-w = { version = "*", markers = "extra == 'toml' or extra == 'all'", optional = true }
pyyaml = { version = "*", markers = "extra == 'yaml' or extra == 'all'", optional = true }
numpy = [
{ version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0' and (extra == 'numpy' or extra == 'all')", optional = true },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10' and (extra == 'numpy' or extra == 'all')", optional = true },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11' and (extra == 'numpy' or extra == 'all')", optional = true },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12' and (extra == 'numpy' or extra == 'all')", optional = true },
{ version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0' and (extra == 'numpy' or extra == 'all')", optional = true },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10' and (extra == 'numpy' or extra == 'all')", optional = true },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11' and (extra == 'numpy' or extra == 'all')", optional = true },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12' and (extra == 'numpy' or extra == 'all')", optional = true },
]
jaxtyping = { version = "*", markers = "extra == 'jaxtyping' or extra == 'all'", optional = true }
orjson = { version = "*", markers = "extra == 'orjson' or extra == 'all'", optional = true }
plum-dispatch = ">=2,<2.3"
beartype = ">=0.18.4"
Expand All @@ -49,10 +50,10 @@ tomli = { version = "*", markers = "python_version <= '3.11.0'" }
tomli-w = "*"
msgpack = "*"
numpy = [
{ version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0'" },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10'" },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11'" },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12'" },
{ version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0'" },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10'" },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11'" },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12'" },
]
mypy = "==1.10.1"
pytest = "*"
Expand All @@ -68,6 +69,7 @@ types-PyYAML = "^6.0.9"
msgpack-types = "^0.3"
envclasses = "^0.3.1"
jedi = "*"
jaxtyping = "*"

[tool.poetry.extras]
msgpack = ["msgpack"]
Expand All @@ -76,7 +78,8 @@ toml = ["tomli", "tomli-w"]
yaml = ["pyyaml"]
orjson = ["orjson"]
sqlalchemy = ["sqlalchemy"]
all = ["msgpack", "tomli", "tomli-w", "pyyaml", "numpy", "orjson", "sqlalchemy"]
jaxtyping = ["jaxtyping"]
all = ["msgpack", "tomli", "tomli-w", "pyyaml", "numpy", "orjson", "sqlalchemy", "jaxtyping"]

[build-system]
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning"]
Expand Down Expand Up @@ -145,16 +148,25 @@ exclude = [
"tests/test_sqlalchemy.py",
]

[[tool.mypy.overrides]]
# to avoid complaints about generic type ndarray
module = "examples.type_numpy_jaxtyping"
ignore_errors = true

[tool.ruff]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"C", # flake8-comprehensions
"B", # flake8-bugbear
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"C", # flake8-comprehensions
"B", # flake8-bugbear
]
ignore = ["B904"]
line-length = 100

[tool.ruff.lint.mccabe]
max-complexity = 30

[tool.ruff.per-file-ignores]
# https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error
"examples/type_numpy_jaxtyping.py" = ["F722"]
5 changes: 5 additions & 0 deletions serde/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@
deserialize_numpy_array,
deserialize_numpy_scalar,
deserialize_numpy_array_direct,
deserialize_numpy_jaxtyping_array,
is_numpy_array,
is_numpy_jaxtyping,
is_numpy_scalar,
)

Expand Down Expand Up @@ -749,6 +751,9 @@ def render(self, arg: DeField[Any]) -> str:
elif is_numpy_array(arg.type):
self.import_numpy = True
res = deserialize_numpy_array(arg)
elif is_numpy_jaxtyping(arg.type):
self.import_numpy = True
res = deserialize_numpy_jaxtyping_array(arg)
elif is_union(arg.type):
res = self.union_func(arg)
elif is_str_serializable(arg.type):
Expand Down
19 changes: 19 additions & 0 deletions serde/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ def is_numpy_array(typ) -> bool:
typ = origin
return typ is np.ndarray

def is_numpy_jaxtyping(typ) -> bool:
try:
origin = get_origin(typ)
if origin is not None:
typ = origin
return typ is not np.ndarray and issubclass(typ, np.ndarray)
except TypeError:
return False

def serialize_numpy_array(arg) -> str:
return f"{arg.varname}.tolist()"

Expand All @@ -86,6 +95,10 @@ def deserialize_numpy_array(arg) -> str:
dtype = fullname(arg[1][0].type)
return f"numpy.array({arg.data}, dtype={dtype})"

def deserialize_numpy_jaxtyping_array(arg) -> str:
dtype = f"numpy.{arg.type.dtypes[-1]}"
return f"numpy.array({arg.data}, dtype={dtype})"

def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any:
if is_bare_numpy_array(typ):
return np.array(arg)
Expand All @@ -111,6 +124,9 @@ def deserialize_numpy_scalar(arg):
def is_numpy_array(typ) -> bool:
return False

def is_numpy_jaxtyping(typ) -> bool:
return False

def serialize_numpy_array(arg) -> str:
return ""

Expand All @@ -120,5 +136,8 @@ def serialize_numpy_datetime(arg) -> str:
def deserialize_numpy_array(arg) -> str:
return ""

def deserialize_numpy_jaxtyping_array(arg) -> str:
return ""

def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any:
return arg
3 changes: 3 additions & 0 deletions serde/se.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
)
from .numpy import (
is_numpy_array,
is_numpy_jaxtyping,
is_numpy_datetime,
is_numpy_scalar,
serialize_numpy_array,
Expand Down Expand Up @@ -751,6 +752,8 @@ def render(self, arg: SeField[Any]) -> str:
res = serialize_numpy_scalar(arg)
elif is_numpy_array(arg.type):
res = serialize_numpy_array(arg)
elif is_numpy_jaxtyping(arg.type):
res = serialize_numpy_array(arg)
elif is_primitive(arg.type):
res = self.primitive(arg)
elif is_union(arg.type):
Expand Down
61 changes: 61 additions & 0 deletions tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import numpy.typing as npt
import jaxtyping
import pytest

import serde
Expand Down Expand Up @@ -89,6 +90,66 @@ class NumpyDate:

assert de(NumpyDate, se(date_test)) == date_test

@serde.serde(**opt)
class NumpyJaxtyping:
float_: jaxtyping.Float[np.ndarray, "2 2"] # noqa: F722
float16: jaxtyping.Float16[np.ndarray, "2 2"] # noqa: F722
float32: jaxtyping.Float32[np.ndarray, "2 2"] # noqa: F722
float64: jaxtyping.Float64[np.ndarray, "2 2"] # noqa: F722
inexact: jaxtyping.Inexact[np.ndarray, "2 2"] # noqa: F722
int_: jaxtyping.Int[np.ndarray, "2 2"] # noqa: F722
int8: jaxtyping.Int8[np.ndarray, "2 2"] # noqa: F722
int16: jaxtyping.Int16[np.ndarray, "2 2"] # noqa: F722
int32: jaxtyping.Int32[np.ndarray, "2 2"] # noqa: F722
int64: jaxtyping.Int64[np.ndarray, "2 2"] # noqa: F722
integer: jaxtyping.Integer[np.ndarray, "2 2"] # noqa: F722
uint: jaxtyping.UInt[np.ndarray, "2 2"] # noqa: F722
uint8: jaxtyping.UInt8[np.ndarray, "2 2"] # noqa: F722
uint16: jaxtyping.UInt16[np.ndarray, "2 2"] # noqa: F722
uint32: jaxtyping.UInt32[np.ndarray, "2 2"] # noqa: F722
uint64: jaxtyping.UInt64[np.ndarray, "2 2"] # noqa: F722

def __eq__(self, other):
return (
(self.float_ == other.float_).all()
and (self.float16 == other.float16).all()
and (self.float32 == other.float32).all()
and (self.float64 == other.float64).all()
and (self.inexact == other.inexact).all()
and (self.int_ == other.int_).all()
and (self.int8 == other.int8).all()
and (self.int16 == other.int16).all()
and (self.int32 == other.int32).all()
and (self.int64 == other.int64).all()
and (self.integer == other.integer).all()
and (self.uint == other.uint).all()
and (self.uint8 == other.uint8).all()
and (self.uint16 == other.uint16).all()
and (self.uint32 == other.uint32).all()
and (self.uint64 == other.uint64).all()
)

jaxtyping_test = NumpyJaxtyping(
float_=np.array([[1, 2], [3, 4]], dtype=np.float_),
float16=np.array([[5, 6], [7, 8]], dtype=np.float16),
float32=np.array([[9, 10], [11, 12]], dtype=np.float32),
float64=np.array([[13, 14], [15, 16]], dtype=np.float64),
inexact=np.array([[17, 18], [19, 20]], dtype=np.float_),
int_=np.array([[21, 22], [23, 24]], dtype=np.int_),
int8=np.array([[25, 26], [27, 28]], dtype=np.int8),
int16=np.array([[29, 30], [31, 32]], dtype=np.int16),
int32=np.array([[33, 34], [35, 36]], dtype=np.int32),
int64=np.array([[37, 38], [39, 40]], dtype=np.int64),
integer=np.array([[41, 42], [43, 44]], dtype=np.int_),
uint=np.array([[45, 46], [47, 48]], dtype=np.uint),
uint8=np.array([[49, 50], [51, 52]], dtype=np.uint8),
uint16=np.array([[53, 54], [55, 56]], dtype=np.uint16),
uint32=np.array([[57, 58], [59, 60]], dtype=np.uint32),
uint64=np.array([[61, 62], [63, 64]], dtype=np.uint64),
)

assert de(NumpyJaxtyping, se(jaxtyping_test)) == jaxtyping_test


@pytest.mark.parametrize("opt", opt_case, ids=opt_case_ids())
@pytest.mark.parametrize("se,de", format_json + format_msgpack)
Expand Down

0 comments on commit 129e030

Please sign in to comment.