Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 8, 2023
1 parent 536e828 commit f387081
Show file tree
Hide file tree
Showing 24 changed files with 75 additions and 81 deletions.
10 changes: 6 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ repos:
args:
- --fix

- repo: https://github.com/psf/black
rev: 23.11.0
- repo: https://github.com/psf/black-pre-commit-mirror
rev: PLEASE-UPDATE
hooks:
- id: black
- id: black-jupyter
args: [--line-length=85]
types_or: [jupyter]
args:
- --line-length=85
types_or:
- jupyter

- repo: https://github.com/asottile/blacken-docs
rev: 1.16.0
Expand Down
10 changes: 4 additions & 6 deletions benchmarks/ampform.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,7 @@ def test_fit(self, backend, benchmark, model, size):
def print_data_sample(data: DataSample, sample_size: int) -> None:
"""Print a `.DataSample`, so it can be pasted into the expected sample."""
print() # noqa: T201
pprint( # noqa: T203
{
i: np.round(four_momenta[:sample_size], decimals=11).tolist()
for i, four_momenta in data.items()
}
)
pprint({ # noqa: T203
i: np.round(four_momenta[:sample_size], decimals=11).tolist()
for i, four_momenta in data.items()
})
2 changes: 1 addition & 1 deletion benchmarks/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_data(backend, benchmark, size):
def test_fit(
backend: str,
benchmark,
optimizer_type: (type[Minuit2] | type[ScipyMinimizer]),
optimizer_type: type[Minuit2] | type[ScipyMinimizer],
size: int,
):
domain, data = generate_data_and_domain(backend, n_data=size, n_domain=10 * size)
Expand Down
1 change: 1 addition & 0 deletions docs/_relink_references.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
See also https://github.com/sphinx-doc/sphinx/issues/5868.
"""

# pyright: reportMissingImports=false
from __future__ import annotations

Expand Down
24 changes: 10 additions & 14 deletions docs/amplitude-analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -456,13 +456,11 @@
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"pd.DataFrame(\n",
" {\n",
" (k, label): np.transpose(v)[i]\n",
" for k, v in phsp_momenta.items()\n",
" for i, label in enumerate([\"E\", \"px\", \"py\", \"pz\"])\n",
" }\n",
")"
"pd.DataFrame({\n",
" (k, label): np.transpose(v)[i]\n",
" for k, v in phsp_momenta.items()\n",
" for i, label in enumerate([\"E\", \"px\", \"py\", \"pz\"])\n",
"})"
]
},
{
Expand Down Expand Up @@ -628,13 +626,11 @@
" domain_transformer=helicity_transformer,\n",
")\n",
"data_momenta = data_generator.generate(10_000, rng)\n",
"pd.DataFrame(\n",
" {\n",
" (k, label): np.transpose(v)[i]\n",
" for k, v in data_momenta.items()\n",
" for i, label in enumerate([\"E\", \"px\", \"py\", \"pz\"])\n",
" }\n",
")"
"pd.DataFrame({\n",
" (k, label): np.transpose(v)[i]\n",
" for k, v in data_momenta.items()\n",
" for i, label in enumerate([\"E\", \"px\", \"py\", \"pz\"])\n",
"})"
]
},
{
Expand Down
22 changes: 10 additions & 12 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,16 @@ def fetch_logo(url: str, output_path: str) -> None:
relink_references()
shutil.rmtree("api", ignore_errors=True)
subprocess.call(
" ".join(
[
"sphinx-apidoc",
f"../src/{PACKAGE}/",
f"../src/{PACKAGE}/version.py",
"-o api/",
"--force",
"--no-toc",
"--templatedir _templates",
"--separate",
]
),
" ".join([
"sphinx-apidoc",
f"../src/{PACKAGE}/",
f"../src/{PACKAGE}/version.py",
"-o api/",
"--force",
"--no-toc",
"--templatedir _templates",
"--separate",
]),
shell=True, # noqa: S602
)

Expand Down
6 changes: 3 additions & 3 deletions docs/usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -804,9 +804,9 @@
")\n",
"def plot(dphi, k_r, k_phi, sigma):\n",
" global color_mesh, X, Y\n",
" polar_function.update_parameters(\n",
" {R\"\\Delta\\phi\": dphi, \"k_r\": k_r, \"k_phi\": k_phi, \"sigma\": sigma}\n",
" )\n",
" polar_function.update_parameters({\n",
" R\"\\Delta\\phi\": dphi, \"k_r\": k_r, \"k_phi\": k_phi, \"sigma\": sigma\n",
" })\n",
" Z = polar_function(polar_domain)\n",
" if color_mesh is not None:\n",
" color_mesh.remove()\n",
Expand Down
12 changes: 5 additions & 7 deletions docs/usage/basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -740,13 +740,11 @@
"outputs": [],
"source": [
"minuit2 = Minuit2(\n",
" callback=CallbackList(\n",
" [\n",
" CSVSummary(\"traceback-1D.csv\"),\n",
" YAMLSummary(\"fit-result-1D.yaml\"),\n",
" TFSummary(),\n",
" ]\n",
" )\n",
" callback=CallbackList([\n",
" CSVSummary(\"traceback-1D.csv\"),\n",
" YAMLSummary(\"fit-result-1D.yaml\"),\n",
" TFSummary(),\n",
" ])\n",
")\n",
"fit_result = minuit2.optimize(estimator, initial_parameters)\n",
"fit_result"
Expand Down
4 changes: 1 addition & 3 deletions docs/usage/unbinned-fit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,7 @@
"\n",
"\n",
"def gaussian(x, mu, sigma):\n",
" return sp.exp(-((x - mu) ** 2) / (2 * sigma**2)) / sp.sqrt(\n",
" 2 * sp.pi * sigma**2\n",
" )\n",
" return sp.exp(-((x - mu) ** 2) / (2 * sigma**2)) / sp.sqrt(2 * sp.pi * sigma**2)\n",
"\n",
"\n",
"x, y, mu, sigma_x, sigma_y = sp.symbols(\"x y mu sigma_x sigma_y\")\n",
Expand Down
1 change: 1 addition & 0 deletions src/tensorwaves/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The `.data` module takes care of data generation."""

