Skip to content

Commit

Permalink
Implement jax.stax save/load optimizer states
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Aug 21, 2024
1 parent 717fb0f commit 0eae034
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions phiml/backend/jax/stax_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class JaxOptimizer:

def __init__(self, initialize: Callable, update: Callable, get_params: Callable):
self._initialize, self._update, self._get_params = initialize, update, get_params # Stax functions
self._state = None
self._state = None # List[Tuple[T,T,T]]: (parameter, m, v)
self._step_i = 0
self._update_function_cache = {}

Expand Down Expand Up @@ -160,23 +160,36 @@ def _recursive_add_parameters(param, wrap: bool, prefix: tuple, result: dict):


def save_state(obj: Union[StaxNet, JaxOptimizer], path: str):
if not path.endswith('.npy'):
path += '.npy'
if isinstance(obj, StaxNet):
if not path.endswith('.npy'):
path += '.npy'
numpy.save(path, obj.parameters)
elif isinstance(obj, JaxOptimizer):
if not path.endswith('.npz'):
path += '.npz'
ms = [np.asarray(m) for x, m, v in obj._state.packed_state]
vs = [np.asarray(v) for x, m, v in obj._state.packed_state]
np.savez(path, m=np.array(ms, dtype=object), v=np.array(vs, dtype=object), allow_pickle=True)
else:
raise NotImplementedError # ToDo
# numpy.save(path, obj._state)
raise ValueError(f"obj must be a network or optimizer but got {type(obj)}")


def load_state(obj: Union[StaxNet, JaxOptimizer], path: str):
if not path.endswith('.npy'):
path += '.npy'
if isinstance(obj, StaxNet):
if not path.endswith('.npy'):
path += '.npy'
state = numpy.load(path, allow_pickle=True)
obj.parameters = tuple([tuple(layer) for layer in state])
else:
raise NotImplementedError # ToDo
if not path.endswith('.npz'):
path += '.npz'
xs = [x for x, m, v in obj._state.packed_state]
data = np.load(path, allow_pickle=True)
ms, vs = data['m'], data['v']
ms = [jnp.array(m) for m in ms]
vs = [jnp.array(v) for v in vs]
packed_state = [(x, m, v) for x, m, v in zip(xs, ms, vs)]
obj._state = OptimizerState(packed_state, obj._state.tree_def, obj._state.subtree_defs)


def update_weights(net: StaxNet, optimizer: JaxOptimizer, loss_function: Callable, *loss_args, **loss_kwargs):
Expand Down

0 comments on commit 0eae034

Please sign in to comment.