Make a partial function into a pytree #21475
-
Hello, I've written a library which takes any function and converts a partial application it into a pytree. You write The idea is that a model has two parts, the parameters and the input, we want to vary the input, and differentiate the parameters. It's like a differentiable version of functools.partial. I think it would be cool if jax provided such a feature by default. It's kinda like equinox, if you only implement @funtree
def Model(input, parameter):
return input + parameter
model = Model(parameter=jnp.arrray([1,2,3]))
# model is a pytree, can get gradients wrt parameter
grad = jax.grad(lambda f, x: f(x))(model, input) This lets you write terse code such as attention below, which makes reading the math easy. The complicated part is all moved to an external function to actually initialize a new model. I think that a lot of the complexity of frameworks is magic running to get shape sizes, which should all be handled by the initializer function, which should know every parameter. I have no idea how to do state, because I haven't tried batch norm yet. @funtree.makefun
def Mlp(x, key, up, down, dropout_p: float):
x_norm = rms_norm(x)
expanded = jax.nn.gelu(einsum(x_norm, up, 'L E, E U -> L U'))
lowered = einsum(expanded, down, 'L U, U E -> L E')
return dropout(lowered, key, dropout_p)
@funtree.makefun
def Attention(x, key, qkv, out, heads: int, dropout_p: float):
x_norm = rms_norm(x)
parts = einsum(x_norm, qkv, 'L E, E HsplitD -> L HsplitD')
k, q, v = rearrange(parts, 'L (H split D) -> split H L D', split=3, H=heads)
q, k = norm(q), norm(k)
H, L, D = k.shape
mask = jnp.tril(jnp.ones([L, L]))
similarity = einsum(k, q, 'H L D, H L2 D -> H L L2') * (D ** -0.5)
masked_similarity = jnp.where(mask, similarity, -jnp.inf)
attention = jax.nn.softmax(masked_similarity, axis=-1)
attention = dropout(attention, key, dropout_p)
gather = einsum(attention, v, 'H L L2, H L2 V -> H L V')
gather = rearrange(gather, 'H L V -> L (H V)')
output = einsum(gather, out, 'L Z, Z E -> L E')
return output
@funtree.makefun
def GPT(x, key, embedding, positional, layers, unembed):
L = x.shape[0]
hidden = embedding[x] + positional[:L, :]
for layer, k in utils.zipkey(layers, key):
hidden = hidden + layer(hidden, key=k)
logits = einsum(unembed, hidden, 'E O, L E -> L O')
return logits
def init_gpt_model(vocab, embedding, heads, layer_count, expansion, max_length, use_swiglu):
return GPT(embedding=..., ) I have an implementation: https://github.com/randomekek/sequence/blob/main/funtree.py#L17-L52 |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
How does this compare to |
Beta Was this translation helpful? Give feedback.
-
Your right, I think a combination of The main differences is that I defer the transformation, so you need to call 3 times, so funtree is a transformation like jit or grad. The benefit is that
|
Beta Was this translation helpful? Give feedback.
How does this compare to
jax.tree_util.Partial
?