Skip to content

Commit

Permalink
* Fixing a bug in the ScaleAndShiftDiagonal, which reduces all dimens…
Browse files Browse the repository at this point in the history
…ions of size 1 when computing gradients, even if the parameter has a one axis of size 1. This leads to incorrect broadcasting later when multiplying with the curvature matrix.

* Fix a few smaller bugs in computing Fisher/GGN factor products in the implicit curvature. This does not affect any of the previous usage of the Optimizer class.

PiperOrigin-RevId: 453427834
  • Loading branch information
botev authored and KfacJaxDev committed Jun 7, 2022
1 parent bb761e2 commit c31c24d
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 22 deletions.
12 changes: 4 additions & 8 deletions kfac_jax/_src/curvature_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,8 +1539,7 @@ def _update_curvature_matrix_estimate(
assert (state.diagonal_factors[0].raw_value.shape ==
self.parameters_shapes[0])
scale_shape = estimation_data["params"][0].shape
full_scale_shape = (1,) * (len(x.shape) - len(scale_shape)) + scale_shape
axis = [i for i, s in enumerate(full_scale_shape) if s == 1 and i != 0]
axis = range(x.ndim)[1:(x.ndim - len(scale_shape))]
d_scale = jnp.sum(x * dy, axis=tuple(axis))
scale_diag_update = jnp.sum(d_scale * d_scale, axis=0) / batch_size
state.diagonal_factors[0].update(scale_diag_update, ema_old, ema_new)
Expand All @@ -1549,8 +1548,7 @@ def _update_curvature_matrix_estimate(
assert (state.diagonal_factors[-1].raw_value.shape ==
self.parameters_shapes[-1])
shift_shape = estimation_data["params"][-1].shape
full_shift_shape = (1,) * (len(x.shape) - len(shift_shape)) + shift_shape
axis = [i for i, s in enumerate(full_shift_shape) if s == 1 and i != 0]
axis = range(x.ndim)[1:(x.ndim - len(shift_shape))]
d_shift = jnp.sum(dy, axis=tuple(axis))
shift_diag_update = jnp.sum(d_shift * d_shift, axis=0) / batch_size
state.diagonal_factors[-1].update(shift_diag_update, ema_old, ema_new)
Expand Down Expand Up @@ -1589,17 +1587,15 @@ def update_curvature_matrix_estimate(
if self._has_scale:
# Scale tangent
scale_shape = estimation_data["params"][0].shape
full_scale_shape = (1,) * (len(x.shape) - len(scale_shape)) + scale_shape
axis = [i for i, s in enumerate(full_scale_shape) if s == 1 and i != 0]
axis = range(x.ndim)[1:(x.ndim - len(scale_shape))]
d_scale = jnp.sum(x * dy, axis=tuple(axis))
d_scale = d_scale.reshape([batch_size, -1])
tangents.append(d_scale)

if self._has_shift:
# Shift tangent
shift_shape = estimation_data["params"][-1].shape
full_shift_shape = (1,) * (len(x.shape) - len(shift_shape)) + shift_shape
axis = [i for i, s in enumerate(full_shift_shape) if s == 1 and i != 0]
axis = range(x.ndim)[1:(x.ndim - len(shift_shape))]
d_shift = jnp.sum(dy, axis=tuple(axis))
d_shift = d_shift.reshape([batch_size, -1])
tangents.append(d_shift)
Expand Down
16 changes: 8 additions & 8 deletions kfac_jax/_src/curvature_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _multiply_loss_ggn(
def _multiply_loss_fisher_factor(
cls,
losses: Sequence[loss_functions.NegativeLogProbLoss],
loss_inner_vectors: Sequence[Sequence[chex.Array]]
loss_inner_vectors: Sequence[chex.Array],
) -> Tuple[Tuple[chex.Array, ...], ...]:
"""Multiplies the vectors with the Fisher factors of each loss.
Expand All @@ -175,14 +175,14 @@ def _multiply_loss_fisher_factor(
losses.
"""
assert len(losses) == len(loss_inner_vectors)
return tuple(loss.multiply_fisher_factor(*vec)
return tuple(loss.multiply_fisher_factor(vec)
for loss, vec in zip(losses, loss_inner_vectors))

@classmethod
def _multiply_loss_ggn_factor(
cls,
losses: Sequence[loss_functions.LossFunction],
loss_inner_vectors: Sequence[Sequence[chex.Array]]
loss_inner_vectors: Sequence[chex.Array],
) -> Tuple[Tuple[chex.Array, ...], ...]:
"""Multiplies the vectors with the GGN factors of each loss.
Expand All @@ -195,7 +195,7 @@ def _multiply_loss_ggn_factor(
The product of all vectors with the factors of the GGN of each the
losses.
"""
return tuple(loss.multiply_ggn_factor(*vec)
return tuple(loss.multiply_ggn_factor(vec)
for loss, vec in zip(losses, loss_inner_vectors))

@classmethod
Expand Down Expand Up @@ -396,7 +396,7 @@ def multiply_ggn_factor_transpose(
def multiply_fisher_factor(
self,
func_args: utils.FuncArgs,
loss_inner_vectors: Sequence[Sequence[chex.Array]],
loss_inner_vectors: Sequence[chex.Array],
) -> utils.Params:
"""Multiplies the vector with the factor of the Fisher matrix.
Expand All @@ -410,7 +410,7 @@ def multiply_fisher_factor(
The product ``Bv``, where ``F = BB^T``.
"""
losses: Sequence[loss_functions.NegativeLogProbLoss]
losses, vjp = self._loss_tags_vjp(*func_args)
losses, vjp = self._loss_tags_vjp(func_args)
if any(not isinstance(l, loss_functions.NegativeLogProbLoss)
for l in losses):
raise ValueError("To use `multiply_fisher` all registered losses must "
Expand All @@ -437,7 +437,7 @@ def multiply_ggn_factor(
Returns:
The product ``Bv``, where ``G = BB^T``.
"""
losses, vjp = self._loss_tags_vjp(*func_args)
losses, vjp = self._loss_tags_vjp(func_args)
fisher_factor_transpose_vectors = self._multiply_loss_ggn_factor(
losses, loss_inner_vectors)
vectors = vjp(fisher_factor_transpose_vectors)
Expand Down Expand Up @@ -504,7 +504,7 @@ def dim(self) -> int:
@abc.abstractmethod
def init(
self,
rng: chex.Array,
rng: chex.PRNGKey,
func_args: utils.FuncArgs,
exact_powers_to_cache: Optional[curvature_blocks.ScalarOrSequence],
approx_powers_to_cache: Optional[curvature_blocks.ScalarOrSequence],
Expand Down
88 changes: 82 additions & 6 deletions tests/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
]


LINEAR_MODELS_AND_CURVATURE_TYPE = [
model + ("ggn",) for model in models.LINEAR_MODELS
] + [
model + ("fisher",) for model in models.LINEAR_MODELS
]


PIECEWISE_LINEAR_MODELS_AND_CURVATURE = [
model + ("ggn",) for model in models.PIECEWISE_LINEAR_MODELS
] + [
Expand Down Expand Up @@ -149,14 +156,14 @@ def mul_e_i(index, *_):
# Compare
self.assertAllClose(matrix, explicit_exact_matrix)

@parameterized.parameters(models.NON_LINEAR_MODELS)
@parameterized.parameters(NON_LINEAR_MODELS_AND_CURVATURE_TYPE)
def test_block_diagonal_full(
self,
init_func: Callable[..., models.hk.Params],
model_func: Callable[..., chex.Array],
data_point_shapes: Mapping[str, chex.Shape],
seed: int,
curvature_type: str = "fisher",
curvature_type: str,
data_size: int = 4,
):
"""Tests that the block diagonal full is equal to the explicit curvature."""
Expand Down Expand Up @@ -282,14 +289,14 @@ def mul_e_i(index, *_):
d = d + block.shape[0]
self.assertEqual(d, hessian.shape[0])

@parameterized.parameters(models.NON_LINEAR_MODELS)
@parameterized.parameters(NON_LINEAR_MODELS_AND_CURVATURE_TYPE)
def test_diagonal(
self,
init_func: Callable[..., models.hk.Params],
model_func: Callable[..., chex.Array],
data_point_shapes: Mapping[str, chex.Shape],
seed: int,
curvature_type: str = "fisher",
curvature_type: str,
data_size: int = 4,
):
"""Tests that the diagonal estimation is the diagonal of the full."""
Expand Down Expand Up @@ -348,7 +355,7 @@ def test_diagonal(
for diagonal, block in zip(diagonals, blocks):
self.assertAllClose(diagonal, jnp.diag(jnp.diag(block)))

@parameterized.parameters(models.LINEAR_MODELS)
@parameterized.parameters(LINEAR_MODELS_AND_CURVATURE_TYPE)
def test_kronecker_factored(
self,
init_func: Callable[..., models.hk.Params],
Expand Down Expand Up @@ -421,6 +428,12 @@ def test_kronecker_factored(
(
dict(images=(32, 32, 3), labels=(10,)),
1230971,
"ggn",
),
(
dict(images=(32, 32, 3), labels=(10,)),
1230971,
"fisher",
),
])
def test_eigenvalues(
Expand Down Expand Up @@ -508,13 +521,19 @@ def test_eigenvalues(
(
dict(images=(32, 32, 3), labels=(10,)),
1230971,
"ggn",
),
(
dict(images=(32, 32, 3), labels=(10,)),
1230971,
"fisher",
),
])
def test_matmul(
self,
data_point_shapes: Mapping[str, chex.Shape],
seed: int,
curvature_type: str = "fisher",
curvature_type: str,
data_size: int = 4,
e: float = 1.0,
):
Expand Down Expand Up @@ -597,6 +616,63 @@ def test_matmul(
computed2 = jnp.linalg.solve(m_i_plus_eye, v_i_flat)
self.assertAllClose(computed2, r2_i_flat, atol=1e-5, rtol=1e-4)

@parameterized.parameters([
(
dict(images=(32, 32, 3), labels=(10,)),
1230971,
"ggn",
),
(
dict(images=(32, 32, 3), labels=(10,)),
1230971,
"fisher",
),
])
def test_implicit_factor_products(
self,
data_point_shapes: Mapping[str, chex.Shape],
seed: int,
curvature_type: str,
data_size: int = 4,
):
"""Tests that the products of the curvature factors are correct."""
num_classes = data_point_shapes["labels"][0]
init_func = models.conv_classifier(
num_classes=num_classes, layer_channels=[8, 16, 32]).init
model_func = functools.partial(
models.conv_classifier_loss,
num_classes=num_classes,
layer_channels=[8, 16, 32])

rng_key = jax.random.PRNGKey(seed)
init_key1, init_key2, data_key = jax.random.split(rng_key, 3)

# Generate data
data = {}
for name, shape in data_point_shapes.items():
data_key, key = jax.random.split(data_key)
data[name] = jax.random.uniform(key, (data_size, *shape))
if name == "labels":
data[name] = jnp.argmax(data[name], axis=-1)

params = init_func(init_key1, data)
func_args = (params, data)
estimator = kfac_jax.ImplicitExactCurvature(model_func)

v = init_func(init_key2, data)
if curvature_type == "fisher":
c_factor_v = estimator.multiply_fisher_factor_transpose(func_args, v)
c_v_1 = estimator.multiply_fisher_factor(func_args, c_factor_v)
c_v_2 = estimator.multiply_fisher(func_args, v)
elif curvature_type == "ggn":
c_factor_v = estimator.multiply_ggn_factor_transpose(func_args, v)
c_v_1 = estimator.multiply_ggn_factor(func_args, c_factor_v)
c_v_2 = estimator.multiply_ggn(func_args, v)
else:
raise NotImplementedError()

self.assertAllClose(c_v_1, c_v_2, atol=1e-6, rtol=1e-6)


if __name__ == "__main__":
absltest.main()

0 comments on commit c31c24d

Please sign in to comment.