Skip to content

Commit

Permalink
chore: simplify type hints with PEP 563 (#250)
Browse files Browse the repository at this point in the history
* chore: add from __future__ import annotations
* chore: clean up typing imports
  No need to import Optional, List, etc anymore
* chore: move group functions to helicity.decay
* ci: update pip constraints and pre-commit config
* fix: remove type aliases if hashable
* style: reorder helicity module API

Co-authored-by: GitHub <noreply@github.com>
  • Loading branch information
redeboer and web-flow authored Mar 7, 2022
1 parent 73df164 commit eac40f7
Show file tree
Hide file tree
Showing 26 changed files with 507 additions and 513 deletions.
1 change: 1 addition & 0 deletions .constraints/py3.10.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ flake8-blind-except==0.2.0
flake8-bugbear==22.1.11
flake8-builtins==1.5.3
flake8-comprehensions==3.8.0
flake8-future-annotations==0.0.4
flake8-plugin-utils==1.3.2
flake8-polyfill==1.0.2
flake8-pytest-style==1.6.0
Expand Down
1 change: 1 addition & 0 deletions .constraints/py3.7.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ flake8-blind-except==0.2.0
flake8-bugbear==22.1.11
flake8-builtins==1.5.3
flake8-comprehensions==3.8.0
flake8-future-annotations==0.0.4
flake8-plugin-utils==1.3.2
flake8-polyfill==1.0.2
flake8-pytest-style==1.6.0
Expand Down
1 change: 1 addition & 0 deletions .constraints/py3.8.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ flake8-blind-except==0.2.0
flake8-bugbear==22.1.11
flake8-builtins==1.5.3
flake8-comprehensions==3.8.0
flake8-future-annotations==0.0.4
flake8-plugin-utils==1.3.2
flake8-polyfill==1.0.2
flake8-pytest-style==1.6.0
Expand Down
1 change: 1 addition & 0 deletions .constraints/py3.9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ flake8-blind-except==0.2.0
flake8-bugbear==22.1.11
flake8-builtins==1.5.3
flake8-comprehensions==3.8.0
flake8-future-annotations==0.0.4
flake8-plugin-utils==1.3.2
flake8-polyfill==1.0.2
flake8-pytest-style==1.6.0
Expand Down
22 changes: 17 additions & 5 deletions docs/_relink_references.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,48 @@
See also https://github.com/sphinx-doc/sphinx/issues/5868.
"""

from typing import List
from __future__ import annotations

import sphinx.domains.python
from docutils import nodes
from sphinx.addnodes import pending_xref
from sphinx.environment import BuildEnvironment

__TARGET_SUBSTITUTIONS = {
"InteractionProperties": "qrules.quantum_numbers.InteractionProperties",
"ParameterValue": "ampform.helicity.ParameterValue",
"ReactionInfo": "qrules.transition.ReactionInfo",
"Slider": "symplot.Slider",
"State": "qrules.transition.State",
"StateTransition": "qrules.transition.StateTransition",
"Topology": "qrules.topology.Topology",
"WignerD": "sympy.physics.quantum.spin.WignerD",
"a set-like object providing a view on D's items": "typing.ItemsView",
"a set-like object providing a view on D's keys": "typing.KeysView",
"ampform.helicity._T": "typing.TypeVar",
"an object providing a view on D's values": "typing.ValuesView",
"sp.Basic": "sympy.core.basic.Basic",
"sp.Expr": "sympy.core.expr.Expr",
"sp.Float": "sympy.core.numbers.Float",
"sp.Indexed": "sympy.tensor.indexed.Indexed",
"sp.IndexedBase": "sympy.tensor.indexed.IndexedBase",
"sp.Symbol": "sympy.core.symbol.Symbol",
"sympy.printing.numpy.NumPyPrinter": "sympy.printing.printer.Printer",
"typing_extensions.Protocol": "typing.Protocol",
}
__REF_TYPE_SUBSTITUTIONS = {
"DecoratedClass": "obj",
"DecoratedExpr": "obj",
"FourMomenta": "obj",
"FourMomentumSymbol": "obj",
"None": "obj",
"ParameterValue": "obj",
"RangeDefinition": "obj",
"ampform.dynamics.builder.BuilderReturnType": "obj",
"ampform.helicity.ParameterValue": "obj",
"ampform.kinematics.FourMomenta": "obj",
"ampform.kinematics.FourMomentumSymbol": "obj",
"ampform.sympy.DecoratedClass": "obj",
"ampform.sympy.DecoratedExpr": "obj",
"symplot.RangeDefinition": "obj",
"symplot.Slider": "obj",
}

Expand Down Expand Up @@ -97,7 +109,7 @@ def __get_env_kwargs(env: BuildEnvironment) -> dict:
return {}


def __create_nodes(env: BuildEnvironment, title: str) -> List[nodes.Node]:
def __create_nodes(env: BuildEnvironment, title: str) -> list[nodes.Node]:
short_name = title.split(".")[-1]
if env.config.python_use_unqualified_type_names:
return [
Expand Down
13 changes: 5 additions & 8 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ def fetch_logo(url: str, output_path: str) -> None:
"BuilderReturnType": "ampform.dynamics.builder.BuilderReturnType",
"FourMomenta": "ampform.kinematics.FourMomenta",
"FourMomentumSymbol": "ampform.kinematics.FourMomentumSymbol",
"ParameterValue": "ampform.helicity.ParameterValue",
"RangeDefinition": "symplot.RangeDefinition",
"Slider": "symplot.Slider",
# https://github.com/sphinx-doc/sphinx/pull/10183
# "ParameterValue": "ampform.helicity.ParameterValue",
# "Slider": "symplot.Slider",
}
autodoc_typehints_format = "short"
codeautolink_concat_default = True
Expand Down Expand Up @@ -223,14 +224,10 @@ def fetch_logo(url: str, output_path: str) -> None:
primary_domain = "py"
nitpicky = True # warn if cross-references are missing
nitpick_ignore = [
("py:class", "ArraySum"),
("py:class", "ampform.sympy._array_expressions.MatrixMultiplication"),
("py:class", "ipywidgets.widgets.widget_float.FloatSlider"),
("py:class", "ipywidgets.widgets.widget_int.IntSlider"),
(
"py:class",
"sympy.tensor.array.expressions.array_expressions.ArraySymbol",
),
("py:class", "ampform.sympy._array_expressions.ArraySum"),
("py:class", "ampform.sympy._array_expressions.MatrixMultiplication"),
("py:class", "typing_extensions.Protocol"),
]

Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ flake8 =
flake8-bugbear
flake8-builtins
flake8-comprehensions
flake8-future-annotations
flake8-pytest-style
flake8-rst-docstrings
flake8-type-ignore; python_version >="3.8.0"
Expand Down
29 changes: 15 additions & 14 deletions src/ampform/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
.. seealso:: :doc:`/usage/dynamics` and
:doc:`/usage/dynamics/analytic-continuation`
"""
from __future__ import annotations

import re
import sys
from typing import Dict, List, Optional, Sequence
from typing import Sequence

import sympy as sp
from sympy.printing.conventions import split_super_sub
Expand Down Expand Up @@ -57,20 +58,20 @@ class BlattWeisskopfSquared(UnevaluatedExpression):
See also :ref:`usage/dynamics:Form factor`.
"""
is_commutative = True
max_angular_momentum: Optional[int] = None
max_angular_momentum: int | None = None
"""Limit the maximum allowed angular momentum :math:`L`.
This improves performance when :math:`L` is a `~sympy.core.symbol.Symbol`
and you are note interested in higher angular momenta.
"""

def __new__(cls, angular_momentum, z, **hints) -> "BlattWeisskopfSquared":
def __new__(cls, angular_momentum, z, **hints) -> BlattWeisskopfSquared:
return create_expression(cls, angular_momentum, z, **hints)

def evaluate(self) -> sp.Expr:
angular_momentum: sp.Expr = self.args[0] # type: ignore[assignment]
z: sp.Expr = self.args[1] # type: ignore[assignment]
cases: Dict[int, sp.Expr] = {
cases: dict[int, sp.Expr] = {
0: sp.S.One,
1: 2 * z / (z + 1),
2: 13 * z**2 / ((z - 3) * (z - 3) + 9 * z),
Expand Down Expand Up @@ -170,7 +171,7 @@ class PhaseSpaceFactor(UnevaluatedExpression):

is_commutative = True

def __new__(cls, s, m_a, m_b, **hints) -> "PhaseSpaceFactor":
def __new__(cls, s, m_a, m_b, **hints) -> PhaseSpaceFactor:
return create_expression(cls, s, m_a, m_b, **hints)

def evaluate(self) -> sp.Expr:
Expand Down Expand Up @@ -203,7 +204,7 @@ class PhaseSpaceFactorAbs(UnevaluatedExpression):

is_commutative = True

def __new__(cls, s, m_a, m_b, **hints) -> "PhaseSpaceFactorAbs":
def __new__(cls, s, m_a, m_b, **hints) -> PhaseSpaceFactorAbs:
return create_expression(cls, s, m_a, m_b, **hints)

def evaluate(self) -> sp.Expr:
Expand Down Expand Up @@ -233,7 +234,7 @@ class PhaseSpaceFactorAnalytic(UnevaluatedExpression):

is_commutative = True

def __new__(cls, s, m_a, m_b, **hints) -> "PhaseSpaceFactorAnalytic":
def __new__(cls, s, m_a, m_b, **hints) -> PhaseSpaceFactorAnalytic:
return create_expression(cls, s, m_a, m_b, **hints)

def evaluate(self) -> sp.Expr:
Expand Down Expand Up @@ -264,7 +265,7 @@ class PhaseSpaceFactorComplex(UnevaluatedExpression):

is_commutative = True

def __new__(cls, s, m_a, m_b, **hints) -> "PhaseSpaceFactorComplex":
def __new__(cls, s, m_a, m_b, **hints) -> PhaseSpaceFactorComplex:
return create_expression(cls, s, m_a, m_b, **hints)

def evaluate(self) -> sp.Expr:
Expand Down Expand Up @@ -336,10 +337,10 @@ def __new__( # pylint: disable=too-many-arguments
m_b,
angular_momentum,
meson_radius,
phsp_factor: Optional[PhaseSpaceFactorProtocol] = None,
name: Optional[str] = None,
phsp_factor: PhaseSpaceFactorProtocol | None = None,
name: str | None = None,
evaluate: bool = False,
) -> "EnergyDependentWidth":
) -> EnergyDependentWidth:
args = sp.sympify(
(s, mass0, gamma0, m_a, m_b, angular_momentum, meson_radius)
)
Expand Down Expand Up @@ -409,7 +410,7 @@ class BreakupMomentumSquared(UnevaluatedExpression):

is_commutative = True

def __new__(cls, s, m_a, m_b, **hints) -> "BreakupMomentumSquared":
def __new__(cls, s, m_a, m_b, **hints) -> BreakupMomentumSquared:
return create_expression(cls, s, m_a, m_b, **hints)

def evaluate(self) -> sp.Expr:
Expand Down Expand Up @@ -441,7 +442,7 @@ def relativistic_breit_wigner_with_ff( # pylint: disable=too-many-arguments
m_b,
angular_momentum,
meson_radius,
phsp_factor: Optional[PhaseSpaceFactorProtocol] = None,
phsp_factor: PhaseSpaceFactorProtocol | None = None,
) -> sp.Expr:
"""Relativistic Breit-Wigner with `.BlattWeisskopfSquared` factor.
Expand Down Expand Up @@ -477,7 +478,7 @@ def _indices_to_subscript(indices: Sequence[int]) -> str:
return "_{" + subscript + "}"


def _determine_indices(symbol) -> List[int]:
def _determine_indices(symbol) -> list[int]:
r"""Extract any indices if available from a `~sympy.core.symbol.Symbol`.
>>> _determine_indices(sp.Symbol("m1"))
Expand Down
23 changes: 12 additions & 11 deletions src/ampform/dynamics/builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Build `~ampform.dynamics` with correct variable names and values."""
from __future__ import annotations

import sys
from typing import Dict, Optional, Tuple
from typing import Dict, Tuple

import sympy as sp
from attrs import field, frozen
Expand Down Expand Up @@ -39,7 +40,7 @@ class TwoBodyKinematicVariableSet:
outgoing_state_mass2: sp.Symbol = field(validator=instance_of(sp.Symbol))
helicity_theta: sp.Symbol = field(validator=instance_of(sp.Symbol))
helicity_phi: sp.Symbol = field(validator=instance_of(sp.Symbol))
angular_momentum: Optional[int] = field(default=None)
angular_momentum: int | None = field(default=None)


BuilderReturnType = Tuple[sp.Expr, Dict[sp.Symbol, float]]
Expand All @@ -66,21 +67,21 @@ class ResonanceDynamicsBuilder(Protocol):

def __call__(
self, resonance: Particle, variable_pool: TwoBodyKinematicVariableSet
) -> "BuilderReturnType":
) -> BuilderReturnType:
"""Formulate a dynamics `~sympy.core.expr.Expr` for this resonance."""
...


