A helper library for training dm-haiku models.
-
Updated
Feb 7, 2021 - Python
A helper library for training dm-haiku models.
JAX implementations of various deep reinforcement learning algorithms.
Direct port of TD3_BC to JAX using Haiku and optax.
The (unofficial) vanilla version of WaveRNN
An gradient-free optimization suite written in JAX. We conform to the optax interface and provide ensemble-based optimizers.
Simple JAX implementation of Gaussian process regression
Neural implicit digital elevation model
H-Former is a VAE for generating in-between fonts (or combining fonts). Its encoder uses a Point net and transformer to compute a code vector of glyph. Its decoder is composed of multiple independent decoders which act on a code vector to reconstruct a point cloud representing a glpyh.
A library which trains the Fermionic Neural Network to find the ground state wave functions of an atom or a molecule using neural network quantum states.
An implementation of adan optimizer for optax
JAX implementation of Classical and Quantum Algorithms for Orthogonal Neural Networks by (Kerenidis et al., 2021)
JAX/Flax implementation of finite-size scaling
dm-haiku implementation of hyperbolic neural networks
Stochastic Weight Averaging (SWA) transforms for Optax with JAX
Oxford MSc thesis. PriorVAE with graph convolutional networks for learning locally-aware spatial prior distributions
This repository contains some of the code I wrote for the assignments in DSA4212 - Optimisation for Large-Scale Data-Driven Inference.
An implementation of MNIST classification using LeNet-300-100 in JAX (using Haiku and Optax).
Add a description, image, and links to the optax topic page so that developers can more easily learn about it.
To associate your repository with the optax topic, visit your repo's landing page and select "manage topics."