Skip to content

Commit

Permalink
Autoformatting with black and isort
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-graham committed Aug 9, 2023
1 parent fc6a954 commit 1c1f631
Show file tree
Hide file tree
Showing 25 changed files with 194 additions and 224 deletions.
20 changes: 10 additions & 10 deletions src/mici/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,24 @@
from abc import ABC, abstractmethod
from math import exp, log
from typing import TYPE_CHECKING

import numpy as np
from mici.errors import IntegratorError, AdaptationError
from mici.matrices import PositiveDiagonalMatrix, DensePositiveDefiniteMatrix

from mici.errors import AdaptationError, IntegratorError
from mici.matrices import DensePositiveDefiniteMatrix, PositiveDiagonalMatrix

if TYPE_CHECKING:
from typing import Collection, Optional, Iterable, Union
from typing import Collection, Iterable, Optional, Union

from numpy.random import Generator
from numpy.typing import ArrayLike

from mici.integrators import Integrator
from mici.states import ChainState
from mici.systems import System
from mici.transitions import Transition
from mici.types import (
AdaptationStatisticFunction,
AdapterState,
ReducerFunction,
TransitionStatistics,
)
from mici.types import (AdaptationStatisticFunction, AdapterState,
ReducerFunction, TransitionStatistics)


class Adapter(ABC):
Expand Down Expand Up @@ -486,7 +486,7 @@ def finalize(
mean_est /= n_iter
var_est += adapt_state["sum_diff_sq"]
var_est += (
mean_diff ** 2 * (adapt_state["iter"] * n_iter_prev) / n_iter
mean_diff**2 * (adapt_state["iter"] * n_iter_prev) / n_iter
)
if n_iter < 2:
raise AdaptationError(
Expand Down
1 change: 1 addition & 0 deletions src/mici/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import mici.autograd_wrapper as autograd_wrapper

if TYPE_CHECKING:
Expand Down
15 changes: 5 additions & 10 deletions src/mici/autograd_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,19 @@

AUTOGRAD_AVAILABLE = True
try:
from autograd.wrap_util import unary_to_nary
import autograd.numpy as np
from autograd.builtins import tuple as atuple
from autograd.core import make_vjp
from autograd.extend import vspace
import autograd.numpy as np
from autograd.wrap_util import unary_to_nary
except ImportError:
AUTOGRAD_AVAILABLE = False

if TYPE_CHECKING:
from typing import Callable
from mici.types import (
ScalarLike,
ArrayLike,
ScalarFunction,
ArrayFunction,
MatrixHessianProduct,
MatrixTressianProduct,
)

from mici.types import (ArrayFunction, ArrayLike, MatrixHessianProduct,
MatrixTressianProduct, ScalarFunction, ScalarLike)


def _wrapped_unary_to_nary(func: Callable) -> Callable:
Expand Down
18 changes: 9 additions & 9 deletions src/mici/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import numpy as np
from numpy.typing import ArrayLike
from mici.errors import NonReversibleStepError, AdaptationError
from mici.solvers import (
maximum_norm,
solve_fixed_point_direct,
solve_projection_onto_manifold_newton,
FixedPointSolver,
ProjectionSolver,
)

from mici.errors import AdaptationError, NonReversibleStepError
from mici.solvers import (FixedPointSolver, ProjectionSolver, maximum_norm,
solve_fixed_point_direct,
solve_projection_onto_manifold_newton)

if TYPE_CHECKING:
from typing import Any, Callable, Optional, Sequence

from mici.states import ChainState
from mici.systems import ConstrainedTractableFlowSystem, System, TractableFlowSystem
from mici.systems import (ConstrainedTractableFlowSystem, System,
TractableFlowSystem)
from mici.types import NormFunction


Expand Down
9 changes: 6 additions & 3 deletions src/mici/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@

import importlib
import os
import mici
import numpy as np
from typing import TYPE_CHECKING

import numpy as np

import mici

if TYPE_CHECKING:
from typing import Literal, Optional, Union
from numpy.typing import ArrayLike

import arviz
import pymc
from numpy.typing import ArrayLike


def convert_to_inference_data(
Expand Down
16 changes: 10 additions & 6 deletions src/mici/matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@
import abc
import numbers
from typing import TYPE_CHECKING

import numpy as np
from mici.errors import LinAlgError
import numpy.linalg as nla
import scipy.linalg as sla

from mici.errors import LinAlgError
from mici.utils import hash_array

if TYPE_CHECKING:
from typing import Iterable, Literal, Optional, Tuple, Union

from numpy.typing import NDArray

from mici.types import MatrixLike, ScalarLike


Expand Down Expand Up @@ -648,7 +652,7 @@ def grad_log_abs_det(self):
return self.shape[0] / self._scalar

def grad_quadratic_form_inv(self, vector: NDArray) -> float:
return -np.sum(vector ** 2) / self._scalar ** 2
return -np.sum(vector**2) / self._scalar**2

def __str__(self) -> str:
return f"(shape={self.shape}, scalar={self._scalar})"
Expand Down Expand Up @@ -681,7 +685,7 @@ def _construct_inv(self) -> PositiveScaledIdentityMatrix:
return PositiveScaledIdentityMatrix(1 / self._scalar, self.shape[0])

def _construct_sqrt(self) -> PositiveScaledIdentityMatrix:
return PositiveScaledIdentityMatrix(self._scalar ** 0.5, self.shape[0])
return PositiveScaledIdentityMatrix(self._scalar**0.5, self.shape[0])


class DiagonalMatrix(SymmetricMatrix, DifferentiableMatrix, ImplicitArrayMatrix):
Expand Down Expand Up @@ -766,7 +770,7 @@ def _construct_inv(self) -> PositiveDiagonalMatrix:
return PositiveDiagonalMatrix(1.0 / self.diagonal)

def _construct_sqrt(self) -> PositiveDiagonalMatrix:
return PositiveDiagonalMatrix(self.diagonal ** 0.5)
return PositiveDiagonalMatrix(self.diagonal**0.5)


def _make_array_triangular(array: NDArray, lower: bool) -> NDArray:
Expand Down Expand Up @@ -1042,7 +1046,7 @@ def __init__(
def _scalar_multiply(self, scalar: ScalarLike) -> TriangularFactoredDefiniteMatrix:
if scalar > 0:
return TriangularFactoredPositiveDefiniteMatrix(
factor=scalar ** 0.5 * self.factor
factor=scalar**0.5 * self.factor
)
else:
return super()._scalar_multiply(scalar)
Expand Down Expand Up @@ -1542,7 +1546,7 @@ def _construct_inv(self) -> EigendecomposedPositiveDefiniteMatrix:
return EigendecomposedPositiveDefiniteMatrix(self.eigvec, 1 / self.eigval)

def _construct_sqrt(self) -> EigendecomposedPositiveDefiniteMatrix:
return EigendecomposedPositiveDefiniteMatrix(self.eigvec, self.eigval ** 0.5)
return EigendecomposedPositiveDefiniteMatrix(self.eigvec, self.eigval**0.5)


class SoftAbsRegularizedPositiveDefiniteMatrix(
Expand Down
2 changes: 1 addition & 1 deletion src/mici/progressbars.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def __init__(

@property
def description(self) -> str:
""""Description of task being tracked."""
"""Description of task being tracked."""
return self._description

@property
Expand Down
86 changes: 44 additions & 42 deletions src/mici/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,62 +2,49 @@

from __future__ import annotations

import logging
import os
import queue
import signal
import tempfile
from contextlib import ExitStack, contextmanager, nullcontext
from pathlib import Path
from pickle import PicklingError
import logging
import tempfile
import signal
from warnings import warn
from typing import TYPE_CHECKING, NamedTuple
from warnings import warn

import numpy as np
from numpy.random import default_rng
from mici.transitions import (
IndependentMomentumTransition,
MetropolisRandomIntegrationTransition,
MetropolisStaticIntegrationTransition,
MultinomialDynamicIntegrationTransition,
SliceDynamicIntegrationTransition,
euclidean_no_u_turn_criterion,
riemannian_no_u_turn_criterion,
)
from mici.states import ChainState
from mici.progressbars import (
SequenceProgressBar,
LabelledSequenceProgressBar,
DummyProgressBar,
_ProxySequenceProgressBar,
)
from mici.errors import AdaptationError

from mici.adapters import DualAveragingStepSizeAdapter
from mici.errors import AdaptationError
from mici.progressbars import (DummyProgressBar, LabelledSequenceProgressBar,
SequenceProgressBar, _ProxySequenceProgressBar)
from mici.stagers import WarmUpStager, WindowedWarmUpStager
from mici.states import ChainState
from mici.transitions import (IndependentMomentumTransition,
MetropolisRandomIntegrationTransition,
MetropolisStaticIntegrationTransition,
MultinomialDynamicIntegrationTransition,
SliceDynamicIntegrationTransition,
euclidean_no_u_turn_criterion,
riemannian_no_u_turn_criterion)

if TYPE_CHECKING:
from typing import (
Container,
Generator,
Iterable,
Optional,
Sequence,
Union,
)
from typing import (Container, Generator, Iterable, Optional, Sequence,
Union)

from numpy.typing import ArrayLike, DTypeLike, NDArray

from mici.adapters import Adapter
from mici.integrators import Integrator
from mici.progressbars import ProgressBar
from mici.stagers import Stager
from mici.systems import System
from mici.transitions import IntegrationTransition, MomentumTransition, Transition
from mici.types import (
AdapterState,
ChainIterator,
ScalarLike,
PyTree,
TraceFunction,
TerminationCriterion,
)
from mici.transitions import (IntegrationTransition, MomentumTransition,
Transition)
from mici.types import (AdapterState, ChainIterator, PyTree, ScalarLike,
TerminationCriterion, TraceFunction)

# Preferentially import from multiprocess library if available as able to
# serialize much wider range of types including autograd functions
Expand Down Expand Up @@ -145,7 +132,10 @@ def _generate_memmap_filenames(


def _open_new_memmap(
file_path: str, shape: tuple[int, ...], default_val: ScalarLike, dtype: DTypeLike,
file_path: str,
shape: tuple[int, ...],
default_val: ScalarLike,
dtype: DTypeLike,
) -> np.memmap:
"""Open a new memory-mapped array object and fill with a default-value.
Expand Down Expand Up @@ -298,7 +288,15 @@ def _init_traces(
]
else:
traces[key] = list(
np.full((n_chain, n_iter,) + val.shape, init, val.dtype)
np.full(
(
n_chain,
n_iter,
)
+ val.shape,
init,
val.dtype,
)
)
return traces

Expand Down Expand Up @@ -961,7 +959,11 @@ def sample_chains(
)
)
stats = _init_stats(
self.transitions, n_chain, n_trace_iter, use_memmap, memmap_path,
self.transitions,
n_chain,
n_trace_iter,
use_memmap,
memmap_path,
)
per_chain_rngs = _get_per_chain_rngs(self.rng, n_chain)
per_chain_traces = (
Expand Down Expand Up @@ -1051,7 +1053,7 @@ class HMCSampleChainsOutputs(NamedTuple):
and main sampling stages otherwise. The key for each value is the
corresponding key in the dictionary returned by the trace function which
computed the traced value.
statistics: Dictionary of chain transition statistic dictionaries. Values in
statistics: Dictionary of chain transition statistic dictionaries. Values in
dictionary are lists of arrays of chain statistic values with each array in
the list corresponding to a single chain and the leading dimension of each
array corresponding to the iteration (draw) index, within the main
Expand Down
14 changes: 7 additions & 7 deletions src/mici/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@

from __future__ import annotations

from typing import Protocol, TYPE_CHECKING
from mici.errors import ConvergenceError, LinAlgError
from typing import TYPE_CHECKING, Protocol

import numpy as np

from mici.errors import ConvergenceError, LinAlgError

if TYPE_CHECKING:
from mici.states import ChainState
from mici.systems import (
ConstrainedEuclideanMetricSystem,
ConstrainedTractableFlowSystem,
)
from mici.types import ScalarFunction, ArrayFunction, ArrayLike
from mici.systems import (ConstrainedEuclideanMetricSystem,
ConstrainedTractableFlowSystem)
from mici.types import ArrayFunction, ArrayLike, ScalarFunction


def euclidean_norm(vct):
Expand Down
3 changes: 2 additions & 1 deletion src/mici/stagers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from __future__ import annotations

import abc
from typing import NamedTuple, TYPE_CHECKING
from typing import TYPE_CHECKING, NamedTuple

if TYPE_CHECKING:
from typing import Iterable, Optional

from mici.adapters import Adapter
from mici.types import TraceFunction

Expand Down
Loading

0 comments on commit 1c1f631

Please sign in to comment.