Skip to content

A reimplementation of Parallel DNN Training in JAX by Will Whitney using haiku and optax.

Notifications You must be signed in to change notification settings

alexjackson1/hk-parallel

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Parallel Training of Neural Networks with JAX

This repository trains a simple neural network multiple times (with different seed values) on the same device (using JAX). It is largely a reimplementation of Parallel Training JAX by Will Whitney, except this notebook uses haiku and optax instead of flax.

Installation

The following packages are required to run the notebook:

  • JAX (jax)
  • Haiku (dm-haiku)
  • Optax (optax)

All of the dependencies can be installed using pip from the requirements.txt file:

pip install -r requirements.txt

About

A reimplementation of Parallel DNN Training in JAX by Will Whitney using haiku and optax.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published