-
I really like the design of JAX, that at its core is a program transformation system. So I was wondering how easy would it be to implement some of my custom program transformations. For example, I wanted to implement the following transformation:
In particular, what the transformation is doing is that it identifies the subexpression
|
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 8 replies
-
This seems like it would do better as a Github discussion (https://github.com/google/jax/discussions), rather than an issue. There's no action for us to take here. Closing; feel free to reopen in the discussions section. To answer your concrete question, have you seen: ? |
Beta Was this translation helpful? Give feedback.
-
I also saw the callback_transformation https://github.com/google/jax/blob/master/jax/experimental/callback.py#L32. This does seem relevant, though looking at the PR that created this: #2665.
I just hope there's some simple working examples (callback.py counts, modulo the warning above) of using Trace and Tracer, so I can dig myself of doing custom transformations. |
Beta Was this translation helpful? Give feedback.
-
I think I'm going to lay out my plan here, just so to see if the Jax team has any advice or interest in (all or part of) the plan. I'm interested in using custom transformations for a scenario of "distributed approximate optimization" (name not determined yet). The idea is that the user first writes a non-distributed model, and my tool will:
Right now, I do sort of see how I can achieve 1) and 2). But I'm a bit afraid of 3), in that Jaxpr may be too high level to perform this analysis (though in theory it's all possible). I do see that 3) is potentially of interest to others and the Jax team, as it's basically a |
Beta Was this translation helpful? Give feedback.
I think I'm going to lay out my plan here, just so to see if the Jax team has any advice or interest in (all or part of) the plan.
I'm interested in using custom transformations for a scenario of "distributed approximate optimization" (name not determined yet). The idea is that the user first writes a non-distributed model, and my tool will:
Generate an approximation to that model, with the "Taylor approximate transformation" I mentioned above
Perform optimization on the approximated model but still non-distributed (standard Jax stuff, so no big deal here)
Split the non-distributed function into distributed, with potentially a Jax transformation. This sort of works like
pmap
, but in…