from __future__ import annotations

import logging
Expand Down
1 change: 1 addition & 0 deletions src/tensorwaves/data/_data_sample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Helper functions for modifying `.DataSample` instances."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable
Expand Down
1 change: 1 addition & 0 deletions src/tensorwaves/data/phasespace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implementations of a `.DataGenerator` for four-momentum samples."""

from __future__ import annotations

import logging
Expand Down
1 change: 1 addition & 0 deletions src/tensorwaves/data/rng.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implementations of `.RealNumberGenerator`."""

from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Union
Expand Down
1 change: 1 addition & 0 deletions src/tensorwaves/data/transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implementations of `.DataTransformer`."""

from __future__ import annotations

from typing import TYPE_CHECKING, Mapping
Expand Down
1 change: 1 addition & 0 deletions src/tensorwaves/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
All estimators have to implement the `.Estimator` interface.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Iterable, Mapping
Expand Down
1 change: 1 addition & 0 deletions src/tensorwaves/function/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Express mathematical expressions in terms of computational functions."""

from __future__ import annotations

import inspect
Expand Down
1 change: 1 addition & 0 deletions src/tensorwaves/function/_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Computational back-end handling."""

from __future__ import annotations

from functools import partial
Expand Down
1 change: 1 addition & 0 deletions src/tensorwaves/function/sympy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Lambdify `sympy` expression trees to a `.Function`."""

from __future__ import annotations

import logging
Expand Down
34 changes: 16 additions & 18 deletions src/tensorwaves/function/sympy/_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,22 @@ def _get_numpy_code(self: _T, expr: sp.Expr, *args: Any) -> str:
return decorator


@_forward_to_numpy_printer(
[
"ArrayAxisSum",
"ArrayMultiplication",
"BoostZ",
"BoostZMatrix",
"RotationY",
"RotationYMatrix",
"RotationZ",
"RotationZMatrix",
"_ArraySize",
"_BoostZMatrixImplementation",
"_OnesArray",
"_RotationYMatrixImplementation",
"_RotationZMatrixImplementation",
"_ZerosArray",
]
)
@_forward_to_numpy_printer([
"ArrayAxisSum",
"ArrayMultiplication",
"BoostZ",
"BoostZMatrix",
"RotationY",
"RotationYMatrix",
"RotationZ",
"RotationZMatrix",
"_ArraySize",
"_BoostZMatrixImplementation",
"_OnesArray",
"_RotationYMatrixImplementation",
"_RotationZMatrixImplementation",
"_ZerosArray",
])
class TensorflowPrinter(CustomNumPyPrinter):
module_imports = {"tensorflow.experimental": {"numpy as tnp"}}
_module = "tnp"
Expand Down
1 change: 1 addition & 0 deletions src/tensorwaves/interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Defines top-level interface of tensorwaves."""

from __future__ import annotations

from abc import ABC, abstractmethod
Expand Down
1 change: 1 addition & 0 deletions src/tensorwaves/optimizer/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Collection of loggers that can be inserted into an optimizer as callback."""

from __future__ import annotations

import csv
Expand Down
10 changes: 4 additions & 6 deletions tests/data/test_phasespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,10 @@ def test_generate_deterministic(self, pdg: "ParticleCollection"):
weights = phsp_momenta.get("weights", []) # type: ignore[var-annotated]
del phsp_momenta["weights"]
print("Expected values, get by running pytest with the -s flag") # noqa: T201
pprint( # noqa: T203
{
i: np.round(four_momenta, decimals=10).tolist()
for i, four_momenta in phsp_momenta.items()
}
)
pprint({ # noqa: T203
i: np.round(four_momenta, decimals=10).tolist()
for i, four_momenta in phsp_momenta.items()
})
expected_sample = {
"p0": [
[0.7059154068, 0.3572095625, 0.251997269, 0.2441281612],
Expand Down
2 changes: 1 addition & 1 deletion tests/optimizer/test_fit_simple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_optimize_all_parameters(
backend: str,
domain_and_data_sample: tuple[DataSample, DataSample],
expression_and_parameters: tuple[sp.Expr, dict[sp.Symbol, float]],
optimizer_type: (type[Minuit2] | type[ScipyMinimizer]),
optimizer_type: type[Minuit2] | type[ScipyMinimizer],
output_dir: Path,
):
domain, data = domain_and_data_sample
Expand Down
8 changes: 2 additions & 6 deletions tests/optimizer/test_minuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,12 @@ def test_mock_callback(self, mocker: MockerFixture) -> None:
{"x": 2.5}, # 2 (x - 1) - 3 == 0 -> x = 3/2 + 1
),
(
Polynomial1DMinimaEstimator(
lambda x: x**3 + (x - 1) ** 2 - 3 * x + 1
),
Polynomial1DMinimaEstimator(lambda x: x**3 + (x - 1) ** 2 - 3 * x + 1),
{"x": -1.0},
{"x": 1.0},
),
(
Polynomial1DMinimaEstimator(
lambda x: x**3 + (x - 1) ** 2 - 3 * x + 1
),
Polynomial1DMinimaEstimator(lambda x: x**3 + (x - 1) ** 2 - 3 * x + 1),
{"x": -2.0},
None, # no convergence
),
Expand Down

0 comments on commit f387081

Please sign in to comment.