Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/beartype exceptions #38

Merged
merged 11 commits into from
May 6, 2024
17 changes: 16 additions & 1 deletion .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@ permissions:

jobs:
pylint:

runs-on: ubuntu-latest

permissions: # Job-level permissions configuration starts here
contents: write # 'write' access to repository contents
pull-requests: write # 'write' access to pull requests

steps:
- uses: actions/checkout@master
with:
persist-credentials: false # otherwise, the token used is the GITHUB_TOKEN, instead of your personal access token.
fetch-depth: 0 # otherwise, there would be errors pushing refs to the destination repository.

- name: Setup Python
uses: actions/setup-python@v2
with:
Expand Down Expand Up @@ -97,5 +106,11 @@ jobs:

git add .github/.pylint_cache
git commit -m "Add .github/.pylint_cache on push event"
git push

- name: Push changes
if: github.event_name == 'push'
uses: ad-m/github-push-action@master
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
branch: ${{ github.ref }}

20 changes: 17 additions & 3 deletions PQAnalysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,26 @@

from PQAnalysis.utils.custom_logging import CustomLogger

beartype_this_package()

__base_path__ = Path(__file__).parent

__package_name__ = __name__

##################
# BEARTYPE SETUP #
##################

# TODO: change the default level to "RELEASE" after all changes are implemented
__beartype_default_level__ = "DEBUG"
__beartype_level__ = os.getenv(
"PQANALYSIS_BEARTYPE_LEVEL", __beartype_default_level__
)

if __beartype_level__.upper() == "DEBUG":
beartype_this_package()

Check warning on line 31 in PQAnalysis/__init__.py

View check run for this annotation

Codecov / codecov/patch

PQAnalysis/__init__.py#L31

Added line #L31 was not covered by tests

#################
# LOGGING SETUP #
#################

logging_env_var = os.getenv("PQANALYSIS_LOGGING_LEVEL")

if logging_env_var and logging_env_var not in logging.getLevelNamesMapping():
Expand Down
9 changes: 6 additions & 3 deletions PQAnalysis/analysis/rdf/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,36 @@
from PQAnalysis.io import TrajectoryReader, RestartFileReader, MoldescriptorReader
from PQAnalysis.traj import MDEngineFormat
from PQAnalysis.topology import Topology
from PQAnalysis.type_checking import runtime_type_checking

from .rdf import RDF
from .rdf_input_file_reader import RDFInputFileReader
from .rdf_output_file_writer import RDFDataWriter, RDFLogWriter


@runtime_type_checking
def rdf(input_file: str, md_format: MDEngineFormat | str = MDEngineFormat.PQ):
"""
Calculates the radial distribution function (RDF) using a given input file.

This is just a wrapper function combining the underlying classes and functions.

For more information on the input file keys please
For more information on the input file keys please
visit :py:mod:`~PQAnalysis.analysis.rdf.rdfInputFileReader`.
For more information on the exact calculation of
the RDF please visit :py:class:`~PQAnalysis.analysis.rdf.rdf.RDF`.

Parameters
----------
input_file : str
The input file. For more information on the input file
The input file. For more information on the input file
keys please visit :py:mod:`~PQAnalysis.analysis.rdf.rdfInputFileReader`.
md_format : MDEngineFormat | str, optional
the format of the input trajectory. Default is "PQ".
the format of the input trajectory. Default is "PQ".
For more information on the supported formats please visit
:py:class:`~PQAnalysis.traj.formats.MDEngineFormat`.
"""

md_format = MDEngineFormat(md_format)

