-
-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Have a decorator to wrap universal functions ? #34
Comments
Yes, a nice generic decorator that can handle arbitrary number of arguments (and return values) would be great. I am pretty sure I thought about this before, but I cannot recall why I didn't do it.
Have you seen the |
Thanks for your response! |
I show bellow a first POC that I wrote for the wrapper function import numbers
from collections import defaultdict
from functools import wraps
from typing import Any
import eagerpy as ep
def _tuple_as(template, data):
data = list(data)
try:
# list, tuple case
return type(template)(data)
except TypeError:
# named tuple case
return type(template)(*data)
def _dict_as(template, data):
"""Create dictionary like data structure from template object.
Parameters
----------
template
objecti used as template
data
data used to fill the created object.
"""
if isinstance(template, defaultdict):
return type(template)(template.default_factory, data)
return type(template)(data)
def as_eager_tensors(data: Any) -> (Any, bool):
return as_eager_tensors_(data)[0]
def as_eager_tensors_(data: Any) -> (Any, bool):
"""Convert to eagerpy tensors.
Parameters
----------
data : (tuple, list, dict, namedtuple, defaultdict)
data structure to convert
Returns
-------
unwrap : bool
if True, it means that the tensors have been converted
to eagerpy tensors.
"""
if isinstance(data, dict):
# dict, defaultdict
if not data:
return data, None
keys, res_values, unwrap_values = zip(
*[(dim,) + as_eager_tensors_(var) for dim, var in data.items()]
)
unwrap = True in unwrap_values
return _dict_as(data, dict(zip(keys, res_values))), unwrap
elif isinstance(data, (list, tuple)):
if not data:
return data, None
res_values, unwrap_values = zip(*[as_eager_tensors_(var) for var in data])
unwrap = True in unwrap_values
try:
res = type(data)(res_values)
except TypeError:
res = type(data)(*res_values)
return res, unwrap
elif isinstance(data, ep.Tensor):
return data, False
elif isinstance(data, np.datetime64):
# datetime not managed by ep.tensors
return data, False
elif isinstance(data, numbers.Number):
return data, False
return ep.astensor(data), True
def as_raw_tensors(data):
"""Convert from eager tensors to raw tensors.
Parameters
----------
data
data to convert
"""
if isinstance(data, dict):
return _dict_as(data, {dim: as_raw_tensors(var) for dim, var in data.items()})
elif isinstance(data, (list, tuple)):
return _tuple_as(data, (as_raw_tensors(var) for var in data))
if isinstance(data, ep.Tensor):
return data.raw
else:
return data
def restore_tensor_type(data: Any, unwrap: bool) -> Any:
if unwrap:
return as_raw_tensors(data)
else:
return data
def eager_function(func):
@wraps(func)
def eager_func(*args, **kwargs):
self = None
if len(func.__qualname__.split(".")) > 1:
args = list(args)
self = args.pop(0)
args, args_unwrap = as_eager_tensors_(args)
kwargs, kwargs_unwrap = as_eager_tensors_(kwargs)
unwrap = args_unwrap or kwargs_unwrap
if self:
args = [self] + args
result = func(*args, **kwargs)
return restore_tensor_type(result, unwrap)
return eager_func |
Another possibility could be to use pytrees implemented in Jax. This should permit to handle more data structures and also to rely on the existing |
I propose an implementation based on pytrees in #41. |
In fact, I think it's not a good idea to not register JAXTensor in pytrees, it should prevent to have compatibility with jax functionalities. I will restore that in an update of the review. |
In order to simplify the writting of universal functions it could be great to have a decorator function which hide the technical part of the code (convertion of input and output of the wrapped function/method).
For example, the code:
would become:
In addition, we could add the feature that if the input tensors are already eagerpy tensors, then no convertion to raw format should done on the output tensors.
I wrote a prototype of such a decorator function. It should not work on any type of arguments and so its usage would require that the wrapped function has a rather "simple" signature (with args and kwargs constituted of tensors or nested containers with tensors on leaves: dict, list, tuple or namedtuple like containers).
Would you consider to have this feature in eagerpy?
The text was updated successfully, but these errors were encountered: