Skip to content

Commit

Permalink
* Bumping up the package version due to changes in the curvature bloc…
Browse files Browse the repository at this point in the history
…k classes.

PiperOrigin-RevId: 532513111
  • Loading branch information
botev authored and KfacJaxDev committed May 16, 2023
1 parent 55e3b4e commit aa73ceb
Showing 1 changed file with 27 additions and 14 deletions.
41 changes: 27 additions & 14 deletions kfac_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from kfac_jax._src import utils


__version__ = "0.0.4"
__version__ = "0.0.5"

# Patches Second Moments
patches_moments = patches_second_moment.patches_moments
Expand Down Expand Up @@ -53,16 +53,25 @@
NegativeLogProbLoss = loss_functions.NegativeLogProbLoss
DistributionNegativeLogProbLoss = loss_functions.DistributionNegativeLogProbLoss
NormalMeanNegativeLogProbLoss = loss_functions.NormalMeanNegativeLogProbLoss
NormalMeanVarianceNegativeLogProbLoss = loss_functions.NormalMeanVarianceNegativeLogProbLoss
MultiBernoulliNegativeLogProbLoss = loss_functions.MultiBernoulliNegativeLogProbLoss
CategoricalLogitsNegativeLogProbLoss = loss_functions.CategoricalLogitsNegativeLogProbLoss
OneHotCategoricalLogitsNegativeLogProbLoss = loss_functions.OneHotCategoricalLogitsNegativeLogProbLoss
register_sigmoid_cross_entropy_loss = loss_functions.register_sigmoid_cross_entropy_loss
register_multi_bernoulli_predictive_distribution = loss_functions.register_multi_bernoulli_predictive_distribution
register_softmax_cross_entropy_loss = loss_functions.register_softmax_cross_entropy_loss
register_categorical_predictive_distribution = loss_functions.register_categorical_predictive_distribution
NormalMeanVarianceNegativeLogProbLoss = (
loss_functions.NormalMeanVarianceNegativeLogProbLoss)
MultiBernoulliNegativeLogProbLoss = (
loss_functions.MultiBernoulliNegativeLogProbLoss)
CategoricalLogitsNegativeLogProbLoss = (
loss_functions.CategoricalLogitsNegativeLogProbLoss)
OneHotCategoricalLogitsNegativeLogProbLoss = (
loss_functions.OneHotCategoricalLogitsNegativeLogProbLoss)
register_sigmoid_cross_entropy_loss = (
loss_functions.register_sigmoid_cross_entropy_loss)
register_multi_bernoulli_predictive_distribution = (
loss_functions.register_multi_bernoulli_predictive_distribution)
register_softmax_cross_entropy_loss = (
loss_functions.register_softmax_cross_entropy_loss)
register_categorical_predictive_distribution = (
loss_functions.register_categorical_predictive_distribution)
register_squared_error_loss = loss_functions.register_squared_error_loss
register_normal_predictive_distribution = loss_functions.register_normal_predictive_distribution
register_normal_predictive_distribution = (
loss_functions.register_normal_predictive_distribution)

# Curvature blocks
CurvatureBlock = curvature_blocks.CurvatureBlock
Expand All @@ -83,16 +92,20 @@
ScaleAndShiftFull = curvature_blocks.ScaleAndShiftFull
set_max_parallel_elements = curvature_blocks.set_max_parallel_elements
get_max_parallel_elements = curvature_blocks.get_max_parallel_elements
set_default_eigen_decomposition_threshold = curvature_blocks.set_default_eigen_decomposition_threshold
get_default_eigen_decomposition_threshold = curvature_blocks.get_default_eigen_decomposition_threshold
set_default_eigen_decomposition_threshold = (
curvature_blocks.set_default_eigen_decomposition_threshold)
get_default_eigen_decomposition_threshold = (
curvature_blocks.get_default_eigen_decomposition_threshold)

# Curvature estimators
CurvatureEstimator = curvature_estimator.CurvatureEstimator
BlockDiagonalCurvature = curvature_estimator.BlockDiagonalCurvature
ExplicitExactCurvature = curvature_estimator.ExplicitExactCurvature
ImplicitExactCurvature = curvature_estimator.ImplicitExactCurvature
set_default_tag_to_block_ctor = curvature_estimator.set_default_tag_to_block_ctor
get_default_tag_to_block_ctor = curvature_estimator.get_default_tag_to_block_ctor
set_default_tag_to_block_ctor = (
curvature_estimator.set_default_tag_to_block_ctor)
get_default_tag_to_block_ctor = (
curvature_estimator.get_default_tag_to_block_ctor)

# Optimizers
Optimizer = optimizer.Optimizer
Expand Down

0 comments on commit aa73ceb

Please sign in to comment.