input_reader = RDFInputFileReader(input_file)
Expand Down
92 changes: 92 additions & 0 deletions PQAnalysis/type_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
A module for type checking of arguments passed to functions at runtime.
"""

import logging

from decorator import decorator
from beartype.door import is_bearable

from PQAnalysis.utils.custom_logging import setup_logger
from .types import (
Np1DIntArray,
Np2DIntArray,
Np1DNumberArray,
Np2DNumberArray,
Np3x3NumberArray,
NpnDNumberArray,
)

__logger_name__ = "PQAnalysis.TypeChecking"

if not logging.getLogger(__logger_name__).handlers:
logger = setup_logger(logging.getLogger(__logger_name__))
else:
logger = logging.getLogger(__logger_name__)

Check warning on line 25 in PQAnalysis/type_checking.py

View check run for this annotation

Codecov / codecov/patch

PQAnalysis/type_checking.py#L25

Added line #L25 was not covered by tests


@decorator
def runtime_type_checking(func, *args, **kwargs):
"""
A decorator to check the type of the arguments passed to a function at runtime.
"""

# Get the type hints of the function
type_hints = func.__annotations__

# Check the type of each argument
for arg_name, arg_value in zip(func.__code__.co_varnames, args):
if arg_name in type_hints:
if not is_bearable(arg_value, type_hints[arg_name]):
logger.error(
_get_type_error_message(
arg_name,
arg_value,
type_hints[arg_name],
),
exception=TypeError,
)

# Check the type of each keyword argument
for kwarg_name, kwarg_value in kwargs.items():
if kwarg_name in type_hints:
if not is_bearable(kwarg_value, type_hints[kwarg_name]):
logger.error(

Check warning on line 54 in PQAnalysis/type_checking.py

View check run for this annotation

Codecov / codecov/patch

PQAnalysis/type_checking.py#L51-L54

Added lines #L51 - L54 were not covered by tests
_get_type_error_message(
kwarg_name,
kwarg_value,
type_hints[kwarg_name],
),
exception=TypeError,
)

# Call the function
return func(*args, **kwargs)

Check warning on line 64 in PQAnalysis/type_checking.py

View check run for this annotation

Codecov / codecov/patch

PQAnalysis/type_checking.py#L64

Added line #L64 was not covered by tests


def _get_type_error_message(arg_name, value, expected_type):
"""
Get the error message for a type error.
"""

actual_type = type(value)

header = (
f"Argument '{arg_name}' with {value=} should be "
f"of type {expected_type}, but got {actual_type}."
)

if expected_type is Np1DIntArray:
header += " Expected a 1D numpy integer array."

Check warning on line 80 in PQAnalysis/type_checking.py

View check run for this annotation

Codecov / codecov/patch

PQAnalysis/type_checking.py#L80

Added line #L80 was not covered by tests
elif expected_type is Np2DIntArray:
header += " Expected a 2D numpy integer array."

Check warning on line 82 in PQAnalysis/type_checking.py

View check run for this annotation

Codecov / codecov/patch

PQAnalysis/type_checking.py#L82

Added line #L82 was not covered by tests
elif expected_type is Np1DNumberArray:
header += " Expected a 1D numpy number array."

Check warning on line 84 in PQAnalysis/type_checking.py

View check run for this annotation

Codecov / codecov/patch

PQAnalysis/type_checking.py#L84

Added line #L84 was not covered by tests
elif expected_type is Np2DNumberArray:
header += " Expected a 2D numpy number array."

Check warning on line 86 in PQAnalysis/type_checking.py

View check run for this annotation

Codecov / codecov/patch

PQAnalysis/type_checking.py#L86

Added line #L86 was not covered by tests
elif expected_type is Np3x3NumberArray:
header += " Expected a 3x3 numpy number array."

Check warning on line 88 in PQAnalysis/type_checking.py

View check run for this annotation

Codecov / codecov/patch

PQAnalysis/type_checking.py#L88

Added line #L88 was not covered by tests
elif expected_type is NpnDNumberArray:
header += " Expected an n-dimensional numpy number array."

Check warning on line 90 in PQAnalysis/type_checking.py

View check run for this annotation

Codecov / codecov/patch

PQAnalysis/type_checking.py#L90

Added line #L90 was not covered by tests

return header
27 changes: 18 additions & 9 deletions PQAnalysis/utils/custom_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,15 @@
)

if level in [logging.CRITICAL, logging.ERROR]:

exception = exception or Exception

if self.isEnabledFor(logging.DEBUG):
back_tb = None

try:
if exception is not None:
raise exception

raise Exception # pylint: disable=broad-exception-raised
except Exception: # pylint: disable=broad-except
raise exception # pylint: disable=broad-exception-raised
except exception: # pylint: disable=broad-except

Check warning on line 132 in PQAnalysis/utils/custom_logging.py

View check run for this annotation

Codecov / codecov/patch

PQAnalysis/utils/custom_logging.py#L131-L132

Added lines #L131 - L132 were not covered by tests
traceback = sys.exc_info()[2]
back_frame = traceback.tb_frame.f_back

Expand All @@ -140,12 +140,21 @@
tb_lineno=back_frame.f_lineno
)

if exception is not None:
raise Exception(msg).with_traceback(back_tb)

raise exception(msg).with_traceback(back_tb)

sys.exit(1)
class DevNull:
"""
Dummy class to redirect the sys.stderr to /dev/null.
"""

def write(self, _):
"""
Dummy write method.
"""

sys.stderr = DevNull()

raise exception(msg) # pylint: disable=raise-missing-from

def _original_log(self,
level: Any,
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ markers =
topology
traj
io
analysis

testpaths =
tests
Expand Down
7 changes: 7 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
Unit tests for the PQAnalysis package.
"""

