You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have been following these two tutorials [1] [2] for exposing my own C++ class methods to jax. The ultimate idea is to create an operator that could be cast into numpyro's NUTS / HMC samplers. There was already a lot of development done on the C++ side, so I just want to make a wrapper for some of the objects on the C++ side and expose them through pybind to python and then create a jax Primitive out of them.
The problem
I have encountered a problem where I need to pass a class object pointer from the python side to C++ side, so that my XLA custom_call can properly call the corresponding class method, with the correct instance of the class. The code for this custom call looks like this:
staticvoidXLACPU_custom_call(void *out, constvoid **in){
// The class instance that I initialized on the python sideconst pyhelperWrapper * phW = reinterpret_cast<const pyhelperWrapper* >(in[0]);
// 3D Array of double values that I pass from the python side, for example, sampled by numpyroconst py::array_t<double> *pyshat = reinterpret_cast<const py::array_t<double> *>(in[1]);
// The result is going to simply be a double coming from the likelihooddouble *result = reinterpret_cast<double *>(out);
// Here I call the method of the class instance
*result = phW->compute_like(*pyshat);
}
The reason why I need to pass the class object (pointer of type pyhelperWrapper) is that I also want the class instance to be passed to C++ side, because some class variables are already properly initialized on the python side, and I just want to pass that object to C++ in order to properly evaluate the line *result = phW->compute_like(*pyshat);.
Therefore, I am wondering is it even possible to pass around through JAX primitive's arguments pointers to class objects?
P.S. Note that I needed to make the function static here, given that the pointer to this function is passed to the pybind::capsule which will raise an error if it is a non-static class member, because this pointer would be passed by default too.
More details on the problem
For some more context, when trying to make the jax primitive on the python side, the first step, as far as I understand, is to tell jax how to bind the inputs, i.e.:
# This function exposes the primitive to user code and this is the only# public-facing function in this moduledefpylefty_ll(shat, pyleftyobj):
return_pylefty_ll_prim.bind(shat, pyleftyobj)
Where here the pyleftyobj is the pointer to the instantiated object. And of course, jax doesn't know how to handle this class object pointer and hence gives the following error:
File "...", line 59, in <module>
test_init()
File ".../test_pylefty.py", line 55, in test_init
pylefty_ll(shat, pyleftyobj)
File ".../pylefty_jax.py", line 33, in pylefty_ll
return _pylefty_ll_prim.bind(shat, pyleftyobj)
File "~/.local/lib/python3.9/site-packages/jax/_src/core.py", line 343, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "~/.local/lib/python3.9/site-packages/jax/_src/core.py", line 346, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "~/.local/lib/python3.9/site-packages/jax/_src/core.py", line 728, in process_primitive
return primitive.impl(*tracers, **params)
File "~/.local/lib/python3.9/site-packages/jax/_src/dispatch.py", line 122, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
File "~/.local/lib/python3.9/site-packages/jax/_src/dispatch.py", line 107, in arg_spec
aval = xla.abstractify(x)
File "~/.local/lib/python3.9/site-packages/jax/_src/interpreters/xla.py", line 273, in abstractify
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
TypeError: Argument '<pyleftfield.pylefty object at 0x14b4a42f46b0>' of type '<class 'pyleftfield.pylefty'>' is not a valid JAX type
Here is also an example of how I am testing my implementation so far and which line causes the above error (see #FIXME):
Another problem that will arise as well I guess, is how to treat the mlir lowering step, when performing the jaxlib.hlo_helpers.custom_call, since as far as I understand I would need to do something like:
out=custom_call(
b"pylefty_ll_forward_f64",
# Output of the XLACPU_custom_call call is just doubleout_types=[ir.F64Type],
# Inputs of the XLACPU_custom_calloperands=[pyleftyobj, shat],
# Layout specification# FIXME: How to see what is the layout of the `pyleftyobj`?operand_layouts=[(sys.getsizeof(pyleftyobj),), layout],
# Result is just a number!result_layouts=[()],
)
Here I am pretty sure the sys.getsizeof(pyleftyobj) would probably not be the correct thing to pass, given how the hlo.CustomCallOp is structured, so I just guessed this would be the structure of the call. However, first I need to solve the above problem of passing class object pointer in order to deal with this one down the road..
Possible solutions I am not satisfied with
One way around this issue is to pass always all the necessary objects (dicts, lists, strings, ints, bools) and initialize the class instance every time the XLACPU_custom_call is called on the C++ side. In other words pass to the pylefty_ll these arguments. But this doesn't seem like a practical solution.
If anyone is interested, the solution was to simply have an extern pointer of the class object type inside the header file and then initialize it upon instantiating the class and reusing this pointer on the XLA's custom_call. So I will close up this discussion due to the lack of interest.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi!
The goal
I have been following these two tutorials [1] [2] for exposing my own C++ class methods to jax. The ultimate idea is to create an operator that could be cast into numpyro's NUTS / HMC samplers. There was already a lot of development done on the C++ side, so I just want to make a wrapper for some of the objects on the C++ side and expose them through pybind to python and then create a jax Primitive out of them.
The problem
I have encountered a problem where I need to pass a class object pointer from the python side to C++ side, so that my XLA
custom_call
can properly call the corresponding class method, with the correct instance of the class. The code for this custom call looks like this:The reason why I need to pass the class object (pointer of type
pyhelperWrapper
) is that I also want the class instance to be passed to C++ side, because some class variables are already properly initialized on the python side, and I just want to pass that object to C++ in order to properly evaluate the line*result = phW->compute_like(*pyshat);
.Therefore, I am wondering is it even possible to pass around through JAX primitive's arguments pointers to class objects?
P.S. Note that I needed to make the function
static
here, given that the pointer to this function is passed to thepybind::capsule
which will raise an error if it is a non-static class member, becausethis
pointer would be passed by default too.More details on the problem
For some more context, when trying to make the jax primitive on the python side, the first step, as far as I understand, is to tell jax how to bind the inputs, i.e.:
Where here the
pyleftyobj
is the pointer to the instantiated object. And of course, jax doesn't know how to handle this class object pointer and hence gives the following error:Here is also an example of how I am testing my implementation so far and which line causes the above error (see
#FIXME
):Follow up problem
Another problem that will arise as well I guess, is how to treat the
mlir
lowering step, when performing thejaxlib.hlo_helpers.custom_call
, since as far as I understand I would need to do something like:Here I am pretty sure the
sys.getsizeof(pyleftyobj)
would probably not be the correct thing to pass, given how thehlo.CustomCallOp
is structured, so I just guessed this would be the structure of the call. However, first I need to solve the above problem of passing class object pointer in order to deal with this one down the road..Possible solutions I am not satisfied with
XLACPU_custom_call
is called on the C++ side. In other words pass to thepylefty_ll
these arguments. But this doesn't seem like a practical solution.Beta Was this translation helpful? Give feedback.
All reactions