diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000..4ac443d1 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @ekiefl diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..2cbf2639 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,118 @@ +name: CI + +on: + pull_request: + branches: [main] + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.8' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -r requirements-dev.txt + + # --- Ruff + + - name: ruff (lint) + id: ruff_lint + continue-on-error: true + run: | + ruff --version + ruff check . --verbose --diff + echo "ruff_lint_failed=$?" >> $GITHUB_ENV + + - name: ruff (format) + id: ruff_format + continue-on-error: true + run: | + ruff --version + ruff format . --check --verbose --diff + echo "ruff_format_failed=$?" >> $GITHUB_ENV + + # --- Pytest + + - name: pytest + id: pytest + continue-on-error: true + run: | + pytest --version + pytest + echo "pytest_failed=$?" >> $GITHUB_ENV + + # --- Pyright + + - name: pyright + id: pyright + continue-on-error: true + run: | + pyright --version + pyright --project ./pyrightconfig.ci.json + echo "pyright_failed=$?" >> $GITHUB_ENV + + # --- Main + + - name: Test results + if: always() + run: | + # Print out test results + passed=() + failed=() + + if [[ "${{ env.pytest_failed }}" != "0" ]]; then + failed+=("pytest") + else + passed+=("pytest") + fi + + if [[ "${{ env.pyright_failed }}" != "0" ]]; then + failed+=("pyright") + else + passed+=("pyright") + fi + + if [[ "${{ env.ruff_lint_failed }}" != "0" ]]; then + failed+=("ruff_lint") + else + passed+=("ruff_lint") + fi + + if [[ "${{ env.ruff_format_failed }}" != "0" ]]; then + failed+=("ruff_format") + else + passed+=("ruff_format") + fi + + if [ ${#passed[@]} -ne 0 ]; then + echo "✅ PASSED:" + for check in "${passed[@]}"; do + echo " - $check" + done + fi + + echo "" + + if [ ${#failed[@]} -ne 0 ]; then + echo "❌ FAILED:" + for check in "${failed[@]}"; do + echo " - $check" + done + else + echo "🚀🚀 ALL TESTS PASSED 🚀🚀" + fi + + echo "" + echo "Click above jobs for details on each success/failure" + + if [ ${#failed[@]} -ne 0 ]; then + exit 1 + fi diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7d34f72b..85c036e8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,15 +1,13 @@ repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.0 + hooks: + - id: ruff + args: [ --fix ] + - id: ruff-format + - repo: local hooks: - - id: isort - name: isort - entry: isort - language: system - - id: black - name: black - entry: black - language: system - types_or: [python, pyi] - id: pytest name: pytest entry: pytest diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 0dc3d143..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - // NOTE: I (@ekiefl) don't use VSCode, I think these settings should help you conform to the - // code standards in this repo - "python.analysis.typeCheckingMode": "basic", - "python.analysis.extraPaths": [ - "pooltool" - ], - "[python]": { - "editor.defaultFormatter": "ms-python.black-formatter" - }, - "isort.args": [ - "--profile", - "black" - ] -} diff --git a/docs/conf.py b/docs/conf.py index 57dca0a7..b59eb747 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,7 +12,8 @@ import os import sys -sys.path.insert(0, os.path.abspath('../')) + +sys.path.insert(0, os.path.abspath("../")) # -- Project information ----------------------------------------------------- @@ -46,7 +47,7 @@ # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. # NOTE: Don't use this for excluding python files, use `autoapi_ignore` below -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Global options ---------------------------------------------------------- @@ -59,7 +60,7 @@ # a list of builtin themes. # html_theme = "furo" -html_logo = '../pooltool/logo/logo_small.png' +html_logo = "../pooltool/logo/logo_small.png" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, @@ -85,14 +86,14 @@ } # -- copybutton options -copybutton_exclude = '.linenos, .gp, .go' +copybutton_exclude = ".linenos, .gp, .go" # -- myst options myst_enable_extensions = ["colon_fence"] # -- autoapi configuration --------------------------------------------------- -#autodoc_typehints = "signature" # autoapi respects this +# autodoc_typehints = "signature" # autoapi respects this autodoc_typehints = "both" # autoapi respects this autodoc_typehints_description_target = "documented_params" # autoapi respects this autodoc_class_signature = "mixed" @@ -110,29 +111,30 @@ autoapi_keep_files = True autoapi_ignore = [ - '*/test_*.py', + "*/test_*.py", "*/render.py", "*/ai/*", "*/user_config.py", ] # Everything in ani/ except animate.py -autoapi_ignore.extend([ - "*/ani/camera/*", - "*/ani/fonts/*", - "*/ani/image/*", - "*/ani/modes/*", - "*/ani/__init__.py", - "*/ani/action.py", - "*/ani/collision.py", - "*/ani/environment.py", - "*/ani/globals.py", - "*/ani/hud.py", - "*/ani/menu.py", - "*/ani/mouse.py", - "*/ani/tasks.py", - "*/ani/utils.py", -]) - +autoapi_ignore.extend( + [ + "*/ani/camera/*", + "*/ani/fonts/*", + "*/ani/image/*", + "*/ani/modes/*", + "*/ani/__init__.py", + "*/ani/action.py", + "*/ani/collision.py", + "*/ani/environment.py", + "*/ani/globals.py", + "*/ani/hud.py", + "*/ani/menu.py", + "*/ani/mouse.py", + "*/ani/tasks.py", + "*/ani/utils.py", + ] +) # -- custom auto_summary() macro --------------------------------------------- diff --git a/docs/custom_directives.py b/docs/custom_directives.py index 531dde21..c20a941d 100644 --- a/docs/custom_directives.py +++ b/docs/custom_directives.py @@ -1,6 +1,7 @@ from docutils import nodes from sphinx.util.docutils import SphinxDirective + class CachedPropertyDirective(SphinxDirective): required_arguments = 0 optional_arguments = 0 @@ -9,7 +10,7 @@ class CachedPropertyDirective(SphinxDirective): def run(self): targetid = f"cached-property-{self.env.new_serialno('cached-property')}" - targetnode = nodes.target('', '', ids=[targetid]) + targetnode = nodes.target("", "", ids=[targetid]) # Create an admonition node to hold the content admonition_node = nodes.admonition() @@ -23,12 +24,14 @@ def run(self): post_text = ", and should be accessed as an attribute, not as a method call." # Creating the hyperlink - uri = "https://docs.python.org/3/library/functools.html#functools.cached_property" + uri = ( + "https://docs.python.org/3/library/functools.html#functools.cached_property" + ) link_text = "cached property" - hyperlink = nodes.reference('', '', nodes.Text(link_text), refuri=uri) + hyperlink = nodes.reference("", "", nodes.Text(link_text), refuri=uri) # Creating the paragraph and adding the intro text and hyperlink - para = nodes.paragraph('', '') + para = nodes.paragraph("", "") para += nodes.Text(pre_text, pre_text) para += hyperlink para += nodes.Text(post_text, post_text) diff --git a/docs/custom_extensions.py b/docs/custom_extensions.py index 44fd9a4c..3e9c039e 100644 --- a/docs/custom_extensions.py +++ b/docs/custom_extensions.py @@ -1,7 +1,9 @@ -import attr import re + +import attr from sphinx.application import Sphinx + def extract_qualified_name(obj): repr_str = repr(obj) class_or_instance_method_pattern = r" float: - ... +def at_pos(system: System, pos: NDArray[np.float64]) -> float: ... @overload -def at_pos(cue_ball: Ball, pos: NDArray[np.float64]) -> float: - ... +def at_pos(cue_ball: Ball, pos: NDArray[np.float64]) -> float: ... def at_pos(*args) -> float: # type: ignore @@ -38,13 +36,11 @@ def _at_pos(cue_ball: Ball, pos: NDArray[np.float64]) -> float: @overload -def at_ball(system: System, ball_id: str, *, cut: float = 0.0) -> float: - ... +def at_ball(system: System, ball_id: str, *, cut: float = 0.0) -> float: ... @overload -def at_ball(cue_ball: Ball, object_ball: Ball, *, cut: float = 0.0) -> float: - ... +def at_ball(cue_ball: Ball, object_ball: Ball, *, cut: float = 0.0) -> float: ... def at_ball(*args, **kwargs) -> float: # type: ignore @@ -79,7 +75,7 @@ def _at_ball(cue_ball: Ball, object_ball: Ball, cut: float = 0.0) -> float: d = ptmath.norm3d(object_ball.state.rvw[0] - cue_ball.state.rvw[0]) # If for some reason d < 2R, set d = 2R - d = max(d, 2*R) + d = max(d, 2 * R) lower_bound = 0 upper_bound = np.pi / 2 - np.arccos((2 * R) / d) diff --git a/pooltool/ani/animate.py b/pooltool/ani/animate.py index 3dc069a3..6ed7b9c9 100755 --- a/pooltool/ani/animate.py +++ b/pooltool/ani/animate.py @@ -5,7 +5,6 @@ from functools import partial from typing import Generator, Optional, Tuple, Union -import gltf # FIXME at first glance this does nothing? import simplepbr from attrs import define from direct.showbase.ShowBase import ShowBase @@ -401,7 +400,7 @@ def show( # This line takes a view seconds to execute. It will generate a visible # window. Once the window has been generated, script execution continues - gui = pt.ShotViewer() + gui = pt.ShotViewer() # When this line is called, the window is populated with an animated # scene of the shot. diff --git a/pooltool/ani/camera/_camera.py b/pooltool/ani/camera/_camera.py index c3a20a0a..6c775e81 100644 --- a/pooltool/ani/camera/_camera.py +++ b/pooltool/ani/camera/_camera.py @@ -218,12 +218,12 @@ def from_camera(cls, camera: Camera) -> CameraState: return cls( cam_hpr=_vec_to_tuple(camera.node.getHpr()), cam_pos=_vec_to_tuple(camera.node.getPos()), - fixation_hpr=_vec_to_tuple(camera.fixation.getHpr()) - if camera.fixated - else None, - fixation_pos=_vec_to_tuple(camera.fixation.getPos()) - if camera.fixated - else None, + fixation_hpr=( + _vec_to_tuple(camera.fixation.getHpr()) if camera.fixated else None + ), + fixation_pos=( + _vec_to_tuple(camera.fixation.getPos()) if camera.fixated else None + ), ) @classmethod diff --git a/pooltool/ani/collision.py b/pooltool/ani/collision.py index 3a479ce5..a06bf979 100644 --- a/pooltool/ani/collision.py +++ b/pooltool/ani/collision.py @@ -125,9 +125,9 @@ def process_cushion_collision(self, entry): # Correct for cue's cylindrical radius at collision point # distance from cue tip (E) to desired collision point (D) - l = np.sqrt((Dx - Ex) ** 2 + (Dy - Ey) ** 2 + (Dz - Ez) ** 2) - cue_radius = self.get_cue_radius(l) - min_theta += np.arctan2(cue_radius, l) + ll = np.sqrt((Dx - Ex) ** 2 + (Dy - Ey) ** 2 + (Dz - Ez) ** 2) + cue_radius = self.get_cue_radius(ll) + min_theta += np.arctan2(cue_radius, ll) return max(0, min_theta) * 180 / np.pi @@ -185,7 +185,7 @@ def process_ball_collision(self, entry): min_theta = min_theta_no_english + beta return max(0, min_theta) * 180 / np.pi - def get_cue_radius(self, l): + def get_cue_radius(self, length): """Returns cue radius at collision point, given point is distance l from tip""" bounds = visual.cue.get_node("cue_stick").get_tight_bounds() @@ -197,7 +197,7 @@ def get_cue_radius(self, l): m = (R - r) / L # rise/run b = r # intercept - return m * l + b + return m * length + b def get_cushion(self, entry): expected_suffix = "cushion_cplane_" diff --git a/pooltool/ani/image/interface.py b/pooltool/ani/image/interface.py index 6d70b186..d5912f9f 100644 --- a/pooltool/ani/image/interface.py +++ b/pooltool/ani/image/interface.py @@ -15,8 +15,7 @@ class Exporter(Protocol): - def save(self, data: NDArray[np.uint8]) -> Any: - ... + def save(self, data: NDArray[np.uint8]) -> Any: ... def get_graphics_texture() -> Texture: diff --git a/pooltool/ani/modes/ball_in_hand.py b/pooltool/ani/modes/ball_in_hand.py index 553e5adc..31878c38 100644 --- a/pooltool/ani/modes/ball_in_hand.py +++ b/pooltool/ani/modes/ball_in_hand.py @@ -127,8 +127,9 @@ def try_placement(self): If no, places and returns True. If yes, returns False """ - r, pos = self.grabbed_ball._ball.params.R, np.array( - self.grab_ball_node.getPos() + r, pos = ( + self.grabbed_ball._ball.params.R, + np.array(self.grab_ball_node.getPos()), ) for ball in visual.balls.values(): diff --git a/pooltool/ani/modes/datatypes.py b/pooltool/ani/modes/datatypes.py index 39af5901..505261c9 100644 --- a/pooltool/ani/modes/datatypes.py +++ b/pooltool/ani/modes/datatypes.py @@ -52,7 +52,7 @@ def shared_task(self, task): elif self.keymap.get(Action.introspect): self.keymap[Action.introspect] = False - shot = multisystem.active + shot = multisystem.active # noqa F841 pdb.set_trace() elif self.keymap.get(Action.show_help): diff --git a/pooltool/ani/modes/game_over.py b/pooltool/ani/modes/game_over.py index 019b025d..14f7e657 100644 --- a/pooltool/ani/modes/game_over.py +++ b/pooltool/ani/modes/game_over.py @@ -37,7 +37,7 @@ def render_game_over_screen(self): if (winner := Global.game.shot_info.winner) is not None: title = f"Game over! {winner.name} wins!" else: - title = f"Game over! Tie game!" + title = "Game over! Tie game!" self.game_over_menu = GenericMenu( title=title, diff --git a/pooltool/ani/modes/stroke.py b/pooltool/ani/modes/stroke.py index 1668d731..7e3e4814 100644 --- a/pooltool/ani/modes/stroke.py +++ b/pooltool/ani/modes/stroke.py @@ -7,8 +7,6 @@ from pooltool.ani.globals import Global from pooltool.ani.modes.datatypes import BaseMode, Mode from pooltool.ani.mouse import MouseMode, mouse -from pooltool.objects.ball.datatypes import Ball -from pooltool.objects.table.components import Pocket from pooltool.system.datatypes import multisystem from pooltool.system.render import visual diff --git a/pooltool/ani/utils.py b/pooltool/ani/utils.py index ab349a49..02e9cb0c 100644 --- a/pooltool/ani/utils.py +++ b/pooltool/ani/utils.py @@ -14,7 +14,7 @@ class CustomOnscreenText(OnscreenText): def __init__(self, **kwargs): - assert "font" not in kwargs, f"Cannot modify 'font', use 'font_name' instead" + assert "font" not in kwargs, "Cannot modify 'font', use 'font_name' instead" if "font_name" in kwargs: font = load_font(kwargs["font_name"]) @@ -197,4 +197,4 @@ def alignTo(obj, other, selfPos, otherPos=None, gap=(0, 0)): CR = DGG.CR = (1, -1) # CENTER RIGHT CB = DGG.CB = (-1, 2) # CENTER BOTTOM CT = DGG.CT = (-1, 3) # CENTER TOP -O = DGG.O = 0 # ORIGIN +OO = DGG.O = 0 # ORIGIN diff --git a/pooltool/error.py b/pooltool/error.py index 73d01411..b1b279a4 100644 --- a/pooltool/error.py +++ b/pooltool/error.py @@ -3,27 +3,26 @@ """Borrowed from https://github.com/merenlab/anvio/blob/master/anvio/errors.py""" import textwrap +from typing import Optional from pooltool.terminal import color_text -def remove_spaces(text): +def remove_spaces(text: Optional[str]) -> str: if not text: return "" - while True: - if text.find(" ") > -1: - text = text.replace(" ", " ") - else: - break + while " " in text: + text = text.replace(" ", " ") return text class PoolToolError(Exception): - def __init__(self, e=None): - Exception.__init__(self) - return + def __init__(self, e: Optional[str] = None) -> None: + super().__init__() + self.e: str = e if e is not None else "" + self.error_type = "General Error" def __str__(self): max_len = max([len(line) for line in textwrap.fill(self.e, 80).split("\n")]) @@ -47,21 +46,18 @@ def clear_text(self): class ConfigError(PoolToolError): - def __init__(self, e=None): - self.e = remove_spaces(e) + def __init__(self, e: Optional[str] = None) -> None: + super().__init__(remove_spaces(e)) self.error_type = "Config Error" - PoolToolError.__init__(self) class StrokeError(PoolToolError): - def __init__(self, e=None): - self.e = remove_spaces(e) + def __init__(self, e: Optional[str] = None) -> None: + super().__init__(remove_spaces(e)) self.error_type = "Stroke Error" - PoolToolError.__init__(self) class SimulateError(PoolToolError): - def __init__(self, e=None): - self.e = remove_spaces(e) + def __init__(self, e: Optional[str] = None) -> None: + super().__init__(remove_spaces(e)) self.error_type = "Simulate Error" - PoolToolError.__init__(self) diff --git a/pooltool/events/__init__.py b/pooltool/events/__init__.py index 0e6fb906..01dd0836 100644 --- a/pooltool/events/__init__.py +++ b/pooltool/events/__init__.py @@ -1,3 +1,4 @@ +from pooltool.events.datatypes import Agent, AgentType, Event, EventType from pooltool.events.factory import ( ball_ball_collision, ball_circular_cushion_collision, @@ -10,7 +11,6 @@ spinning_stationary_transition, stick_ball_collision, ) -from pooltool.events.datatypes import Agent, AgentType, Event, EventType from pooltool.events.filter import ( by_ball, by_time, diff --git a/pooltool/events/datatypes.py b/pooltool/events/datatypes.py index 4a9f1e20..8bfb9898 100644 --- a/pooltool/events/datatypes.py +++ b/pooltool/events/datatypes.py @@ -33,17 +33,18 @@ class EventType(strenum.StrEnum): BALL_POCKET: A ball pocket "collision". This marks the point at which the ball crosses the *point of no return*. - STICK_BALL: + STICK_BALL: A cue-stick ball collision. - SPINNING_STATIONARY: + SPINNING_STATIONARY: A ball transition from spinning to stationary. - ROLLING_STATIONARY: + ROLLING_STATIONARY: A ball transition from rolling to stationary. - ROLLING_SPINNING: + ROLLING_SPINNING: A ball transition from rolling to spinning. - SLIDING_ROLLING: + SLIDING_ROLLING: A ball transition from sliding to rolling. """ + NONE = strenum.auto() BALL_BALL = strenum.auto() BALL_LINEAR_CUSHION = strenum.auto() @@ -96,6 +97,7 @@ class AgentType(strenum.StrEnum): LINEAR_CUSHION_SEGMENT: A linear cushion segment agent. CIRCULAR_CUSHION_SEGMENT: A circular cushion segment agent. """ + NULL = strenum.auto() CUE = strenum.auto() BALL = strenum.auto() @@ -130,6 +132,7 @@ class Agent: initial: The state of the agent before an event. final: The state of the agent after an event. """ + id: str agent_type: AgentType @@ -270,7 +273,6 @@ def _disambiguate_agent_structuring( ) - @define class Event: """Represents an event. @@ -297,6 +299,7 @@ class Event: time: The time at which the event occurs. """ + event_type: EventType agents: Tuple[Agent, ...] time: float diff --git a/pooltool/evolution/event_based/simulate.py b/pooltool/evolution/event_based/simulate.py index 6f25a41f..0e537ab0 100755 --- a/pooltool/evolution/event_based/simulate.py +++ b/pooltool/evolution/event_based/simulate.py @@ -158,7 +158,9 @@ def simulate( events = 0 while True: event = get_next_event( - shot, transition_cache=transition_cache, quartic_solver=quartic_solver, + shot, + transition_cache=transition_cache, + quartic_solver=quartic_solver, ) if event.time == np.inf: @@ -197,7 +199,7 @@ def _evolve(shot: System, dt: float): partial function so parameters don't continuously need to be passed """ - for ball_id, ball in shot.balls.items(): + for ball in shot.balls.values(): rvw, _ = evolve.evolve_ball_motion( state=ball.state.s, rvw=ball.state.rvw, @@ -252,6 +254,7 @@ def get_next_event( def _null(): return {"null": null_event(time=np.inf)} + @attrs.define class TransitionCache: transitions: Dict[str, Event] = attrs.field(default=_null) diff --git a/pooltool/evolution/event_based/test_simulate.py b/pooltool/evolution/event_based/test_simulate.py index 635619b0..52896bec 100644 --- a/pooltool/evolution/event_based/test_simulate.py +++ b/pooltool/evolution/event_based/test_simulate.py @@ -221,7 +221,7 @@ def test_case4(solver: quartic.QuarticSolver): angle of 45 (and not once in 4500 shots with a cut angle of 0) """ - shot = System.load(TEST_DIR / "case4.msgpack") + shot = System.load(TEST_DIR / "case4.msgpack") # noqa F841 # FIXME This will go on for a very, very, very long time. To introspect, add an # early break after 8 events. This represents one cycle of the loop diff --git a/pooltool/game/datatypes.py b/pooltool/game/datatypes.py index b2e4a498..6edb2a41 100644 --- a/pooltool/game/datatypes.py +++ b/pooltool/game/datatypes.py @@ -12,6 +12,7 @@ class GameType(StrEnum): SANDBOX: SUMTOTHREE: """ + EIGHTBALL = auto() NINEBALL = auto() THREECUSHION = auto() diff --git a/pooltool/game/layouts.py b/pooltool/game/layouts.py index d2c29719..2b522ee9 100755 --- a/pooltool/game/layouts.py +++ b/pooltool/game/layouts.py @@ -470,8 +470,7 @@ def __call__( ballset: Optional[BallSet] = None, ball_params: Optional[BallParams] = None, **kwargs: Any, - ) -> Dict[str, Ball]: - ... + ) -> Dict[str, Ball]: ... _game_rack_map: Dict[str, GetRackProtocol] = { @@ -509,7 +508,7 @@ def get_rack( allocated within a larger, virtual radius defined as ``(1 + spacing_factor) * R``, where ``R`` represents the actual radius of the ball. Within this expanded radius, the ball's position is determined randomly, allowing for a - controlled separation between each ball. The `spacing_factor` therefore + controlled separation between each ball. The ``spacing_factor`` therefore dictates the degree of this separation, with higher values resulting in greater distances between adjacent balls. Setting this to 0 is not recommended. diff --git a/pooltool/game/ruleset/__init__.py b/pooltool/game/ruleset/__init__.py index 9ec75089..a71ed713 100755 --- a/pooltool/game/ruleset/__init__.py +++ b/pooltool/game/ruleset/__init__.py @@ -6,8 +6,8 @@ from pooltool.game.ruleset.nine_ball import _NineBall from pooltool.game.ruleset.sandbox import _SandBox from pooltool.game.ruleset.snooker import _Snooker -from pooltool.game.ruleset.three_cushion import _ThreeCushion from pooltool.game.ruleset.sum_to_three import _SumToThree +from pooltool.game.ruleset.three_cushion import _ThreeCushion _ruleset_classes = { GameType.NINEBALL: _NineBall, diff --git a/pooltool/game/ruleset/datatypes.py b/pooltool/game/ruleset/datatypes.py index c3403e7f..8b5e74fc 100644 --- a/pooltool/game/ruleset/datatypes.py +++ b/pooltool/game/ruleset/datatypes.py @@ -30,11 +30,9 @@ def decide( system: System, game: Ruleset, callback: Optional[Callable[[Action], None]] = None, - ) -> Action: - ... + ) -> Action: ... - def apply(self, system: System, action: Action) -> None: - ... + def apply(self, system: System, action: Action) -> None: ... @attrs.define @@ -47,6 +45,7 @@ class Player: ai: Not implemented yet... """ + name: str ai: Optional[AIPlayer] = None @@ -174,6 +173,7 @@ class ShotInfo: The total game score (tallied after the shot). Keys are player names and values are points. """ + player: Player legal: bool reason: str @@ -256,7 +256,7 @@ def advance(self, shot: System) -> None: if (winner := self.shot_info.winner) is not None: self.log.add_msg(f"Game over! {winner.name} wins!", sentiment="good") else: - self.log.add_msg(f"Game over! Tie game!", sentiment="good") + self.log.add_msg("Game over! Tie game!", sentiment="good") return if self.shot_info.turn_over: diff --git a/pooltool/game/ruleset/sandbox.py b/pooltool/game/ruleset/sandbox.py index 1c850290..a391fcb5 100755 --- a/pooltool/game/ruleset/sandbox.py +++ b/pooltool/game/ruleset/sandbox.py @@ -16,7 +16,7 @@ class _SandBox(Ruleset): - def build_shot_info(self, _: System) -> ShotInfo: + def build_shot_info(self, shot: System) -> ShotInfo: return ShotInfo( player=self.active_player, legal=True, @@ -36,7 +36,7 @@ def initial_shot_constraints(self) -> ShotConstraints: call_shot=False, ) - def next_shot_constraints(self, _: System) -> ShotConstraints: + def next_shot_constraints(self, shot: System) -> ShotConstraints: return self.initial_shot_constraints() def respot_balls(self, shot: System): diff --git a/pooltool/game/ruleset/sum_to_three.py b/pooltool/game/ruleset/sum_to_three.py index a23639c5..06fa2568 100644 --- a/pooltool/game/ruleset/sum_to_three.py +++ b/pooltool/game/ruleset/sum_to_three.py @@ -82,7 +82,7 @@ def initial_shot_constraints(self) -> ShotConstraints: call_shot=False, ) - def next_shot_constraints(self, _: System) -> ShotConstraints: + def next_shot_constraints(self, shot: System) -> ShotConstraints: return self.shot_constraints def get_score(self, score: Counter, turn_over: bool) -> Counter: @@ -92,7 +92,7 @@ def get_score(self, score: Counter, turn_over: bool) -> Counter: score[self.active_player.name] += 1 return score - def respot_balls(self, _: System) -> None: + def respot_balls(self, shot: System) -> None: pass def process_shot(self, shot: System): diff --git a/pooltool/game/ruleset/three_cushion.py b/pooltool/game/ruleset/three_cushion.py index b4035083..f91e46bf 100644 --- a/pooltool/game/ruleset/three_cushion.py +++ b/pooltool/game/ruleset/three_cushion.py @@ -109,7 +109,7 @@ def initial_shot_constraints(self) -> ShotConstraints: call_shot=False, ) - def next_shot_constraints(self, _: System) -> ShotConstraints: + def next_shot_constraints(self, shot: System) -> ShotConstraints: assert (cueable := self.shot_constraints.cueable) is not None if self.shot_info.turn_over: @@ -130,7 +130,7 @@ def get_score(self, score: Counter, turn_over: bool) -> Counter: score[self.active_player.name] += 1 return score - def respot_balls(self, _: System) -> None: + def respot_balls(self, shot: System) -> None: pass def process_shot(self, shot: System): diff --git a/pooltool/game/test_layouts.py b/pooltool/game/test_layouts.py index e097781e..a3f14753 100644 --- a/pooltool/game/test_layouts.py +++ b/pooltool/game/test_layouts.py @@ -7,7 +7,6 @@ import numpy as np import pytest from numpy.typing import NDArray -from pooltool.objects.ball.datatypes import Ball import pooltool.ptmath as ptmath from pooltool.game.layouts import ( @@ -20,6 +19,7 @@ generate_layout, ) from pooltool.objects import BallParams, Table +from pooltool.objects.ball.datatypes import Ball def test_get_ball_ids(): diff --git a/pooltool/logo/logo.blend b/pooltool/logo/logo.blend index c0bf65b6..0e8c186b 100644 Binary files a/pooltool/logo/logo.blend and b/pooltool/logo/logo.blend differ diff --git a/pooltool/logo/logo.png b/pooltool/logo/logo.png index 713a41c0..bfbdc835 100644 Binary files a/pooltool/logo/logo.png and b/pooltool/logo/logo.png differ diff --git a/pooltool/logo/logo_small.png b/pooltool/logo/logo_small.png index 80392ff9..d800b15f 100644 Binary files a/pooltool/logo/logo_small.png and b/pooltool/logo/logo_small.png differ diff --git a/pooltool/objects/ball/datatypes.py b/pooltool/objects/ball/datatypes.py index 0b4327f7..6a77bcbf 100644 --- a/pooltool/objects/ball/datatypes.py +++ b/pooltool/objects/ball/datatypes.py @@ -97,7 +97,7 @@ class BallState: 3 = rolling 4 = pocketed t (float): - The simulated time. + The simulated time. """ rvw: NDArray[np.float64] @@ -259,7 +259,7 @@ def vectorize( def from_vectorization( vectorization: Optional[ Tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]] - ] + ], ) -> BallHistory: """Zips a vectorization into a BallHistory @@ -278,7 +278,7 @@ def from_vectorization( >>> import pooltool as pt >>> history = pt.simulate(pt.System.example(), continuous=True).balls["cue"].history_cts - Illustrate a lossless round trip: + Illustrate a lossless round trip: >>> pt.BallHistory.from_vectorization(history.vectorize()) == history True @@ -322,10 +322,10 @@ class Ball: Attributes: id: An ID for the ball. - + Use strings (e.g. "1" not 1). state: - The ball's state. + The ball's state. This is the current state of the ball. diff --git a/pooltool/objects/ball/params.py b/pooltool/objects/ball/params.py index f61798fa..58ea8ee4 100644 --- a/pooltool/objects/ball/params.py +++ b/pooltool/objects/ball/params.py @@ -121,6 +121,7 @@ def prebuilt(cls, name: PrebuiltBallParams) -> BallParams: class PrebuiltBallParams(StrEnum): """An Enum specifying prebuilt ball parameters""" + POOL_GENERIC = auto() SNOOKER_GENERIC = auto() BILLIARD_GENERIC = auto() diff --git a/pooltool/objects/ball/sets.py b/pooltool/objects/ball/sets.py index b83ce667..47e682bc 100644 --- a/pooltool/objects/ball/sets.py +++ b/pooltool/objects/ball/sets.py @@ -28,7 +28,7 @@ ``generic_snooker`` ballset matches the red ball IDs to the same model ID: .. code:: - + $ cat $(python -c "import pooltool; print(pooltool.__file__[:-12])")/models/balls/generic_snooker/conversion.json { "red_01": "red", @@ -72,8 +72,9 @@ class BallSet: The name of the ballset. During instantiation, the validity of this name will be checked, and a - ValueError will be raised if the ballset doesn't exist. + ValueError will be raised if the ballset doesn't exist. """ + name: str = attrs.field() @name.validator # type: ignore diff --git a/pooltool/objects/ball/test_datatypes.py b/pooltool/objects/ball/test_datatypes.py index 75731c1a..fabe4742 100644 --- a/pooltool/objects/ball/test_datatypes.py +++ b/pooltool/objects/ball/test_datatypes.py @@ -11,7 +11,7 @@ BallState, _null_rvw, ) -from pooltool.objects.ball.sets import BallSet, get_ballset +from pooltool.objects.ball.sets import get_ballset def test__null_rvw(): diff --git a/pooltool/objects/cue/datatypes.py b/pooltool/objects/cue/datatypes.py index 3f7371ae..b8eb7797 100755 --- a/pooltool/objects/cue/datatypes.py +++ b/pooltool/objects/cue/datatypes.py @@ -55,7 +55,7 @@ class Cue: Units are *m/s*. Warning: This is the speed of the cue stick upon impact, not the speed of the - ball upon impact. + ball upon impact. phi: The directional strike angle. @@ -69,7 +69,7 @@ class Cue: - :math:`\\phi = 90` corresponds to striking the cue ball towards the foot rail - :math:`\\phi = 180` corresponds to striking the cue ball to the left - :math:`\\phi = 270` corresponds to striking the cue ball towards the head rail - - :math:`\\phi = 360` corresponds to striking the cue ball to the right + - :math:`\\phi = 360` corresponds to striking the cue ball to the right theta: The cue inclination angle. @@ -79,17 +79,17 @@ class Cue: - :math:`\\theta = 0` corresponds to striking the cue ball parallel with the table (no massé) - :math:`\\theta = 90` corresponds to striking the cue ball downwards into the - table (max massé) + table (max massé) a: The amount and direction of side spin. - :math:`a = -1` is the rightmost side of ball - - :math:`a = +1` is the leftmost side of the ball + - :math:`a = +1` is the leftmost side of the ball b: The amount of top/bottom spin. - :math:`b = -1` is the bottom-most side of the ball - - :math:`b = +1` is the top-most side of the ball + - :math:`b = +1` is the top-most side of the ball cue_ball_id: The ball ID of the ball being cued. specs: diff --git a/pooltool/objects/cue/test_datatypes.py b/pooltool/objects/cue/test_datatypes.py index cc8e1fcf..a52eebe4 100644 --- a/pooltool/objects/cue/test_datatypes.py +++ b/pooltool/objects/cue/test_datatypes.py @@ -14,7 +14,7 @@ def test_cue_copy(): # The specs are the same object but thats ok because `specs` is frozen assert cue.specs is copy.specs with pytest.raises(FrozenInstanceError): - cue.specs.brand = "brunswick" + cue.specs.brand = "brunswick" # type: ignore # modifying cue doesn't affect copy cue.phi += 1 diff --git a/pooltool/objects/datatypes.py b/pooltool/objects/datatypes.py index 9576a212..57d0644d 100644 --- a/pooltool/objects/datatypes.py +++ b/pooltool/objects/datatypes.py @@ -12,6 +12,7 @@ class NullObject: Attributes: id: Object ID. """ + id: str = field(default="dummy") def copy(self) -> NullObject: diff --git a/pooltool/objects/table/components.py b/pooltool/objects/table/components.py index 997d0437..b81b8b54 100644 --- a/pooltool/objects/table/components.py +++ b/pooltool/objects/table/components.py @@ -56,15 +56,15 @@ class LinearCushionSegment: The 3D coordinate where the cushion segment starts. Note: - - p1 and p2 must share the same height (``p1[2] == p2[2]``). + - p1 and p2 must share the same height (``p1[2] == p2[2]``). p2: The 3D coordinate where the cushion segment ends. Note: - - p1 and p2 must share the same height (``p1[2] == p2[2]``). + - p1 and p2 must share the same height (``p1[2] == p2[2]``). direction: The cushion direction (*default* = :attr:`CushionDirection.BOTH`). - + See :class:`CushionDirection` for explanation. """ @@ -207,7 +207,7 @@ class CircularCushionSegment: ``center[0]``, ``center[1]``, and ``center[2]`` are the x-, y-, and z-coordinates of the cushion's center. The circle is assumed to be parallel to - the XY plane, which makes ``center[2]`` is the height of the cushion. + the XY plane, which makes ``center[2]`` is the height of the cushion. radius: The radius of the cushion segment. """ @@ -296,13 +296,13 @@ class CushionSegments: Warning: Keys must match the value IDs, *e.g.* ``{"2": - LinearCushionSegment(id="2", ...)}`` + LinearCushionSegment(id="2", ...)}`` circular: A dictionary of circular cushion segments. Warning: Keys must match the value IDs, *e.g.* ``{"2t": - CircularCushionSegment(id="2t", ...)}`` + CircularCushionSegment(id="2t", ...)}`` """ linear: Dict[str, LinearCushionSegment] = field() @@ -339,7 +339,7 @@ class Pocket: - ``center[0]`` is the x-coordinate of the pocket's center - ``center[1]`` is the y-coordinate of the pocket's center - - ``center[2]`` must be 0.0 + - ``center[2]`` must be 0.0 radius: The radius of the pocket. depth: diff --git a/pooltool/objects/table/datatypes.py b/pooltool/objects/table/datatypes.py index bddb00fa..6b3fbaf3 100644 --- a/pooltool/objects/table/datatypes.py +++ b/pooltool/objects/table/datatypes.py @@ -77,7 +77,7 @@ def w(self) -> float: return x2 - x1 @property - def l(self) -> float: + def l(self) -> float: # noqa F743 """The length of the table. Warning: @@ -183,7 +183,7 @@ def default(cls, table_type: TableType = TableType.POCKET) -> Table: Args: table_type: The type of table. - + Returns: Table: The default table for the given table type. @@ -199,7 +199,7 @@ def from_game_type(cls, game_type: GameType) -> Table: Args: game_type: The game type. - + Returns: Table: The default table for the given game type. diff --git a/pooltool/objects/table/layout.py b/pooltool/objects/table/layout.py index 39a62fdd..ec8aa949 100644 --- a/pooltool/objects/table/layout.py +++ b/pooltool/objects/table/layout.py @@ -60,7 +60,7 @@ def create_billiard_table_cushion_segments( def create_pocket_table_cushion_segments( - specs: Union[PocketTableSpecs, SnookerTableSpecs] + specs: Union[PocketTableSpecs, SnookerTableSpecs], ) -> CushionSegments: # https://ekiefl.github.io/2020/12/20/pooltool-alg/#ball-cushion-collision-times # for diagram @@ -295,7 +295,7 @@ def create_pocket_table_cushion_segments( def create_pocket_table_pockets( - specs: Union[PocketTableSpecs, SnookerTableSpecs] + specs: Union[PocketTableSpecs, SnookerTableSpecs], ) -> Dict[str, Pocket]: cr = specs.corner_pocket_radius sr = specs.side_pocket_radius diff --git a/pooltool/objects/table/specs.py b/pooltool/objects/table/specs.py index 4f4c12d3..09343c6b 100755 --- a/pooltool/objects/table/specs.py +++ b/pooltool/objects/table/specs.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import List, Protocol, Union +from typing import Protocol from attrs import define, field @@ -57,6 +57,7 @@ def null() -> TableModelDescr: class TableType(strenum.StrEnum): """An Enum describing the table type""" + POCKET = strenum.auto() BILLIARD = strenum.auto() SNOOKER = strenum.auto() @@ -65,20 +66,16 @@ class TableType(strenum.StrEnum): class TableSpecs(Protocol): @property - def table_type(self) -> TableType: - ... + def table_type(self) -> TableType: ... @property - def height(self) -> float: - ... + def height(self) -> float: ... @property - def lights_height(self) -> float: - ... + def lights_height(self) -> float: ... @property - def model_descr(self) -> TableModelDescr: - ... + def model_descr(self) -> TableModelDescr: ... @define(frozen=True) @@ -93,7 +90,7 @@ class PocketTableSpecs: """ # 7-foot table (78x39 in^2 playing surface) - l: float = field(default=1.9812) + l: float = field(default=1.9812) # noqa E741 w: float = field(default=1.9812 / 2) cushion_width: float = field(default=2 * 2.54 / 100) @@ -129,7 +126,7 @@ class BilliardTableSpecs: """ # 10-foot table (imprecise) - l: float = field(default=3.05) + l: float = field(default=3.05) # noqa E741 w: float = field(default=3.05 / 2) # FIXME height should be adjusted for 3-cushion sized balls @@ -160,10 +157,11 @@ class SnookerTableSpecs: in the future, snooker tables may have some parameters distinct from standard pool tables (*e.g.* directional cloth), causing these classes to diverge. """ + # https://wpbsa.com/rules/ # The playing area is within the cushion faces and shall measure # 11 ft 8½ in x 5 ft 10 in (3569 mm x 1778 mm) with a tolerance on both dimensions of +/- ½ in (13 mm). - l: float = field(default=3.5445) + l: float = field(default=3.5445) # noqa E741 w: float = field(default=1.7465) cushion_width: float = field(default=1.55 * 25.4 / 1000) diff --git a/pooltool/objects/table/test_components.py b/pooltool/objects/table/test_components.py index 7374bacb..a587cb06 100644 --- a/pooltool/objects/table/test_components.py +++ b/pooltool/objects/table/test_components.py @@ -161,7 +161,7 @@ def test_pocket_copy(pocket): assert pocket == copy # center is read only, so its safe that they share the same reference - pocket.center is copy.center + pocket.center is copy.center # type: ignore # contains is mutable, so separate objects is necessary assert pocket.contains == copy.contains @@ -192,20 +192,12 @@ def test_cushion_segments_copy(lin_seg, circ_seg): def test_cushion_segments_id_clash(lin_seg, circ_seg): # No problem - CushionSegments( - linear={lin_seg.id: lin_seg}, circular={circ_seg.id: circ_seg} - ) + CushionSegments(linear={lin_seg.id: lin_seg}, circular={circ_seg.id: circ_seg}) # Keys don't match value IDs with pytest.raises(AssertionError): - CushionSegments( - linear={"wrong": lin_seg}, circular={circ_seg.id: circ_seg} - ) + CushionSegments(linear={"wrong": lin_seg}, circular={circ_seg.id: circ_seg}) with pytest.raises(AssertionError): - CushionSegments( - linear={lin_seg.id: lin_seg}, circular={"wrong": circ_seg} - ) + CushionSegments(linear={lin_seg.id: lin_seg}, circular={"wrong": circ_seg}) with pytest.raises(AssertionError): - CushionSegments( - linear={":(": lin_seg}, circular={"wrong": circ_seg} - ) + CushionSegments(linear={":(": lin_seg}, circular={"wrong": circ_seg}) diff --git a/pooltool/physics/engine.py b/pooltool/physics/engine.py index 64a51d6c..6cb702d7 100644 --- a/pooltool/physics/engine.py +++ b/pooltool/physics/engine.py @@ -18,4 +18,5 @@ class PhysicsEngine: resolver: The physics engine responsible for resolving events. """ + resolver: Resolver = attrs.field(factory=Resolver.default) diff --git a/pooltool/physics/evolve/__init__.py b/pooltool/physics/evolve/__init__.py index 788cdce4..f8a444d7 100644 --- a/pooltool/physics/evolve/__init__.py +++ b/pooltool/physics/evolve/__init__.py @@ -59,6 +59,8 @@ def evolve_ball_motion(state, rvw, R, m, u_s, u_sp, u_r, g, t): else: return evolve_perpendicular_spin_state(rvw, R, u_sp, g, t), const.spinning + raise ValueError + @jit(nopython=True, cache=const.use_numba_cache) def evolve_slide_state(rvw, R, m, u_s, u_sp, g, t): diff --git a/pooltool/physics/resolve/__init__.py b/pooltool/physics/resolve/__init__.py index 51ff47c1..c968d496 100644 --- a/pooltool/physics/resolve/__init__.py +++ b/pooltool/physics/resolve/__init__.py @@ -1 +1,5 @@ from pooltool.physics.resolve.resolver import Resolver + +__all__ = [ + "Resolver", +] diff --git a/pooltool/physics/resolve/ball_ball/core.py b/pooltool/physics/resolve/ball_ball/core.py index 0c6350ad..8bd4c1c6 100644 --- a/pooltool/physics/resolve/ball_ball/core.py +++ b/pooltool/physics/resolve/ball_ball/core.py @@ -7,18 +7,15 @@ class _BaseStrategy(Protocol): - def make_kiss(self, ball1: Ball, ball2: Ball) -> Tuple[Ball, Ball]: - ... + def make_kiss(self, ball1: Ball, ball2: Ball) -> Tuple[Ball, Ball]: ... def resolve( self, ball1: Ball, ball2: Ball, inplace: bool = False - ) -> Tuple[Ball, Ball]: - ... + ) -> Tuple[Ball, Ball]: ... class BallBallCollisionStrategy(_BaseStrategy, Protocol): - def solve(self, ball1: Ball, ball2: Ball) -> Tuple[Ball, Ball]: - ... + def solve(self, ball1: Ball, ball2: Ball) -> Tuple[Ball, Ball]: ... class CoreBallBallCollision(ABC): diff --git a/pooltool/physics/resolve/ball_cushion/core.py b/pooltool/physics/resolve/ball_cushion/core.py index 6a983636..a3f86a97 100644 --- a/pooltool/physics/resolve/ball_cushion/core.py +++ b/pooltool/physics/resolve/ball_cushion/core.py @@ -13,37 +13,31 @@ class _BaseLinearStrategy(Protocol): - def make_kiss(self, ball: Ball, cushion: LinearCushionSegment) -> Ball: - ... + def make_kiss(self, ball: Ball, cushion: LinearCushionSegment) -> Ball: ... def resolve( self, ball: Ball, cushion: LinearCushionSegment, inplace: bool = False - ) -> Tuple[Ball, LinearCushionSegment]: - ... + ) -> Tuple[Ball, LinearCushionSegment]: ... class _BaseCircularStrategy(Protocol): - def make_kiss(self, ball: Ball, cushion: CircularCushionSegment) -> Ball: - ... + def make_kiss(self, ball: Ball, cushion: CircularCushionSegment) -> Ball: ... def resolve( self, ball: Ball, cushion: CircularCushionSegment, inplace: bool = False - ) -> Tuple[Ball, CircularCushionSegment]: - ... + ) -> Tuple[Ball, CircularCushionSegment]: ... class BallLCushionCollisionStrategy(_BaseLinearStrategy, Protocol): def solve( self, ball: Ball, cushion: LinearCushionSegment - ) -> Tuple[Ball, LinearCushionSegment]: - ... + ) -> Tuple[Ball, LinearCushionSegment]: ... class BallCCushionCollisionStrategy(_BaseCircularStrategy, Protocol): def solve( self, ball: Ball, cushion: CircularCushionSegment - ) -> Tuple[Ball, CircularCushionSegment]: - ... + ) -> Tuple[Ball, CircularCushionSegment]: ... class CoreBallLCushionCollision(ABC): diff --git a/pooltool/physics/resolve/ball_cushion/han_2005/__init__.py b/pooltool/physics/resolve/ball_cushion/han_2005/__init__.py index fb3903ab..e9ac55e8 100644 --- a/pooltool/physics/resolve/ball_cushion/han_2005/__init__.py +++ b/pooltool/physics/resolve/ball_cushion/han_2005/__init__.py @@ -3,3 +3,9 @@ Han2005Linear, han2005, ) + +__all__ = [ + "Han2005Circular", + "Han2005Linear", + "han2005", +] diff --git a/pooltool/physics/resolve/ball_cushion/han_2005/model.py b/pooltool/physics/resolve/ball_cushion/han_2005/model.py index 4021ace0..51cb3518 100644 --- a/pooltool/physics/resolve/ball_cushion/han_2005/model.py +++ b/pooltool/physics/resolve/ball_cushion/han_2005/model.py @@ -49,7 +49,7 @@ def han2005(rvw, normal, R, m, h, e_c, f_c): c = rvw_R[1, 0] * np.cos(theta_a) # 2D assumption # Eqs 16 - I = 2 / 5 * m * R**2 + II = 2 / 5 * m * R**2 A = 7 / 2 / m B = 1 / m @@ -78,9 +78,9 @@ def han2005(rvw, normal, R, m, h, e_c, f_c): # rvw_R[1,2] += PZ/m # Update angular velocity - rvw_R[2, 0] += -R / I * PY * np.sin(theta_a) - rvw_R[2, 1] += R / I * (PX * np.sin(theta_a) - PZ * np.cos(theta_a)) - rvw_R[2, 2] += R / I * PY * np.cos(theta_a) + rvw_R[2, 0] += -R / II * PY * np.sin(theta_a) + rvw_R[2, 1] += R / II * (PX * np.sin(theta_a) - PZ * np.cos(theta_a)) + rvw_R[2, 2] += R / II * PY * np.cos(theta_a) # Change back to table reference frame rvw = ptmath.coordinate_rotation(rvw_R.T, psi).T diff --git a/pooltool/physics/resolve/ball_pocket/__init__.py b/pooltool/physics/resolve/ball_pocket/__init__.py index 11b54993..de04f96a 100644 --- a/pooltool/physics/resolve/ball_pocket/__init__.py +++ b/pooltool/physics/resolve/ball_pocket/__init__.py @@ -4,6 +4,7 @@ collisions, expand this file into a file structure modelled after ../ball_ball or ../ball_cushion """ + from typing import Dict, Optional, Protocol, Tuple, Type import numpy as np @@ -18,8 +19,7 @@ class BallPocketStrategy(Protocol): def resolve( self, ball: Ball, pocket: Pocket, inplace: bool = False - ) -> Tuple[Ball, Pocket]: - ... + ) -> Tuple[Ball, Pocket]: ... class CanonicalBallPocket: diff --git a/pooltool/physics/resolve/resolver.py b/pooltool/physics/resolve/resolver.py index fd77ee6d..b85dc432 100644 --- a/pooltool/physics/resolve/resolver.py +++ b/pooltool/physics/resolve/resolver.py @@ -1,4 +1,5 @@ """Resolve collisions and transitions""" + from __future__ import annotations from pathlib import Path @@ -99,6 +100,7 @@ class Resolver: For everything you need to know about this class, see :doc:`Modular Physics `. """ + ball_ball: BallBallCollisionStrategy ball_linear_cushion: BallLCushionCollisionStrategy ball_circular_cushion: BallCCushionCollisionStrategy diff --git a/pooltool/physics/resolve/stick_ball/core.py b/pooltool/physics/resolve/stick_ball/core.py index 1fdfde73..123bbf4d 100644 --- a/pooltool/physics/resolve/stick_ball/core.py +++ b/pooltool/physics/resolve/stick_ball/core.py @@ -6,13 +6,13 @@ class _BaseStrategy(Protocol): - def resolve(self, cue: Cue, ball: Ball, inplace: bool = False) -> Tuple[Cue, Ball]: - ... + def resolve( + self, cue: Cue, ball: Ball, inplace: bool = False + ) -> Tuple[Cue, Ball]: ... class StickBallCollisionStrategy(_BaseStrategy, Protocol): - def solve(self, cue: Cue, ball: Ball) -> Tuple[Cue, Ball]: - ... + def solve(self, cue: Cue, ball: Ball) -> Tuple[Cue, Ball]: ... class CoreStickBallCollision(ABC): diff --git a/pooltool/physics/resolve/stick_ball/instantaneous_point/__init__.py b/pooltool/physics/resolve/stick_ball/instantaneous_point/__init__.py index 873d05de..7d8b644f 100644 --- a/pooltool/physics/resolve/stick_ball/instantaneous_point/__init__.py +++ b/pooltool/physics/resolve/stick_ball/instantaneous_point/__init__.py @@ -70,7 +70,7 @@ def cue_strike(m, M, R, V0, phi, theta, a, b, throttle_english: bool): phi *= np.pi / 180 theta *= np.pi / 180 - I = 2 / 5 * m * R**2 + II = 2 / 5 * m * R**2 c = np.sqrt(R**2 - a**2 - b**2) @@ -96,7 +96,7 @@ def cue_strike(m, M, R, V0, phi, theta, a, b, throttle_english: bool): vec_z = -a * np.cos(theta) vec = np.array([vec_x, vec_y, vec_z]) - w_B = F / I * vec + w_B = F / II * vec # Rotate to table reference rot_angle = phi + np.pi / 2 diff --git a/pooltool/physics/resolve/transition/__init__.py b/pooltool/physics/resolve/transition/__init__.py index 85688db4..3a145cc3 100644 --- a/pooltool/physics/resolve/transition/__init__.py +++ b/pooltool/physics/resolve/transition/__init__.py @@ -4,6 +4,7 @@ transitions, expand this file into a file structure modelled after ../ball_ball or ../ball_cushion """ + from typing import Dict, Optional, Protocol, Tuple, Type import numpy as np @@ -16,8 +17,9 @@ class BallTransitionStrategy(Protocol): - def resolve(self, ball: Ball, transition: EventType, inplace: bool = False) -> Ball: - ... + def resolve( + self, ball: Ball, transition: EventType, inplace: bool = False + ) -> Ball: ... class CanonicalTransition: diff --git a/pooltool/physics/utils.py b/pooltool/physics/utils.py index 9f382302..2d2b4ca4 100644 --- a/pooltool/physics/utils.py +++ b/pooltool/physics/utils.py @@ -64,8 +64,7 @@ def get_ball_energy(rvw, R, m): LKE = m * ptmath.norm3d(rvw[1]) ** 2 / 2 # Rotational - I = 2 / 5 * m * R**2 - RKE = I * ptmath.norm3d(rvw[2]) ** 2 / 2 + RKE = (2 / 5 * m * R**2) * ptmath.norm3d(rvw[2]) ** 2 / 2 return LKE + RKE diff --git a/pooltool/ptmath/__init__.py b/pooltool/ptmath/__init__.py index 2e6f032d..88226661 100644 --- a/pooltool/ptmath/__init__.py +++ b/pooltool/ptmath/__init__.py @@ -1,3 +1,4 @@ +import pooltool.ptmath.roots as roots from pooltool.ptmath._ptmath import ( angle, angle_between_vectors, @@ -8,6 +9,7 @@ find_intersection_2D, norm2d, norm3d, + orientation, point_on_line_closest_to_point, solve_transcendental, unit_vector, @@ -16,7 +18,9 @@ ) __all__ = [ + "roots", "angle", + "orientation", "angle_between_vectors", "coordinate_rotation", "cross", diff --git a/pooltool/ptmath/_ptmath.py b/pooltool/ptmath/_ptmath.py index e8f53b9f..7fe20188 100644 --- a/pooltool/ptmath/_ptmath.py +++ b/pooltool/ptmath/_ptmath.py @@ -50,6 +50,33 @@ def solve_transcendental(f, a, b, tol=1e-5, max_iter=100) -> float: return c +@jit(nopython=True, cache=const.use_numba_cache) +def orientation(p, q, r): + """Find the orientation of an ordered triplet (p, q, r) + + See https://www.geeksforgeeks.org/orientation-3-ordered-points/amp/ + + Notes + ===== + - 3D points may be passed but only the x and y components are used + + Returns + ======= + output : int + 0 : Collinear points, 1 : Clockwise points, 2 : Counterclockwise + """ + val = ((q[1] - p[1]) * (r[0] - q[0])) - ((q[0] - p[0]) * (r[1] - q[1])) + if val > 0: + # Clockwise orientation + return 1 + elif val < 0: + # Counterclockwise orientation + return 2 + else: + # Collinear orientation + return 0 + + def convert_2D_to_3D(array: NDArray) -> NDArray: """Convert a 2D vector to a 3D vector, setting z=0""" return np.pad(array, (0, 1), "constant", constant_values=(0,)) # type: ignore diff --git a/pooltool/ptmath/roots/__init__.py b/pooltool/ptmath/roots/__init__.py index c98c12b8..725b35bd 100644 --- a/pooltool/ptmath/roots/__init__.py +++ b/pooltool/ptmath/roots/__init__.py @@ -2,3 +2,10 @@ import pooltool.ptmath.roots.quartic as quartic from pooltool.ptmath.roots.core import min_real_root from pooltool.ptmath.roots.quartic import minimum_quartic_root + +__all__ = [ + "quadratic", + "quartic", + "min_real_root", + "minimum_quartic_root", +] diff --git a/pooltool/ptmath/roots/quartic.py b/pooltool/ptmath/roots/quartic.py index 1c00d0a1..2bbe0eba 100644 --- a/pooltool/ptmath/roots/quartic.py +++ b/pooltool/ptmath/roots/quartic.py @@ -63,22 +63,21 @@ def solve_many_numerical(p): the columns are in the order a, b, c, d, e, where these coefficients make up the polynomial equation at^4 + bt^3 + ct^2 + dt + e = 0 - Notes - ===== - - Not yet amenable to numbaization (0.56.4). Problem is the numba implementation of - np.linalg.eigvals, which only supports 2D arrays, but the strategy here is to pass - np.lingalg.eigvals as a vectorized 3D array. Nevertheless, here is a numba - implementation that is just slightly slower (7% slower) than this function: - - n = p.shape[-1] - A = np.zeros(p.shape[:1] + (n - 1, n - 1), dtype=np.complex128) - A[:, 1:, :-1] = np.eye(n - 2) - p0 = np.copy(p[:, 0]).reshape((-1, 1)) - A[:, 0, :] = -p[:, 1:] / p0 - roots = np.zeros((p.shape[0], n - 1), dtype=np.complex128) - for i in range(p.shape[0]): - roots[i, :] = np.linalg.eigvals(A[i, :, :]) - return roots + Notes: + - Not yet amenable to numbaization (0.56.4). Problem is the numba implementation of + np.linalg.eigvals, which only supports 2D arrays, but the strategy here is to pass + np.lingalg.eigvals as a vectorized 3D array. Nevertheless, here is a numba + implementation that is just slightly slower (7% slower) than this function: + + n = p.shape[-1] + A = np.zeros(p.shape[:1] + (n - 1, n - 1), dtype=np.complex128) + A[:, 1:, :-1] = np.eye(n - 2) + p0 = np.copy(p[:, 0]).reshape((-1, 1)) + A[:, 0, :] = -p[:, 1:] / p0 + roots = np.zeros((p.shape[0], n - 1), dtype=np.complex128) + for i in range(p.shape[0]): + roots[i, :] = np.linalg.eigvals(A[i, :, :]) + return roots """ n = p.shape[-1] A = np.zeros(p.shape[:1] + (n - 1, n - 1), np.float64) @@ -294,9 +293,9 @@ def analytic(p: NDArray[np.complex128]) -> NDArray[np.complex128]: x25 = b * x0 / 4 x26 = x24 + x25 x27 = -(c**2) * x2 / 12 - x13 - x28 = ( - x12 / 16 - x14 * x5 / 6 + x6 / 216 + np.sqrt(x15**2 / 4 + x27**3 / 27) - ) ** (1 / 3) or const.EPS + x28 = (x12 / 16 - x14 * x5 / 6 + x6 / 216 + np.sqrt(x15**2 / 4 + x27**3 / 27)) ** ( + 1 / 3 + ) or const.EPS x29 = 2 * x28 x30 = 2 * x27 / (3 * x28) x31 = -x29 + x30 @@ -329,7 +328,7 @@ def analytic(p: NDArray[np.complex128]) -> NDArray[np.complex128]: def _truth(a_val, b_val, c_val, d_val, e_val, digits=50): - import sympy + import sympy # type: ignore x, a, b, c, d, e = sympy.symbols("x a b c d e") general_solution = sympy.solve(a * x**4 + b * x**3 + c * x**2 + d * x + e, x) diff --git a/pooltool/ptmath/test_ptmath.py b/pooltool/ptmath/test_ptmath.py index b3661be5..783777d0 100644 --- a/pooltool/ptmath/test_ptmath.py +++ b/pooltool/ptmath/test_ptmath.py @@ -1,9 +1,6 @@ import pytest -from pooltool.ptmath._ptmath import ( - are_points_on_same_side, - solve_transcendental, -) +from pooltool.ptmath._ptmath import are_points_on_same_side, solve_transcendental def test_are_points_on_same_side(): @@ -39,18 +36,18 @@ def test_are_points_on_same_side(): def test_transcendental_linear_equation(): - f = lambda x: x - 5 + f = lambda x: x - 5 # noqa E731 root = solve_transcendental(f, 0, 10) assert pytest.approx(root, 0.00001) == 5.0 def test_transcendental_nonlinear_equation(): - f = lambda x: x**2 - 4 * x + 3 + f = lambda x: x**2 - 4 * x + 3 # noqa E731 root = solve_transcendental(f, 0, 2.5) assert pytest.approx(root, 0.00001) == 1.0 def test_transcendental_no_root_error(): - f = lambda x: x**2 + 1 + f = lambda x: x**2 + 1 # noqa E731 with pytest.raises(ValueError): solve_transcendental(f, 0, 10) diff --git a/pooltool/serialize/__init__.py b/pooltool/serialize/__init__.py index 14d06560..b5f129f1 100644 --- a/pooltool/serialize/__init__.py +++ b/pooltool/serialize/__init__.py @@ -62,4 +62,5 @@ "from_json", "from_msgpack", "from_yaml", + "Pathish", ] diff --git a/pooltool/serialize/convert.py b/pooltool/serialize/convert.py index 919c5302..e92bedcd 100644 --- a/pooltool/serialize/convert.py +++ b/pooltool/serialize/convert.py @@ -1,13 +1,14 @@ -from typing import Dict, Iterable, TypeVar, Callable, Type, Any, Optional - from pathlib import Path +from typing import Any, Callable, Dict, Iterable, Optional, Type, TypeVar + from attrs import define from cattrs.converters import Converter + from pooltool.serialize.serializers import ( + Pathish, SerializeFormat, - serializers, deserializers, - Pathish, + serializers, ) T = TypeVar("T") diff --git a/pooltool/serialize/serializers.py b/pooltool/serialize/serializers.py index 6b05cf88..158ff57f 100644 --- a/pooltool/serialize/serializers.py +++ b/pooltool/serialize/serializers.py @@ -43,7 +43,9 @@ def from_yaml(path: Pathish) -> Any: def to_msgpack(o: Any, path: Pathish) -> None: with open(path, "wb") as fp: - fp.write(msgpack.packb(o, default=m.encode)) + packed = msgpack.packb(o, default=m.encode) + assert isinstance(packed, bytes), "msgpack.packb must return bytes" + fp.write(packed) def from_msgpack(path: Pathish) -> Any: diff --git a/pooltool/system/datatypes.py b/pooltool/system/datatypes.py index 887dbea2..273c5576 100644 --- a/pooltool/system/datatypes.py +++ b/pooltool/system/datatypes.py @@ -17,6 +17,7 @@ from pooltool.serialize import conversion from pooltool.serialize.serializers import Pathish + @define class System: """A class representing the billiards system. @@ -43,11 +44,11 @@ class System: table: A table. balls: - A dictionary of balls. + A dictionary of balls. Warning: Each key must match each value's ``id`` (`e.g.` ``{"2": Ball(id="1")}`` - is invalid). + is invalid). t: The elapsed simulation time. If the system is in the process of being simulated, ``t`` is updated to be the number of seconds the system has @@ -92,6 +93,7 @@ class System: >>> gui = pt.ShotViewer() >>> gui.show(system) """ + cue: Cue = field() table: Table = field() balls: Dict[str, Ball] = field() @@ -632,6 +634,7 @@ class MultiSystem: >>> gui = pt.ShotViewer() >>> gui.show(multisystem, title="Press 'n' for next, 'p' for previous") """ + multisystem: List[System] = field(factory=list) active_index: Optional[int] = field(default=None, init=False) diff --git a/pooltool/terminal.py b/pooltool/terminal.py index 1a3396fb..83ca8cb7 100644 --- a/pooltool/terminal.py +++ b/pooltool/terminal.py @@ -14,8 +14,7 @@ import textwrap import time from collections import OrderedDict - -import pandas as pd +from typing import Optional def get_color_objects(): @@ -37,7 +36,7 @@ def get_color_objects(): else: class NoColored(object): - def __getattr__(self, attr): + def __getattr__(self, _): return "" class Fore(NoColored): @@ -135,8 +134,8 @@ class Progress: def __init__(self, verbose=True): self.pid = None self.verbose = verbose - self.terminal_width = None self.is_tty = sys.stdout.isatty() + self.terminal_width: int self.get_terminal_width() @@ -214,39 +213,44 @@ def write(self, c, dont_update_current=False): # see a full list of color codes: https://gitlab.com/dslackw/colored if p_length >= break_point: sys.stderr.write( - back.CYAN - + fore.BLACK + getattr(back, "CYAN") + + getattr(fore, "BLACK") + c[:break_point] - + back.GREY_30 - + fore.WHITE + + getattr(back, "GREY_30") + + getattr(fore, "WHITE") + c[break_point:end_point] - + back.CYAN - + fore.CYAN + + getattr(back, "CYAN") + + getattr(fore, "CYAN") + c[end_point] - + back.GREY_50 - + fore.LIGHT_CYAN + + getattr(back, "GREY_50") + + getattr(fore, "LIGHT_CYAN") + c[end_point:] - + style.RESET + + getattr(style, "RESET") ) else: sys.stderr.write( - back.CYAN - + fore.BLACK + getattr(back, "CYAN") + + getattr(fore, "BLACK") + c[: break_point - p_length] - + back.SALMON_1 - + fore.BLACK + + getattr(back, "SALMON_1") + + getattr(fore, "BLACK") + p_text - + back.GREY_30 - + fore.WHITE + + getattr(back, "GREY_30") + + getattr(fore, "WHITE") + c[break_point:end_point] - + back.GREY_50 - + fore.LIGHT_CYAN + + getattr(back, "GREY_50") + + getattr(fore, "LIGHT_CYAN") + c[end_point:] - + style.RESET + + getattr(style, "RESET") ) sys.stderr.flush() else: - sys.stderr.write(back.CYAN + fore.BLACK + c + style.RESET) + sys.stderr.write( + getattr(back, "CYAN") + + getattr(fore, "BLACK") + + c + + getattr(style, "RESET") + ) sys.stderr.flush() def reset(self): @@ -313,7 +317,7 @@ def update(self, msg, increment=False): self.clear() self.write("\r[%s] %s" % (self.pid, msg)) - def end(self, timing_filepath=None): + def end(self): """End the current progress Parameters @@ -323,13 +327,11 @@ def end(self, timing_filepath=None): will only be made if a progress_total_items parameter was made during self.new() """ - - if timing_filepath and self.progress_total_items is not None: - self.t.gen_file_report(timing_filepath) - self.pid = None + if not self.verbose: return + self.clear() @@ -442,10 +444,10 @@ def info_single( if progress: progress.clear() - self.write(message_line, overwrite_verbose=False) + self.write(message_line, overwrite_verbose=overwrite_verbose) progress.update(progress.msg) else: - self.write(message_line, overwrite_verbose=False) + self.write(message_line, overwrite_verbose=overwrite_verbose) def warning( self, @@ -569,48 +571,6 @@ def make_checkpoint(self, checkpoint_key=None, increment_to=None): return checkpoint - def gen_report(self, title="Time Report", run=Run()): - checkpoint_last = self.initial_checkpoint_key - - run.warning("", header=title, lc="yellow", nl_before=1, nl_after=0) - - for checkpoint_key, checkpoint in self.checkpoints.items(): - if checkpoint_key == self.initial_checkpoint_key: - continue - - run.info( - str(checkpoint_key), - "+%s" - % self.timedelta_to_checkpoint( - checkpoint, checkpoint_key=checkpoint_last - ), - ) - checkpoint_last = checkpoint_key - - run.info( - "Total elapsed", - "=%s" - % self.timedelta_to_checkpoint( - checkpoint, checkpoint_key=self.initial_checkpoint_key - ), - ) - - def gen_dataframe_report(self): - """Returns a dataframe""" - - d = {"key": [], "time": [], "score": []} - for checkpoint_key, checkpoint in self.checkpoints.items(): - d["key"].append(checkpoint_key) - d["time"].append(checkpoint) - d["score"].append(self.scores[checkpoint_key]) - - return pd.DataFrame(d) - - def gen_file_report(self, filepath): - """Writes to filepath, will overwrite""" - - self.gen_dataframe_report().to_csv(filepath, sep="\t", index=False) - def calculate_time_remaining(self, infinite_default="∞:∞:∞"): if self.complete: return datetime.timedelta(seconds=0) @@ -656,7 +616,12 @@ def time_elapsed(self, fmt=None): def time_elapsed_precise(self): return self.timedelta_to_checkpoint(self.timestamp(), checkpoint_key=0) - def format_time(self, timedelta, fmt="{hours}:{minutes}:{seconds}", zero_padding=2): + def format_time( + self, + timedelta, + fmt: Optional[str] = "{hours}:{minutes}:{seconds}", + zero_padding: int = 2, + ): """Formats time Examples of `fmt`. Suppose the timedelta is seconds = 1, minutes = 1, hours = 1. @@ -676,7 +641,7 @@ def format_time(self, timedelta, fmt="{hours}:{minutes}:{seconds}", zero_padding "seconds": 1, } - if not fmt: + if fmt is None: # use the highest two non-zero units, e.g. if it is 7200s, use # {hours}h{minutes}m seconds = int(timedelta.total_seconds()) @@ -704,6 +669,8 @@ def format_time(self, timedelta, fmt="{hours}:{minutes}:{seconds}", zero_padding else: m *= unit_denominations[unit] + assert isinstance(fmt, str) + # parse units present in fmt format_order = [] for i, x in enumerate(fmt): @@ -885,6 +852,9 @@ def __exit__(self, exception_type, exception_value, traceback): return_code = 0 if exception_type is None else 1 msg, color = (self.s_msg, self.sc) if not return_code else (self.f_msg, self.fc) + + assert msg is not None + self.run.info_single( msg + str(self.time), nl_before=1, mc=color, level=return_code ) @@ -918,8 +888,8 @@ def ioctl_GWINSZ(fd): import fcntl import termios - cr = struct.unpack("hh", fcntl.ioctl(fd, termios.TIOCGWINSZ, "1234")) - except: + cr = struct.unpack("hh", fcntl.ioctl(fd, termios.TIOCGWINSZ, "1234")) # type: ignore + except Exception: return None return cr @@ -929,11 +899,11 @@ def ioctl_GWINSZ(fd): fd = os.open(os.ctermid(), os.O_RDONLY) cr = ioctl_GWINSZ(fd) os.close(fd) - except: + except Exception: pass if not cr: try: cr = (os.environ["LINES"], os.environ["COLUMNS"]) - except: + except Exception: cr = (25, 80) return int(cr[1]), int(cr[0]) diff --git a/pooltool/utils/__init__.py b/pooltool/utils/__init__.py index a8b6c4fb..fa24f30a 100644 --- a/pooltool/utils/__init__.py +++ b/pooltool/utils/__init__.py @@ -8,11 +8,8 @@ import pandas as pd import pprofile -from numba import jit from panda3d.core import Filename -import pooltool.constants as c - class classproperty(property): """Decorator for a class property @@ -25,8 +22,8 @@ class classproperty(property): >>> return cls.__name__ """ - def __get__(self, owner_self, owner_cls): - return self.fget(owner_cls) + def __get__(self, owner_self, owner_cls): # type: ignore + return self.fget(owner_cls) # type: ignore def save_pickle(x, path): @@ -226,48 +223,21 @@ def human_readable_file_size(nbytes): return "%s %s" % (f, suffixes[i]) -@jit(nopython=True, cache=c.use_numba_cache) -def orientation(p, q, r): - """Find the orientation of an ordered triplet (p, q, r) - - See https://www.geeksforgeeks.org/orientation-3-ordered-points/amp/ - - Notes - ===== - - 3D points may be passed but only the x and y components are used - - Returns - ======= - output : int - 0 : Collinear points, 1 : Clockwise points, 2 : Counterclockwise - """ - val = ((q[1] - p[1]) * (r[0] - q[0])) - ((q[0] - p[0]) * (r[1] - q[1])) - if val > 0: - # Clockwise orientation - return 1 - elif val < 0: - # Counterclockwise orientation - return 2 - else: - # Collinear orientation - return 0 - - class PProfile(pprofile.Profile): """Small wrapper for pprofile that accepts a filepath and outputs cachegrind file""" - def __init__(self, path, run=True): - self.run = run + def __init__(self, path, should_run: bool = True): + self.should_run = should_run self.path = path pprofile.Profile.__init__(self) def __enter__(self): - if self.run: - return pprofile.Profile.__enter__(self) + if self.should_run: + return super().__enter__() else: return self def __exit__(self, *args): - if self.run: + if self.should_run: pprofile.Profile.__exit__(self, *args) self.dump_stats(self.path) diff --git a/pooltool/utils/strenum.py b/pooltool/utils/strenum.py index ff95a145..7f7ccde0 100644 --- a/pooltool/utils/strenum.py +++ b/pooltool/utils/strenum.py @@ -20,8 +20,9 @@ def __new__(cls: Type[_S], *values: str) -> _S: raise TypeError("%r is not a string" % (values[0],)) if len(values) >= 2: # check that encoding argument is a string - if not isinstance(values[1], str): - raise TypeError("encoding must be a string, not %r" % (values[1],)) + value = values[1] # type: ignore + if not isinstance(value, str): + raise TypeError("encoding must be a string, not %r" % (value,)) if len(values) == 3: # check that errors argument is a string if not isinstance(values[2], str): @@ -31,7 +32,7 @@ def __new__(cls: Type[_S], *values: str) -> _S: member._value_ = value return member - __str__ = str.__str__ + __str__ = str.__str__ # type: ignore @staticmethod def _generate_next_value_( diff --git a/pyproject.toml b/pyproject.toml index 69588835..0f4c574a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,13 @@ -[tool.black] +[tool.ruff] line-length = 88 -[tool.isort] -profile = "black" +[tool.ruff.lint] +extend-select = ["I"] + +[tool.pytest.ini_options] +testpaths = [ + "pooltool", +] + +# [tool.pyright] +# pyright config is stored in pyrightconfig.json (local) and pyrightconfig.ci.json (CI) diff --git a/pyrightconfig.ci.json b/pyrightconfig.ci.json new file mode 100644 index 00000000..32a2cbf3 --- /dev/null +++ b/pyrightconfig.ci.json @@ -0,0 +1,14 @@ +{ + "include": [ + "pooltool" + ], + + "exclude": [ + "**/render.py", + "**/__pycache__" + ], + + "ignore": [ + "pooltool/ani" + ] +} diff --git a/requirements-dev.txt b/requirements-dev.txt index 17614423..68ba8a3a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,8 +3,8 @@ colored ipython pytest pre-commit -black -isort +pyright +ruff types-Pillow twine types-PyYAML diff --git a/sandbox/arena.py b/sandbox/arena.py index ede9c385..1f36611b 100644 --- a/sandbox/arena.py +++ b/sandbox/arena.py @@ -6,7 +6,7 @@ import pooltool as pt from pooltool.physics.utils import is_overlapping -get_pos = lambda table, ball: ( +get_pos = lambda table, ball: ( # noqa E731 (table.w - 2 * ball.params.R) * np.random.rand() + ball.params.R, (table.l - 2 * ball.params.R) * np.random.rand() + ball.params.R, ball.params.R, diff --git a/sandbox/custom_table.py b/sandbox/custom_table.py index 953aa0d2..a44dc469 100644 --- a/sandbox/custom_table.py +++ b/sandbox/custom_table.py @@ -2,10 +2,13 @@ """This examples how to make a custom pool table and ball parameters""" from typing import Optional -import pooltool as pt + import numpy as np + +import pooltool as pt from pooltool.ptmath import norm3d + def custom_ball_params() -> pt.BallParams: return pt.BallParams( m=0.170097, @@ -18,22 +21,23 @@ def custom_ball_params() -> pt.BallParams: g=9.81, ) + def custom_table_specs() -> pt.PocketTableSpecs: return pt.PocketTableSpecs( - l = 1.9812, - w = 1.9812 / 2, - cushion_width = 2 * 2.54 / 100, - cushion_height = 0.64 * 2 * 0.028575, - corner_pocket_width = 0.10, - corner_pocket_angle = 1, - corner_pocket_depth = 0.0398, - corner_pocket_radius = 0.124 / 2, - corner_jaw_radius = 0.08, - side_pocket_width = 0.08, - side_pocket_angle = 3, - side_pocket_depth = 0.00437, - side_pocket_radius = 0.129 / 2, - side_jaw_radius = 0.03, + l=1.9812, + w=1.9812 / 2, + cushion_width=2 * 2.54 / 100, + cushion_height=0.64 * 2 * 0.028575, + corner_pocket_width=0.10, + corner_pocket_angle=1, + corner_pocket_depth=0.0398, + corner_pocket_radius=0.124 / 2, + corner_jaw_radius=0.08, + side_pocket_width=0.08, + side_pocket_angle=3, + side_pocket_depth=0.00437, + side_pocket_radius=0.129 / 2, + side_jaw_radius=0.03, ) @@ -56,7 +60,6 @@ def closest_ball(system: pt.System) -> str: return closest_id - def main(): # Ball parameters and table specifications ball_params = custom_ball_params()