def create_non_dynamic(
resonance: Particle, variable_pool: TwoBodyKinematicVariableSet
) -> "BuilderReturnType":
) -> BuilderReturnType:
# pylint: disable=unused-argument
return (sp.S.One, {})


def create_non_dynamic_with_ff(
resonance: Particle, variable_pool: TwoBodyKinematicVariableSet
) -> "BuilderReturnType":
) -> BuilderReturnType:
"""Generate (only) a Blatt-Weisskopf form factor for a two-body decay.
Returns the `~sympy.functions.elementary.miscellaneous.sqrt` of a
Expand Down Expand Up @@ -132,7 +133,7 @@ def __init__(
self,
form_factor: bool = False,
energy_dependent_width: bool = False,
phsp_factor: Optional[PhaseSpaceFactorProtocol] = None,
phsp_factor: PhaseSpaceFactorProtocol | None = None,
) -> None:
if phsp_factor is None:
phsp_factor = PhaseSpaceFactor
Expand All @@ -142,7 +143,7 @@ def __init__(

def __call__(
self, resonance: Particle, variable_pool: TwoBodyKinematicVariableSet
) -> "BuilderReturnType":
) -> BuilderReturnType:
"""Formulate a relativistic Breit-Wigner for this resonance."""
if self.energy_dependent_width:
expr, parameter_defaults = self.__energy_dependent_breit_wigner(
Expand All @@ -163,7 +164,7 @@ def __call__(
@staticmethod
def __simple_breit_wigner(
resonance: Particle, variable_pool: TwoBodyKinematicVariableSet
) -> "BuilderReturnType":
) -> BuilderReturnType:
inv_mass = variable_pool.incoming_state_mass
res_identifier = resonance.latex if resonance.latex else resonance.name
res_mass = sp.Symbol(f"m_{{{res_identifier}}}")
Expand All @@ -181,7 +182,7 @@ def __simple_breit_wigner(

def __energy_dependent_breit_wigner(
self, resonance: Particle, variable_pool: TwoBodyKinematicVariableSet
) -> "BuilderReturnType":
) -> BuilderReturnType:
if variable_pool.angular_momentum is None:
raise ValueError(
"Angular momentum is not defined but is required in the"
Expand Down Expand Up @@ -217,7 +218,7 @@ def __energy_dependent_breit_wigner(

def __create_form_factor(
self, resonance: Particle, variable_pool: TwoBodyKinematicVariableSet
) -> "BuilderReturnType":
) -> BuilderReturnType:
if variable_pool.angular_momentum is None:
raise ValueError(
"Angular momentum is not defined but is required in the"
Expand Down Expand Up @@ -247,7 +248,7 @@ def __create_form_factor(
@staticmethod
def __create_symbols(
resonance: Particle,
) -> Tuple[sp.Symbol, sp.Symbol, sp.Symbol]:
) -> tuple[sp.Symbol, sp.Symbol, sp.Symbol]:
res_identifier = resonance.latex if resonance.latex else resonance.name
res_mass = sp.Symbol(f"m_{{{res_identifier}}}")
res_width = sp.Symbol(Rf"\Gamma_{{{res_identifier}}}")
Expand Down
Loading

0 comments on commit eac40f7

Please sign in to comment.