New: Tags, and apply (already in code, to appear in README soon)
Zephyr is a new FP-oriented neural network library/framework on top of JAX that helps you write short, simple, declarative neural networks quickly and easily with minimal learning curve.
For those coming from other frameworks: The main difference is that zephyr is oriented towards writing in an FP-style. No initialization of models is needed because models or nets are just regular functions - no need for separate init/build/construct and a call/forward.
For those new to deep learning: As seen in many textbooks or materials, a neural network
Zephyr, at the core, are just the trace
and validate
functions with extra utilities. trace
gives you parameters. validate
checks expressions related to parameters.
The main mindset in writing zephyr is to think in FP and declarative-manner. Think of composable transformations instead of methods - transformations of both data or arrays AND functions. The examples below, will progressively re-write procedural/imperative-oriented code to the use of function transformations. This puts the focus on what the transformation will be, rather than what the arrays become after each step.
Before we start. A neural network is just function, usually of params
, x
, and hyper-parameters. f(params, x, **hyperparameters)
. If we wanted to get a function
without the hyperparameters, since those never change, we can use python's partial
and rewrite as new_f = partial(f, **hyperparameters)
and use new_f(params, x)
. However, using partial
could get tedious as it doesn't give you signature hints in your editor. Instead, you can use the more readable, zephyr's _
notation which is an alias for placeholder_hole
which zephyr nets accept and auto-partializes the function. So we could write new_f = f(_,_, **hyperparameters)
where _
stands in for values we pas in later. To make your own function accept _
holes, you can use the flexible
decorator.
One more thing, this library was heavily inspired by Haiku, and so params
is a dictionary whose leaves are Arrays. Zephyr, uses the same convention.
pip install z-zephyr --upgrade
Examples | Sharp Bits | Direction | Motivation
Note: OUTDATED : I've changed the API to make them even easier to write
Look at the Following Examples
- Imports: Common Gateway for Imports
- Encoder and Decoder: This example will show you some of the layers in
zephyr.nets
. We use zephyr'schain
function to chain functions(neural networks) together. - Parameter Creation: This example will show you how to use custom parameters in your functions/nets.
- Dealing with random keys: This example will show you that keys are just Arrays and part of your input. Nevertheless, there are some zephyr utilities you could use to transform functions in ways that are useful for dealing with keys.
These are the imports for all the examples
from zephyr.functools.composition import thread_key, thread_params
from jax import numpy as jnp, random, jit, nn
from zephyr import nets, trace
from zephyr.nets import chain
from zephyr.functools.partial import placeholder_hole as _, flexible
Let's write a random encoder and decoder. Notice that we access params
as if we already have a params
made. Indeed, this declarative style is something you would have to get used to. Don't worry, zephyr handles making these parameters for you.
For each of the encoder
, decoder
, and model
we offer 2 versions. One focusing on x
, and the other building the transformation then applying it to x
. These 2 versions are on the extreme, with the first being several lines of code, and the second being a single line of code(broken up). The next examples will use other rewrites that are less extreme.
Encoder: Notice that there neural networks are used just like normal functions. Within each use, we can explicitly see everything, the params, the input/s, and the hyperparameters. This makes code short and concise.
@flexible
def encoder(params, x):
x = nets.mlp(params["mlp"], x, [256,256,256]) # b 256
x = nets.layer_norm(params["ln"], x, -1)
x = nets.branch_linear(params["br"], x, 64) # b 64 256
for i in range(3):
x = nets.conv_1d(params["conv"][i], x, 64, 5)
x = nn.relu(x)
x = nets.max_pool(params, x, (3,3), 2)
x = jnp.reshape(x, [x.shape[0], -1]) # b 256
x = nets.linear(params["linear"], x, 4) # b 4
return x
@flexible
def encoder(params, x):
return chain([
nets.mlp(params["mlp"], _, [256, 256, 256]),
nets.layer_norm(params["ln"], _, -1),
nets.branch_linear(params["br"], _, 64),
* [
chain([
nets.conv_1d(params["conv"][i], _, 64, 5),
nn.relu,
nets.max_pool(params, _, (3,3), 2),
]) for i in range(3)
],
lambda x: jnp.reshape(x, [x.shape[0], -1]),
nets.linear(params["linear"], _, 4)
])(x)
Decoder: Notice that skip connections can be wrapped within a skip
function/network that automatically adds a skip connection as skip(f)(x) = f(x) + x
.
@flexible
def decoder(params, z):
x = nets.mlp(params["mlp"], x, [256,256,256]) # b 256
x = nets.branch_linear(params["br"], x, 64) # b 64 256
for i in range(3):
x = nets.multi_head_self_attention(params["mha"][i], x, 64, 5)
x = x + nets.mlp(params["attn_mlp"][i], x, [256, 256])
x = nets.layer_norm(params["attn_ln"][i], x, -1)
x = jnp.reshape(x, [x.shape[0], -1]) # b (64 * 128) = b 16384
x = nets.linear(params["linear"], x, 128) # b 128
return x
@flexible
def decoder(params, z):
return chain([
nets.mlp(params["mlp"], _, [256, 256, 256]),
nets.branch_linear(params["br"], _, 64),
*[
chain([
nets.multi_head_self_attention(params["mha"][i], _, 64, 5),
nets.skip(nets.mlp(params["attn_mlp"][i], _, [256,256])),
nets.layer_norm(params["attn_ln"][i], _, -1),
]) for i in range(3)
],
lambda x: jnp.reshape(x, [x.shape[0], -1]),
nets.linear(params["linear"], _, 128) # b 128
])(x)
Model:
def model(params, x):
z = encoder(params["encoder"], x)
reconstructed_x = decoder(params["decoder"], z)
return reconstructed_x
def model(params, x):
return chain([
encoder(params["encoder"], _),
decoder(params["decoder"], _),
])(x)
To get an initial params
, we simply use the trace function as follows.
key = random.PRNGKey(0) # needed to randomly initialize weights
x = jnp.ones([64, 8]) # sample input batch:w
params = trace(model, key)
fast_model = jit(model) # tracing of `trace` cannot trace a jit-ed function, please use the non-jit-ed version when tracing
sample_outputs = fast_model(params, x) # b 8
For model surgery or study: if you wanted to use just the enoder, then you can do z = encoder(params["encoder"], x)
. You can do the same with any function/layer.
To illustrate this, we will make our own linear
layer using zephyr. In line with the declarative thinking, we specify what the shape of the paramters would look like -
Ideally, we can put this in the type annotation, but that's ignored by Python, so we instead use zephyr's validate
as an alternative. One main use of validate
is
to specify parameter shape, initializer, and other relationships it might have with hyperparameters.
@flexible
def linear(params, x, out_target):
validate(params["weights"], (x.shape[-1], out_target))
validate(params["bias"], (out_target,))
x = x @ params["weights"] + params["bias"]
return x
As said, earlier we wil show rewrites which is up to you. This is just to show what is possible. There is a way to write this in way that resembles the pattern of
other FP languages where they assume some variables exist and give it to you with a where
keyword, similar to math statements.
@flexible
def linear(params, x, out_target):
return (lambda w, b: x @ w + b)(
validate(params["weights"], (x.shape[-1], out_target)),
validate(params["bias"], (out_target,)),
)
Notice the use of validate
here. validate
is actually just a way to enfore "type annotations" (albeit dependent types because we're really specifying shapes)
because they have to be specified somewhere for zephyr to trace it. Nevertheless, validate
acts like the identity function and returns its first parameter unchanged.
To use it, we simply use the trace
function and use normally as follows.
key = random.PRNGKey(0)
model = linear(_,_, 256)
params = trace(model, key, x)
model(params, x) # use it like this
# or jit it
fast_model = jit(model)
fast_model(params, x)
Random keys or RNGs are somewhat an unfamiliar concept usually, since in FP you have to be explicit with these. So when you try to get rid of it using OO then it tends to stick out like a sore thumb at the end. In zephyr, we embrace this and treat key as you would anything - it is just input to data.
Here a simple model using dropout.
def model(params, x, key):
for i in range(3):
x = nets.mlp(params["mlp"][i], x, [256, 256])
key, subkey = random.split(key)
x = nets.dropout(params, subkey, x, 0.2)
x = nn.sigmoid(x)
return x
As with previous examples, we offer rewrites of this, none of which are "more elegent". Choose the one that best suits you.
Zephyr has a thread
function with specific variants such as thread_key
, thread_params
, and thread_identity
which should be enough for most cases.
Another rewrite would factor out the repeating block into its own function as follows.
def block(params, key, x):
return chain([
nets.mlp(params["mlp"], _, [256,256]),
nets.dropout(params, key, _, 0.2)
])(x)
def model(params, x, key):
blocks = thread_params([block for i in range(3)], params) # each block is block(key,x)
blocks = thread_key(blocks, key) # each block is block(x)
return chain(blocks + [nn.relu])(x)
To use it, we simply use the trace
function and use normally as follows.
trace_key, apply_key_1, apply_key_2, key = random.split(key, 4) # split the keys ;p
params = trace(model, trace_key, x, apply_key_1)
model(params, x, apply_key_2) # use it like this
# or jit it
fast_model = jit(model)
fast_model(params, x, apply_key_2)
- Documentation Strings are sparse: I'll add them soon :3.
- JAX Sharp Bits: You'll be dealing with JAX sharp bits sometimes like "str and int can't be compared" which is a jax thing, since Zephyr is such a thin library on top of JAX (it isn't even a thin wrapper). Any trouble you might have, you can open an issue and i'll help.
- Bugs: If you use it, there'll probably be bugs, if you report them, I'll work on them immediately.
- Missing nets: like RNNs, I'll add them soon when I need them or requested.
- Instability: Things are still changing a lot. I might implement other nets/layers in a different way or change names or move things.
I would like to provide more FP tooling for python in zephyr and so I could write zephyr nets in more FP-style. Zephyr itself, it's core, is probably close
to stable: mainly trace
and validate
, anything else is just to make coding easier or shorter.
This library is heavily inspired by Haiku's transform
function which eventually
converts impure functions/class-method-calls into a pure function paired with an initilized params
PyTree. This is my favorite
approach so far because it is closest to pure functional programming. Zephyr tries to push this to the simplest and make neural networks
simply just a function.
This library is also inspired by other frameworks I have tried in the past: Tensorflow and PyTorch. Tensorflow allows for shape inference to happen after the first pass of inputs, PyTorch (before the Lazy Modules) need the input shapes at layer creation. Zephyr wants to be as easy as possible and will strive to always use at-inference-time shape-inference and use relative axis positions whenever possible.