diff --git a/eagerpy/__init__.py b/eagerpy/__init__.py index bf04d8f..31d887a 100644 --- a/eagerpy/__init__.py +++ b/eagerpy/__init__.py @@ -33,6 +33,7 @@ def __getitem__(self, index: _T) -> _T: from .astensor import astensors # noqa: F401,E402 from .astensor import astensor_ # noqa: F401,E402 from .astensor import astensors_ # noqa: F401,E402 +from .astensor import eager_function # noqa: F401,E402 from .modules import torch # noqa: F401,E402 from .modules import tensorflow # noqa: F401,E402 diff --git a/eagerpy/astensor.py b/eagerpy/astensor.py index f547179..61eaa21 100644 --- a/eagerpy/astensor.py +++ b/eagerpy/astensor.py @@ -1,6 +1,18 @@ -from typing import TYPE_CHECKING, Union, overload, Tuple, TypeVar, Generic, Any +import functools +from typing import ( + TYPE_CHECKING, + Union, + overload, + Tuple, + TypeVar, + Generic, + Any, + Callable, +) import sys +from jax import tree_flatten, tree_unflatten + from .tensor import Tensor from .tensor import TensorType @@ -36,7 +48,7 @@ def astensor(x: NativeTensor) -> Tensor: # type: ignore ... -def astensor(x: Union[NativeTensor, Tensor]) -> Tensor: # type: ignore +def astensor(x: Union[NativeTensor, Tensor, Any]) -> Union[Tensor, Any]: # type: ignore if isinstance(x, Tensor): return x # we use the module name instead of isinstance @@ -52,7 +64,9 @@ def astensor(x: Union[NativeTensor, Tensor]) -> Tensor: # type: ignore return JAXTensor(x) if name == "numpy" and isinstance(x, m[name].ndarray): # type: ignore return NumPyTensor(x) - raise ValueError(f"Unknown type: {type(x)}") + + # non Tensor types are returned unmodified + return x def astensors(*xs: Union[NativeTensor, Tensor]) -> Tuple[Tensor, ...]: # type: ignore @@ -84,7 +98,7 @@ def __call__(self, *args: Any) -> Any: ... def __call__(self, *args): # type: ignore # noqa: F811 - result = tuple(x.raw for x in args) if self.unwrap else args + result = tuple(as_raw_tensor(x) for x in args) if self.unwrap else args if len(result) == 1: (result,) = result return result @@ -96,3 +110,79 @@ def astensor_(x: T) -> Tuple[Tensor, RestoreTypeFunc[T]]: def astensors_(x: T, *xs: T) -> Tuple[Tuple[Tensor, ...], RestoreTypeFunc[T]]: return astensors(x, *xs), RestoreTypeFunc[T](x) + + +def as_tensors(data: Any) -> Any: + leaf_values, tree_def = tree_flatten(data) + leaf_values = tuple(astensor(value) for value in leaf_values) + return tree_unflatten(tree_def, leaf_values) + + +def has_tensor(tree_def: Any) -> bool: + return " Tuple[Any, bool]: + """Convert data structure leaves in Tensor and detect if any of the input data contains a Tensor. + + Parameters + ---------- + data + data structure. + + Returns + ------- + Any + modified data structure. + bool + True if input data contains a Tensor type. + """ + leaf_values, tree_def = tree_flatten(data) + transformed_leaf_values = tuple(astensor(value) for value in leaf_values) + return tree_unflatten(tree_def, transformed_leaf_values), has_tensor(tree_def) + + +def as_raw_tensor(x: T) -> Any: + if isinstance(x, Tensor): + return x.raw + else: + return x + + +def as_raw_tensors(data: Any) -> Any: + leaf_values, tree_def = tree_flatten(data) + + if not has_tensor(tree_def): + return data + + leaf_values = tuple(as_raw_tensor(value) for value in leaf_values) + unwrap_leaf_values = [] + for x in leaf_values: + name = _get_module_name(x) + m = sys.modules + if name == "torch" and isinstance(x, m[name].Tensor): # type: ignore + unwrap_leaf_values.append((x, True)) + elif name == "tensorflow" and isinstance(x, m[name].Tensor): # type: ignore + unwrap_leaf_values.append((x, True)) + elif (name == "jax" or name == "jaxlib") and isinstance(x, m["jax"].numpy.ndarray): # type: ignore + unwrap_leaf_values.append((x, True)) + elif name == "numpy" and isinstance(x, m[name].ndarray): # type: ignore + unwrap_leaf_values.append((x, True)) + else: + unwrap_leaf_values.append(x) + return tree_unflatten(tree_def, unwrap_leaf_values) + + +def eager_function(func: Callable[..., T]) -> Callable[..., T]: + @functools.wraps(func) + def eager_func(*args: Any, **kwargs: Any) -> Any: + (args, kwargs), has_tensor = as_tensors_any((args, kwargs)) + unwrap = not has_tensor + result = func(*args, **kwargs) + if unwrap: + raw_result = as_raw_tensors(result) + return raw_result + else: + return result + + return eager_func diff --git a/eagerpy/tensor/base.py b/eagerpy/tensor/base.py index 41a7da3..f1e2953 100644 --- a/eagerpy/tensor/base.py +++ b/eagerpy/tensor/base.py @@ -1,5 +1,5 @@ from typing_extensions import final -from typing import Any, cast +from typing import Any, Type, Union, cast, Tuple from .tensor import Tensor from .tensor import TensorType @@ -17,6 +17,33 @@ def unwrap1(t: Any) -> Any: class BaseTensor(Tensor): __slots__ = "_raw" + _registered = False + + def __new__(cls: Type["BaseTensor"], *args: Any, **kwargs: Any) -> "BaseTensor": + if not cls._registered: + import jax + + def flatten(t: Tensor) -> Tuple[Any, None]: + return ((t.raw,), None) + + def unflatten(aux_data: None, children: Tuple) -> Union[Tensor, Any]: + assert len(children) == 1 + x = children[0] + del children + + if isinstance(x, tuple): + x, unwrap = x + if unwrap: + return x + + if isinstance(x, Tensor): + return x + return cls(x) + + jax.tree_util.register_pytree_node(cls, flatten, unflatten) + cls._registered = True + return cast("BaseTensor", super().__new__(cls)) + def __init__(self: TensorType, raw: Any): assert not isinstance(raw, Tensor) self._raw = raw diff --git a/eagerpy/tensor/extensions.py b/eagerpy/tensor/extensions.py index 1a0423f..f60ec75 100644 --- a/eagerpy/tensor/extensions.py +++ b/eagerpy/tensor/extensions.py @@ -6,7 +6,6 @@ from .tensor import Tensor - T = TypeVar("T") diff --git a/eagerpy/tensor/jax.py b/eagerpy/tensor/jax.py index 1acf104..74ae75e 100644 --- a/eagerpy/tensor/jax.py +++ b/eagerpy/tensor/jax.py @@ -9,8 +9,8 @@ Optional, overload, Callable, - Type, ) + from typing_extensions import Literal from importlib import import_module import numpy as onp @@ -58,27 +58,11 @@ def getitem_preprocess(x: Any) -> Any: class JAXTensor(BaseTensor): __slots__ = () - # more specific types for the extensions norms: "NormsMethods[JAXTensor]" - _registered = False key = None - def __new__(cls: Type["JAXTensor"], *args: Any, **kwargs: Any) -> "JAXTensor": - if not cls._registered: - import jax - - def flatten(t: JAXTensor) -> Tuple[Any, None]: - return ((t.raw,), None) - - def unflatten(aux_data: None, children: Tuple) -> JAXTensor: - return cls(*children) - - jax.tree_util.register_pytree_node(cls, flatten, unflatten) - cls._registered = True - return cast(JAXTensor, super().__new__(cls)) - def __init__(self, raw: "np.ndarray"): # type: ignore global jax global np @@ -434,46 +418,24 @@ def _value_and_grad_fn( def _value_and_grad_fn( # noqa: F811 (waiting for pyflakes > 2.1.1) self: TensorType, f: Callable, has_aux: bool = False ) -> Callable[..., Tuple]: - # f takes and returns JAXTensor instances - # jax.value_and_grad accepts functions that take JAXTensor instances - # because we registered JAXTensor as JAX type, but it still requires - # the output to be a scalar (that is not not wrapped as a JAXTensor) - - # f_jax is like f but unwraps loss - if has_aux: - - def f_jax(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: - loss, aux = f(*args, **kwargs) - return loss.raw, aux - - else: - - def f_jax(*args: Any, **kwargs: Any) -> Any: # type: ignore - loss = f(*args, **kwargs) - return loss.raw - - value_and_grad_jax = jax.value_and_grad(f_jax, has_aux=has_aux) - - # value_and_grad is like value_and_grad_jax but wraps loss - if has_aux: - - def value_and_grad( - x: JAXTensor, *args: Any, **kwargs: Any - ) -> Tuple[JAXTensor, Any, JAXTensor]: - assert isinstance(x, JAXTensor) - (loss, aux), grad = value_and_grad_jax(x, *args, **kwargs) - assert grad.shape == x.shape - return JAXTensor(loss), aux, grad - - else: - - def value_and_grad( # type: ignore - x: JAXTensor, *args: Any, **kwargs: Any - ) -> Tuple[JAXTensor, JAXTensor]: - assert isinstance(x, JAXTensor) - loss, grad = value_and_grad_jax(x, *args, **kwargs) - assert grad.shape == x.shape - return JAXTensor(loss), grad + from eagerpy.astensor import as_tensors, as_raw_tensors + + def value_and_grad( + x: JAXTensor, *args: Any, **kwargs: Any + ) -> Union[Tuple[JAXTensor, JAXTensor], Tuple[JAXTensor, Any, JAXTensor]]: + assert isinstance(x, JAXTensor) + x, args, kwargs = as_raw_tensors((x, args, kwargs)) + + loss_aux, grad = jax.value_and_grad(f, has_aux=has_aux)(x, *args, **kwargs) + assert grad.shape == x.shape + loss_aux, grad = as_tensors((loss_aux, grad)) + + if has_aux: + loss, aux = loss_aux + return loss, aux, grad + else: + loss = loss_aux + return loss, grad return value_and_grad diff --git a/tests/test_main.py b/tests/test_main.py index 36ca7df..f62ef76 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,7 +4,7 @@ import itertools import numpy as np import eagerpy as ep -from eagerpy import Tensor +from eagerpy import Tensor, eager_function from eagerpy.types import Shape, AxisAxes # make sure there are no undecorated tests in the "special tests" section below @@ -147,6 +147,7 @@ def test_value_and_grad_fn(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() + @eager_function def f(x: ep.Tensor) -> ep.Tensor: return x.square().sum() @@ -161,6 +162,7 @@ def test_value_and_grad_fn_with_aux(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() + @eager_function def f(x: Tensor) -> Tuple[Tensor, Tensor]: x = x.square() return x.sum(), x @@ -177,6 +179,7 @@ def test_value_and_grad(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() + @eager_function def f(x: Tensor) -> Tensor: return x.square().sum() @@ -190,6 +193,7 @@ def test_value_aux_and_grad(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() + @eager_function def f(x: Tensor) -> Tuple[Tensor, Tensor]: x = x.square() return x.sum(), x @@ -205,6 +209,7 @@ def test_value_aux_and_grad_multiple_aux(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() + @eager_function def f(x: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: x = x.square() return x.sum(), (x, x + 1) @@ -221,6 +226,7 @@ def test_value_and_grad_multiple_args(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() + @eager_function def f(x: Tensor, y: Tensor) -> Tensor: return (x * y).sum() @@ -1581,3 +1587,96 @@ def test_norms_lp(t: Tensor) -> Tensor: @compare_all def test_norms_cache(t: Tensor) -> Tensor: return t.norms.l1() + t.norms.l2() + + +@eager_function +def my_universal_function(a: Tensor, b: Tensor, c: Tensor) -> Tensor: + return (a + b * c).square() + + +@pytest.mark.parametrize("astensor", [False, True]) +@compare_all +def test_eager_function(t: Tensor, astensor: bool) -> Tensor: + if astensor: + a = t + else: + a = t.raw + b = a + c = a + result = my_universal_function(a, b, c) + assert isinstance(result, type(a)) + return ep.astensor(result) + + +# define a non-registered pytree container. +class NonRegisteredDataStruct: + def __init__(self, res: Any) -> None: + self.res = res + + +@eager_function +def my_universal_function_return_non_registered_datastruct( + a: Tensor, b: Tensor, c: Tensor +) -> Any: + res = (a + b * c).square() + return NonRegisteredDataStruct(res) + + +@pytest.mark.parametrize("astensor", [False, True]) +@compare_all +def test_eager_function_return_non_registered_datastruct( + t: Tensor, astensor: bool +) -> Tensor: + if astensor: + a = t + else: + a = t.raw + b = a + c = a + result = my_universal_function_return_non_registered_datastruct(a, b, c) + + # result has not been converted because NonRegisteredSpecial + # is not a registered pytree container + assert isinstance(result.res, type(t)) + return ep.astensor(result.res) + + +# define a non-registered pytree container. +class MyClass: + @eager_function + def my_universal_method(self, a: Tensor, b: Tensor, c: Tensor) -> Any: + res = (a + b * c).square() + return res + + +@pytest.mark.parametrize("astensor", [False, True]) +@compare_all +def test_eager_function_on_method(t: Tensor, astensor: bool) -> Tensor: + if astensor: + a = t + else: + a = t.raw + b = a + c = a + result = MyClass().my_universal_method(a, b, c) + assert isinstance(result, type(a)) + return ep.astensor(result) + + +@eager_function +def my_universal_function_with_non_tensors(a: int, b: Tensor, c: Tensor) -> Tensor: + return (a + b * c).square() + + +@pytest.mark.parametrize("astensor", [False, True]) +@compare_all +def test_eager_function_with_non_tensors(t: Tensor, astensor: bool) -> Tensor: + if astensor: + b = t + else: + b = t.raw + a = 3 + c = b + result = my_universal_function_with_non_tensors(a, b, c) + assert isinstance(result, type(b)) + return ep.astensor(result)