Skip to content

Commit

Permalink
Play with optimizer options
Browse files Browse the repository at this point in the history
  • Loading branch information
SamTov committed Aug 9, 2023
1 parent 5d403a4 commit b093a08
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 17 deletions.
2 changes: 2 additions & 0 deletions znnl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
distance_metrics,
loss_functions,
models,
optimizers,
point_selection,
similarity_measures,
training_recording,
Expand All @@ -51,6 +52,7 @@
distance_metrics.__name__,
loss_functions.__name__,
accuracy_functions.__name__,
optimizers.__name__,
models.__name__,
point_selection.__name__,
similarity_measures.__name__,
Expand Down
9 changes: 5 additions & 4 deletions znnl/models/jax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import optax
from flax.training.train_state import TrainState

from znnl.optimizers.partitioned_trace_optimizer import PartitionedTraceOptimizer
from znnl.optimizers.trace_optimizer import TraceOptimizer
from znnl.utils.prng import PRNGKey

Expand Down Expand Up @@ -80,10 +81,10 @@ def __init__(
self.init_model(seed)

# Prepare NTK calculation
self.empirical_ntk = nt.batch(
nt.empirical_ntk_fn(f=self._ntk_apply_fn, trace_axes=trace_axes),
batch_size=ntk_batch_size,
self.empirical_ntk = nt.empirical_ntk_fn(
f=self._ntk_apply_fn, trace_axes=trace_axes
)

self.empirical_ntk_jit = jax.jit(self.empirical_ntk)

def init_model(
Expand Down Expand Up @@ -122,7 +123,7 @@ def _create_train_state(
params = self._init_params(kernel_init, bias_init)

# Set dummy optimizer for case of trace optimizer.
if isinstance(self.optimizer, TraceOptimizer):
if isinstance(self.optimizer, (TraceOptimizer, PartitionedTraceOptimizer)):
optimizer = optax.sgd(1.0)
else:
optimizer = self.optimizer
Expand Down
3 changes: 2 additions & 1 deletion znnl/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Summary
-------
"""
from znnl.optimizers.partitioned_trace_optimizer import PartitionedTraceOptimizer
from znnl.optimizers.trace_optimizer import TraceOptimizer

__all__ = [TraceOptimizer.__name__]
__all__ = [TraceOptimizer.__name__, PartitionedTraceOptimizer.__name__]
159 changes: 159 additions & 0 deletions znnl/optimizers/partitioned_trace_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""
ZnNL: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
"""
from dataclasses import dataclass
from typing import Callable

import jax.numpy as np
import numpy as onp
import optax
from flax.training.train_state import TrainState


@dataclass
class PartitionedTraceOptimizer:
"""
Class implementation of the trace optimizer
Attributes
----------
scale_factor : float
Scale factor to apply to the optimizer.
rescale_interval : int
Number of epochs to wait before re-scaling the learning rate.
subset : float
What percentage of data you want to use in the trace calculation.
"""

scale_factor: float
rescale_interval: float = 1
subset: float = None

_start_value = None

@optax.inject_hyperparams
def optimizer(self, learning_rate):
return optax.sgd(learning_rate)

def apply_optimizer(
self,
model_state: TrainState,
data_set: np.ndarray,
ntk_fn: Callable,
epoch: int,
):
"""
Apply the optimizer to a model state.
Parameters
----------
model_state : TrainState
Current state of the model
data_set : jnp.ndarray
Data-set to use in the computation.
ntk_fn : Callable
Function to use for the NTK computation
epoch : int
Current epoch
Returns
-------
new_state : TrainState
New state of the model
"""
eps = 1e-8

partitions = {}

number_of_classes = np.unique(data_set["targets"], axis=0)

for i in range(number_of_classes.shape[0]):
indices = np.where(data_set["targets"].argmax(-1) == i)[0]

partitions[i] = np.take(data_set["inputs"], indices, axis=0)

if self._start_value is None:
if self.subset is not None:
init_data_set = {}
for ds in partitions:
subset_size = int(self.subset * partitions[ds].shape[0])
init_data_set[ds] = np.take(
partitions[ds],
onp.random.randint(
0, partitions[ds].shape[0] - 1, size=subset_size
),
axis=0,
)
else:
init_data_set = data_set

start_trace = 0

for ds in init_data_set:
ntk = ntk_fn(init_data_set[ds])["empirical"]
start_trace += np.trace(ntk)

self._start_value = np.trace(ntk)

# Check if the update should be performed.
if epoch % self.rescale_interval == 0:
# Select a subset of the data
if self.subset is not None:
data_set = {}

for ds in partitions:
subset_size = int(self.subset * partitions[ds].shape[0])
data_set[ds] = np.take(
partitions[ds],
onp.random.randint(
0, partitions[ds].shape[0] - 1, size=subset_size
),
axis=0,
)

# Compute the ntk trace.
trace = 0.0

for ds in data_set:
ntk = ntk_fn(data_set[ds])["empirical"]
trace += np.trace(ntk)

# Create the new optimizer.
new_optimizer = self.optimizer(
(self.scale_factor * self._start_value) / (trace + eps)
)

# Create the new state
new_state = TrainState.create(
apply_fn=model_state.apply_fn,
params=model_state.params,
tx=new_optimizer,
)
else:
# If no update is needed, return the old state.
new_state = model_state

return new_state
14 changes: 10 additions & 4 deletions znnl/optimizers/trace_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ class TraceOptimizer:
scale_factor: float
rescale_interval: float = 1
subset: Union[float, list] = None
memory: int = 1

_start_value = None
_start_value = []

@optax.inject_hyperparams
def optimizer(self, learning_rate):
Expand Down Expand Up @@ -84,9 +85,10 @@ def apply_optimizer(
new_state : TrainState
New state of the model
"""
data_set = data_set["inputs"]
eps = 1e-8

if self._start_value is None:
if self._start_value == []:
if self.subset is not None:
if isinstance(self.subset, float):
subset_size = int(self.subset * data_set.shape[0])
Expand All @@ -100,7 +102,7 @@ def apply_optimizer(
else:
init_data_set = data_set
ntk = ntk_fn(init_data_set)["empirical"]
self._start_value = np.trace(ntk)
self._start_value.append(np.trace(ntk))

# Check if the update should be performed.
if epoch % self.rescale_interval == 0:
Expand All @@ -120,10 +122,14 @@ def apply_optimizer(
ntk = ntk_fn(data_set)["empirical"]
trace = np.trace(ntk)

memory_index = int(np.clip(epoch - self.memory, 0, epoch))
memory_index = 0

# Create the new optimizer.
new_optimizer = self.optimizer(
(self.scale_factor * self._start_value) / (trace + eps)
(self.scale_factor * self._start_value[memory_index]) / (trace + eps)
)
self._start_value.append(trace)

# Create the new state
new_state = TrainState.create(
Expand Down
18 changes: 10 additions & 8 deletions znnl/training_strategies/simple_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from znnl.accuracy_functions.accuracy_function import AccuracyFunction
from znnl.models.jax_model import JaxModel
from znnl.optimizers.partitioned_trace_optimizer import PartitionedTraceOptimizer
from znnl.optimizers.trace_optimizer import TraceOptimizer
from znnl.training_recording import JaxRecorder
from znnl.training_strategies.recursive_mode import RecursiveMode
Expand Down Expand Up @@ -199,6 +200,15 @@ def _train_step(self, state: TrainState, batch: dict):
metrics : dict
Metrics for the current model.
"""
if isinstance(
self.model.optimizer, (TraceOptimizer, PartitionedTraceOptimizer)
):
state = self.model.optimizer.apply_optimizer(
model_state=state,
data_set=batch,
ntk_fn=self.model.compute_ntk,
epoch=1,
)

def loss_fn(params):
"""
Expand Down Expand Up @@ -373,14 +383,6 @@ def train_model(

loading_bar.set_description(f"Epoch: {i}")

if isinstance(self.model.optimizer, TraceOptimizer):
state = self.model.optimizer.apply_optimizer(
model_state=state,
data_set=train_ds["inputs"],
ntk_fn=self.model.compute_ntk,
epoch=i,
)

state, train_metrics = self._train_epoch(
state, train_ds, batch_size=batch_size
)
Expand Down

0 comments on commit b093a08

Please sign in to comment.