Skip to content

Commit

Permalink
Add type hints to the utilities module (#1252)
Browse files Browse the repository at this point in the history
* Add type hints to json.py

* Update MyPy ini file to include the utilities module

* Add type hints to decimal_places.py

* Add type hints to metadata_parser.py

* Add type hints to metadata_parser_v2.py

* Add type hints to metadata_validators.py

* Add type hints to request_session.py

* Add type hints to schema.py

* Add type hints to supplementary_data_parser.py

* Format python code

* Fix variable name issue

* Tidy up pylint comments

* Add more type hint definitions

* test new changes to ensure the deserialize method uses default format (iso8601) unless a format is specified

* Tidy up type hints and add missing hints for empty collection

* Refactor code to improve type hints

* Refactor/improve type hints and install new dev package

* Refactor more type hints

* Update type ignore message

* Fix type hinting

* Correct type hint for parameter

* Remove types-simplejson package

* Revert pipfiles

* Revert pipfile.lock
  • Loading branch information
VirajP1002 authored Dec 1, 2023
1 parent eda53b9 commit a1a0b10
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 54 deletions.
15 changes: 9 additions & 6 deletions app/utilities/decimal_places.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import flask_babel
from babel import Locale, numbers, units
from babel.numbers import get_currency_precision
from babel.numbers import NumberPattern, get_currency_precision

UnitLengthType: TypeAlias = Literal["short", "long", "narrow"]

Expand All @@ -27,7 +27,7 @@ def get_formatted_currency(
*,
value: float | Decimal,
currency: str = "GBP",
locale: str | None = None,
locale: str | Locale | None = None,
decimal_limit: int | None = None,
) -> str:
"""
Expand Down Expand Up @@ -80,7 +80,7 @@ def custom_format_unit(
measurement_unit: str,
locale: Locale | str,
length: UnitLengthType = "short",
):
) -> str:
"""
This function provides a wrapper for the numbers `format_unit` method, generating the
number format (including the desired number of decimals), based on the value entered by the user and
Expand All @@ -92,14 +92,17 @@ def custom_format_unit(
value=value,
measurement_unit=measurement_unit,
length=length,
format=number_format,
# Type ignore: babel function has incorrect type hinting, NumberPattern is valid here
format=number_format, # type: ignore
locale=locale,
)

return formatted_unit


def get_number_format(value: int | float | Decimal, locale: Locale | str) -> str:
def get_number_format(
value: int | float | Decimal, locale: Locale | str
) -> NumberPattern:
"""
Generates the number format based on the value entered by the user and the locale
Expand All @@ -111,7 +114,7 @@ def get_number_format(value: int | float | Decimal, locale: Locale | str) -> str
"""
decimal_places = _get_decimal_places(value)
locale = Locale.parse(locale)
locale_decimal_format = locale.decimal_formats[None]
locale_decimal_format: NumberPattern = locale.decimal_formats[None]
locale_decimal_format.frac_prec = (decimal_places, decimal_places)
return locale_decimal_format

Expand Down
8 changes: 5 additions & 3 deletions app/utilities/json.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import IO, Any

import simplejson as json


def json_load(file, **kwargs):
def json_load(file: IO[str], **kwargs: Any) -> Any:
return json.load(file, use_decimal=True, **kwargs)


def json_loads(data, **kwargs):
def json_loads(data: str, **kwargs: Any) -> Any:
return json.loads(data, use_decimal=True, **kwargs)


def json_dumps(data, **kwargs) -> str:
def json_dumps(data: Any, **kwargs: Any) -> str:
return json.dumps(data, for_json=True, use_decimal=True, **kwargs)
27 changes: 15 additions & 12 deletions app/utilities/metadata_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
from datetime import datetime, timezone
from typing import Mapping
from typing import Any, Mapping, MutableMapping

from marshmallow import (
EXCLUDE,
Expand Down Expand Up @@ -34,9 +34,9 @@

class StripWhitespaceMixin:
@pre_load()
def strip_whitespace(
self, items, **kwargs
): # pylint: disable=no-self-use, unused-argument
def strip_whitespace( # pylint: disable=no-self-use, unused-argument
self, items: MutableMapping, **kwargs: Any
) -> MutableMapping:
for key, value in items.items():
if isinstance(value, str):
items[key] = value.strip()
Expand Down Expand Up @@ -87,8 +87,9 @@ class RunnerMetadataSchema(Schema, StripWhitespaceMixin):
eq_id = VALIDATORS["string"](required=False) # type:ignore

@validates_schema
def validate_schema_name(self, data, **kwargs):
# pylint: disable=no-self-use, unused-argument
def validate_schema_name( # pylint: disable=no-self-use, unused-argument
self, data: Mapping, **kwargs: Any
) -> None:
"""Function to validate the business schema parameters"""
if not data.get("schema_name"):
business_schema_claims = (
Expand All @@ -101,8 +102,9 @@ def validate_schema_name(self, data, **kwargs):
)

@post_load
def update_schema_name(self, data, **kwargs):
# pylint: disable=no-self-use, unused-argument
def update_schema_name( # pylint: disable=no-self-use, unused-argument
self, data: MutableMapping, **kwargs: Any
) -> MutableMapping:
"""Function to transform parameters into a business schema"""
if data.get("schema_name"):
logger.info(
Expand All @@ -115,9 +117,9 @@ def update_schema_name(self, data, **kwargs):
return data

@post_load
def update_response_id(
self, data, **kwargs
): # pylint: disable=no-self-use, unused-argument
def update_response_id( # pylint: disable=no-self-use, unused-argument
self, data: MutableMapping, **kwargs: Any
) -> MutableMapping:
"""
If response_id is present : return as it is
If response_id is not present : Build response_id from ru_ref,collection_exercise_sid,eq_id and form_type
Expand Down Expand Up @@ -147,4 +149,5 @@ def update_response_id(
def validate_runner_claims(claims: Mapping) -> dict:
"""Validate claims required for runner to function"""
runner_metadata_schema = RunnerMetadataSchema(unknown=EXCLUDE)
return runner_metadata_schema.load(claims)
# Type ignore: the load method in the Marshmallow parent schema class doesn't have type hints for return
return runner_metadata_schema.load(claims) # type: ignore
32 changes: 18 additions & 14 deletions app/utilities/metadata_parser_v2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
from datetime import datetime, timezone
from typing import Callable, Iterable, Mapping
from typing import Any, Callable, Iterable, Mapping, MutableMapping

from marshmallow import (
EXCLUDE,
Expand Down Expand Up @@ -34,9 +34,9 @@

class StripWhitespaceMixin:
@pre_load()
def strip_whitespace(
self, items, **kwargs
): # pylint: disable=no-self-use, unused-argument
def strip_whitespace( # pylint: disable=no-self-use, unused-argument
self, items: MutableMapping, **kwargs: Any
) -> MutableMapping:
for key, value in items.items():
if isinstance(value, str):
items[key] = value.strip()
Expand All @@ -52,8 +52,9 @@ class SurveyMetadata(Schema, StripWhitespaceMixin):
receipting_keys = fields.List(fields.String)

@validates_schema
def validate_receipting_keys(self, data, **kwargs):
# pylint: disable=no-self-use, unused-argument
def validate_receipting_keys( # pylint: disable=no-self-use, unused-argument
self, data: Mapping, **kwargs: Any
) -> None:
if data and (receipting_keys := data.get("receipting_keys", {})):
missing_receipting_keys = [
receipting_key
Expand Down Expand Up @@ -101,8 +102,9 @@ class RunnerMetadataSchema(Schema, StripWhitespaceMixin):
survey_metadata = fields.Nested(SurveyMetadata, required=False)

@validates_schema
def validate_schema_options(self, data, **kwargs):
# pylint: disable=no-self-use, unused-argument
def validate_schema_options( # pylint: disable=no-self-use, unused-argument
self, data: Mapping, **kwargs: Any
) -> None:
if data:
options = [
option
Expand All @@ -122,14 +124,14 @@ def validate_schema_options(self, data, **kwargs):
def validate_questionnaire_claims(
claims: Mapping,
questionnaire_specific_metadata: Iterable[Mapping],
unknown=EXCLUDE,
unknown: str = EXCLUDE,
) -> dict:
"""Validate any survey specific claims required for a questionnaire"""
dynamic_fields = {}
dynamic_fields: dict[str, fields.String | DateString] = {}

for metadata_field in questionnaire_specific_metadata:
field_arguments = {}
validators = []
field_arguments: dict[str, bool] = {}
validators: list[validate.Validator] = []

if metadata_field.get("optional"):
field_arguments["required"] = False
Expand All @@ -155,10 +157,12 @@ def validate_questionnaire_claims(
)(unknown=unknown)

# The load method performs validation.
return questionnaire_metadata_schema.load(claims)
# Type ignore: the load method in the Marshmallow parent schema class doesn't have type hints for return
return questionnaire_metadata_schema.load(claims) # type: ignore


def validate_runner_claims_v2(claims: Mapping) -> dict:
"""Validate claims required for runner to function"""
runner_metadata_schema = RunnerMetadataSchema(unknown=EXCLUDE)
return runner_metadata_schema.load(claims)
# Type ignore: the load method in the Marshmallow parent schema class doesn't have type hints for return
return runner_metadata_schema.load(claims) # type: ignore
16 changes: 10 additions & 6 deletions app/utilities/metadata_validators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from marshmallow import fields, validate


Expand All @@ -6,7 +8,7 @@ class RegionCode(validate.Regexp):
Currently, this does not validate the subdivision, but only checks length
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__("^GB-[A-Z]{3}$", *args, **kwargs)


Expand All @@ -16,7 +18,7 @@ class UUIDString(fields.UUID):
This custom field deserializes UUIDs to strings.
"""

def _deserialize(self, *args, **kwargs): # pylint: disable=arguments-differ
def _deserialize(self, *args: Any, **kwargs: Any) -> str: # type: ignore # pylint: disable=arguments-differ
return str(super()._deserialize(*args, **kwargs))


Expand All @@ -26,10 +28,12 @@ class DateString(fields.DateTime):
This custom field deserializes Dates to strings.
"""

def _deserialize(self, *args, **kwargs): # pylint: disable=arguments-differ
date = super()._deserialize(*args, **kwargs)
DEFAULT_FORMAT = "iso8601"

if self.format == "iso8601":
def _deserialize(self, *args: Any, **kwargs: Any) -> str: # type: ignore # pylint: disable=arguments-differ
date = super()._deserialize(*args, **kwargs)
date_format = self.format or self.DEFAULT_FORMAT
if date_format == "iso8601":
return date.isoformat()

return date.strftime(self.format)
return date.strftime(date_format)
4 changes: 3 additions & 1 deletion app/utilities/request_session.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Sequence

import requests
from requests.adapters import HTTPAdapter
from urllib3 import Retry


def get_retryable_session(
max_retries, retry_status_codes, backoff_factor
max_retries: int, retry_status_codes: Sequence[int], backoff_factor: float
) -> requests.Session:
session = requests.Session()

Expand Down
8 changes: 4 additions & 4 deletions app/utilities/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_schema_path(language_code: str, schema_name: str) -> str | None:
def get_schema_path_map(
include_test_schemas: bool = False,
) -> dict[str, dict[str, dict[str, str]]]:
schemas = {}
schemas: dict[str, dict[str, dict[str, str]]] = {}
for survey_type in os.listdir(SCHEMA_DIR):
if not include_test_schemas and survey_type == "test":
continue
Expand Down Expand Up @@ -99,7 +99,7 @@ def _schema_exists(language_code: str, schema_name: str) -> bool:
)


def get_allowed_languages(schema_name: str | None, launch_language: str):
def get_allowed_languages(schema_name: str | None, launch_language: str) -> list[str]:
if schema_name:
for language_combination in LANGUAGES_MAP.get(schema_name, []):
if launch_language in language_combination:
Expand Down Expand Up @@ -151,7 +151,7 @@ def _load_schema_from_name(schema_name: str, language_code: str) -> Questionnair
return QuestionnaireSchema(schema_json, language_code)


def get_schema_name_from_params(eq_id, form_type) -> str:
def get_schema_name_from_params(eq_id: str | None, form_type: str | None) -> str:
return f"{eq_id}_{form_type}"


Expand Down Expand Up @@ -243,7 +243,7 @@ def load_schema_from_url(url: str, *, language_code: str | None) -> Questionnair
raise SchemaRequestFailed


def cache_questionnaire_schemas():
def cache_questionnaire_schemas() -> None:
for schemas_by_language in get_schema_path_map().values():
for language_code, schemas in schemas_by_language.items():
for schema in schemas:
Expand Down
20 changes: 12 additions & 8 deletions app/utilities/supplementary_data_parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Mapping
from typing import Any, Mapping

from marshmallow import (
INCLUDE,
Expand All @@ -18,8 +18,9 @@ class ItemsSchema(Schema):
identifier = fields.Field(required=True)

@validates("identifier")
def validate_identifier(self, identifier):
# pylint: disable=no-self-use
def validate_identifier( # pylint: disable=no-self-use
self, identifier: fields.Field
) -> None:
if not (isinstance(identifier, str) and identifier.strip()) and not (
isinstance(identifier, int) and identifier >= 0
):
Expand All @@ -40,8 +41,9 @@ class SupplementaryData(Schema, StripWhitespaceMixin):
items = fields.Nested(ItemsData, required=False, unknown=INCLUDE)

@validates_schema()
def validate_identifier(self, data, **kwargs):
# pylint: disable=no-self-use, unused-argument
def validate_identifier( # pylint: disable=no-self-use, unused-argument
self, data: Mapping, **kwargs: Any
) -> None:
if data and data["identifier"] != self.context["identifier"]:
raise ValidationError(
"Supplementary data did not return the specified Identifier"
Expand All @@ -59,8 +61,9 @@ class SupplementaryDataMetadataSchema(Schema, StripWhitespaceMixin):
)

@validates_schema()
def validate_dataset_and_survey_id(self, data, **kwargs):
# pylint: disable=no-self-use, unused-argument
def validate_dataset_and_survey_id( # pylint: disable=no-self-use, unused-argument
self, data: Mapping, **kwargs: Any
) -> None:
if data:
if data["dataset_id"] != self.context["dataset_id"]:
raise ValidationError(
Expand Down Expand Up @@ -97,4 +100,5 @@ def validate_supplementary_data_v1(
items = [ItemsSchema(unknown=INCLUDE).load(value) for value in values]
validated_supplementary_data["data"]["items"][key] = items

return validated_supplementary_data
# Type ignore: the load method in the Marshmallow parent schema class doesn't have type hints for return
return validated_supplementary_data # type: ignore
5 changes: 5 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,8 @@ no_implicit_optional = True
disallow_untyped_defs = True
warn_return_any = True
no_implicit_optional = True

[mypy-app.utilities.*]
disallow_untyped_defs = True
warn_return_any = True
no_implicit_optional = True
Loading

0 comments on commit a1a0b10

Please sign in to comment.