diff --git a/kfac_jax/__init__.py b/kfac_jax/__init__.py index 6b4dfa1..70ec2bc 100644 --- a/kfac_jax/__init__.py +++ b/kfac_jax/__init__.py @@ -24,7 +24,7 @@ from kfac_jax._src import utils -__version__ = "0.0.4" +__version__ = "0.0.5" # Patches Second Moments patches_moments = patches_second_moment.patches_moments @@ -53,16 +53,25 @@ NegativeLogProbLoss = loss_functions.NegativeLogProbLoss DistributionNegativeLogProbLoss = loss_functions.DistributionNegativeLogProbLoss NormalMeanNegativeLogProbLoss = loss_functions.NormalMeanNegativeLogProbLoss -NormalMeanVarianceNegativeLogProbLoss = loss_functions.NormalMeanVarianceNegativeLogProbLoss -MultiBernoulliNegativeLogProbLoss = loss_functions.MultiBernoulliNegativeLogProbLoss -CategoricalLogitsNegativeLogProbLoss = loss_functions.CategoricalLogitsNegativeLogProbLoss -OneHotCategoricalLogitsNegativeLogProbLoss = loss_functions.OneHotCategoricalLogitsNegativeLogProbLoss -register_sigmoid_cross_entropy_loss = loss_functions.register_sigmoid_cross_entropy_loss -register_multi_bernoulli_predictive_distribution = loss_functions.register_multi_bernoulli_predictive_distribution -register_softmax_cross_entropy_loss = loss_functions.register_softmax_cross_entropy_loss -register_categorical_predictive_distribution = loss_functions.register_categorical_predictive_distribution +NormalMeanVarianceNegativeLogProbLoss = ( + loss_functions.NormalMeanVarianceNegativeLogProbLoss) +MultiBernoulliNegativeLogProbLoss = ( + loss_functions.MultiBernoulliNegativeLogProbLoss) +CategoricalLogitsNegativeLogProbLoss = ( + loss_functions.CategoricalLogitsNegativeLogProbLoss) +OneHotCategoricalLogitsNegativeLogProbLoss = ( + loss_functions.OneHotCategoricalLogitsNegativeLogProbLoss) +register_sigmoid_cross_entropy_loss = ( + loss_functions.register_sigmoid_cross_entropy_loss) +register_multi_bernoulli_predictive_distribution = ( + loss_functions.register_multi_bernoulli_predictive_distribution) +register_softmax_cross_entropy_loss = ( + loss_functions.register_softmax_cross_entropy_loss) +register_categorical_predictive_distribution = ( + loss_functions.register_categorical_predictive_distribution) register_squared_error_loss = loss_functions.register_squared_error_loss -register_normal_predictive_distribution = loss_functions.register_normal_predictive_distribution +register_normal_predictive_distribution = ( + loss_functions.register_normal_predictive_distribution) # Curvature blocks CurvatureBlock = curvature_blocks.CurvatureBlock @@ -83,16 +92,20 @@ ScaleAndShiftFull = curvature_blocks.ScaleAndShiftFull set_max_parallel_elements = curvature_blocks.set_max_parallel_elements get_max_parallel_elements = curvature_blocks.get_max_parallel_elements -set_default_eigen_decomposition_threshold = curvature_blocks.set_default_eigen_decomposition_threshold -get_default_eigen_decomposition_threshold = curvature_blocks.get_default_eigen_decomposition_threshold +set_default_eigen_decomposition_threshold = ( + curvature_blocks.set_default_eigen_decomposition_threshold) +get_default_eigen_decomposition_threshold = ( + curvature_blocks.get_default_eigen_decomposition_threshold) # Curvature estimators CurvatureEstimator = curvature_estimator.CurvatureEstimator BlockDiagonalCurvature = curvature_estimator.BlockDiagonalCurvature ExplicitExactCurvature = curvature_estimator.ExplicitExactCurvature ImplicitExactCurvature = curvature_estimator.ImplicitExactCurvature -set_default_tag_to_block_ctor = curvature_estimator.set_default_tag_to_block_ctor -get_default_tag_to_block_ctor = curvature_estimator.get_default_tag_to_block_ctor +set_default_tag_to_block_ctor = ( + curvature_estimator.set_default_tag_to_block_ctor) +get_default_tag_to_block_ctor = ( + curvature_estimator.get_default_tag_to_block_ctor) # Optimizers Optimizer = optimizer.Optimizer