import os

os.environ['PQANALYSIS_BEARTYPE_LEVEL'] = "RELEASE"
3 changes: 3 additions & 0 deletions tests/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import pytest

pytestmark = pytest.mark.analysis
Empty file added tests/analysis/rdf/__init__.py
Empty file.
26 changes: 26 additions & 0 deletions tests/analysis/rdf/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
A module to test the RDF API.
"""

import pytest # pylint: disable=unused-import

from PQAnalysis.analysis.rdf.api import rdf
from PQAnalysis.type_checking import _get_type_error_message

from .. import pytestmark # pylint: disable=unused-import
from ...conftest import assert_logging_with_exception


class TestRDFAPI:
def test_wrong_param_types(self, caplog):
assert_logging_with_exception(
caplog=caplog,
logging_name="TypeChecking",
logging_level="ERROR",
message_to_test=_get_type_error_message(
"input_file", 1, str,
),
exception=TypeError,
function=rdf,
input_file=1,
)
4 changes: 4 additions & 0 deletions tests/analysis/rdf/test_rdfInputFileReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from PQAnalysis.analysis.rdf.rdf_input_file_reader import RDFInputFileReader
from PQAnalysis.io.input_file_reader.exceptions import InputFileError

# import topology marker
from .. import pytestmark # pylint: disable=unused-import
from ...conftest import assert_logging


class TestRDFInputFileReader:
@pytest.mark.parametrize("example_dir", ["rdf"], indirect=False)
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def assert_logging_with_exception(caplog, logging_name, logging_level, message_t
result = None
try:
result = function(*args, **kwargs)
except SystemExit:
except:
pass

record = caplog.records[0]
Expand Down
9 changes: 2 additions & 7 deletions tests/io/test_frameReader.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import pytest
import numpy as np

from beartype.roar import BeartypeException

from . import pytestmark

from PQAnalysis.io import FrameReader
from PQAnalysis.io.traj_file.exceptions import FrameReaderError
from PQAnalysis.core import Cell, Atom
from PQAnalysis.traj.exceptions import TrajectoryFormatError
from PQAnalysis.traj import TrajectoryFormat
from PQAnalysis.topology import Topology

from . import pytestmark


class TestFrameReader:

Expand Down Expand Up @@ -67,9 +65,6 @@ def test__read_scalar(self):
def test_read(self):
reader = FrameReader()

with pytest.raises(BeartypeException):
reader.read(["tmp"])

frame = reader.read(
"2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0 2.0 3.0\no 2.0 2.0 2.0")
assert frame.n_atoms == 2
Expand Down
10 changes: 2 additions & 8 deletions tests/io/test_infoFileReader.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
import pytest

from beartype.roar import BeartypeException

from . import pytestmark

from PQAnalysis.io import InfoFileReader
from PQAnalysis.traj import MDEngineFormat
from PQAnalysis.traj.exceptions import MDEngineFormatError

from . import pytestmark


@pytest.mark.parametrize("example_dir", ["readInfoFile"], indirect=False)
def test__init__(test_with_data_dir):
with pytest.raises(FileNotFoundError) as exception:
InfoFileReader("tmp")
assert str(exception.value) == "File tmp not found."

with pytest.raises(BeartypeException) as exception:
InfoFileReader(
"md-01.info", engine_format=None)

with pytest.raises(MDEngineFormatError) as exception:
InfoFileReader(
"md-01.info", engine_format="tmp")
Expand Down
Loading
Loading