Skip to content

Commit

Permalink
chore: flake8 type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Oct 29, 2024
1 parent 7fe0acf commit 99bcdb3
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repos:
rev: v1.13.0
hooks:
- id: mypy
additional_dependencies: [types-requests, types-setuptools, pydantic]
additional_dependencies: [types-requests, types-setuptools, flake8-pydantic, flake8-type-checking]

- repo: https://github.com/executablebooks/mdformat
rev: 0.7.18
Expand Down
34 changes: 19 additions & 15 deletions ape_solidity/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from typing import TYPE_CHECKING, Optional

from ape.exceptions import CompilerError, ProjectError
from ape.managers import ProjectManager
from ape.utils.basemodel import BaseModel, ManagerAccessMixin, classproperty
from ape.utils.os import get_relative_path
from pydantic import field_serializer

from ape_solidity._utils import get_single_import_lines

if TYPE_CHECKING:
from ape.managers.project import ProjectManager

from ape_solidity.compiler import SolidityCompiler


Expand All @@ -26,7 +27,7 @@ class ApeSolidityModel(BaseModel, ApeSolidityMixin):
pass


def _create_import_remapping(project: ProjectManager) -> dict[str, str]:
def _create_import_remapping(project: "ProjectManager") -> dict[str, str]:
prefix = f"{get_relative_path(project.contracts_folder, project.path)}"
specified = project.dependencies.install()

Expand Down Expand Up @@ -102,22 +103,22 @@ def __init__(self):
# Cache project paths to import remapping.
self._cache: dict[str, dict[str, str]] = {}

def __getitem__(self, project: ProjectManager) -> dict[str, str]:
def __getitem__(self, project: "ProjectManager") -> dict[str, str]:
if remapping := self._cache.get(f"{project.path}"):
return remapping

return self.add_project(project)

def add_project(self, project: ProjectManager) -> dict[str, str]:
def add_project(self, project: "ProjectManager") -> dict[str, str]:
remapping = _create_import_remapping(project)
return self.add(project, remapping)

def add(self, project: ProjectManager, remapping: dict[str, str]):
def add(self, project: "ProjectManager", remapping: dict[str, str]):
self._cache[f"{project.path}"] = remapping
return remapping

@classmethod
def get_import_remapping(cls, project: ProjectManager):
def get_import_remapping(cls, project: "ProjectManager"):
return _create_import_remapping(project)


Expand Down Expand Up @@ -147,7 +148,7 @@ def value(self) -> str:
return self.raw_value

@property
def dependency(self) -> Optional[ProjectManager]:
def dependency(self) -> Optional["ProjectManager"]:
if name := self.dependency_name:
if version := self.dependency_version:
return self.local_project.dependencies[name][version]
Expand All @@ -159,8 +160,8 @@ def parse_line(
cls,
value: str,
reference: Path,
project: ProjectManager,
dependency: Optional[ProjectManager] = None,
project: "ProjectManager",
dependency: Optional["ProjectManager"] = None,
) -> "ImportStatementMetadata":
quote = '"' if '"' in value else "'"
sep = "\\" if "\\" in value else "/"
Expand All @@ -186,14 +187,17 @@ def __hash__(self) -> int:
return hash(path)

def _resolve_source(
self, reference: Path, project: ProjectManager, dependency: Optional[ProjectManager] = None
self,
reference: Path,
project: "ProjectManager",
dependency: Optional["ProjectManager"] = None,
):
if not self._resolve_dependency(project, dependency=dependency):
# Handle non-dependencies.
self._resolve_import_remapping(project)
self._resolve_path(reference, project)

def _resolve_import_remapping(self, project: ProjectManager):
def _resolve_import_remapping(self, project: "ProjectManager"):
if self.value.startswith("."):
# Relative paths should not use import-remappings.
return
Expand All @@ -213,7 +217,7 @@ def _resolve_import_remapping(self, project: ProjectManager):
valid_matches, key=lambda x: len(x[0])
)

def _resolve_path(self, reference: Path, project: ProjectManager):
def _resolve_path(self, reference: Path, project: "ProjectManager"):
base_path = None
if self.value.startswith("."):
base_path = reference.parent
Expand All @@ -236,7 +240,7 @@ def _resolve_path(self, reference: Path, project: ProjectManager):
self.source_id = f"{get_relative_path(self.path, project.path)}"

def _resolve_dependency(
self, project: ProjectManager, dependency: Optional[ProjectManager] = None
self, project: "ProjectManager", dependency: Optional["ProjectManager"] = None
) -> bool:
config_project = dependency or project
# NOTE: Dependency is set if we are getting dependencies of dependencies.
Expand Down Expand Up @@ -340,9 +344,9 @@ def _serialize_import_statements(self, statements, info):
def from_source_files(
cls,
source_files: Iterable[Path],
project: ProjectManager,
project: "ProjectManager",
statements: Optional[dict[tuple[Path, str], set[ImportStatementMetadata]]] = None,
dependency: Optional[ProjectManager] = None,
dependency: Optional["ProjectManager"] = None,
) -> "SourceTree":
statements = statements or {}
for path in source_files:
Expand Down
15 changes: 9 additions & 6 deletions ape_solidity/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
from collections.abc import Iterable
from enum import Enum
from pathlib import Path
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union

from ape.exceptions import CompilerError
from ape.utils import pragma_str_to_specifier_set
from packaging.specifiers import SpecifierSet
from packaging.version import Version
from solcx.install import get_executable
from solcx.wrapper import get_solc_version as get_solc_version_from_binary

if TYPE_CHECKING:
from packaging.specifiers import SpecifierSet


OUTPUT_SELECTION = [
"abi",
"bin-runtime",
Expand Down Expand Up @@ -62,7 +65,7 @@ def get_single_import_lines(source_path: Path) -> list[str]:
return list(import_set)


def get_pragma_spec_from_path(source_file_path: Union[Path, str]) -> Optional[SpecifierSet]:
def get_pragma_spec_from_path(source_file_path: Union[Path, str]) -> Optional["SpecifierSet"]:
"""
Extracts pragma information from Solidity source code.
Expand All @@ -80,7 +83,7 @@ def get_pragma_spec_from_path(source_file_path: Union[Path, str]) -> Optional[Sp
return get_pragma_spec_from_str(source_str)


def get_pragma_spec_from_str(source_str: str) -> Optional[SpecifierSet]:
def get_pragma_spec_from_str(source_str: str) -> Optional["SpecifierSet"]:
if not (
pragma_match := next(
re.finditer(r"(?:\n|^)\s*pragma\s*solidity\s*([^;\n]*)", source_str), None
Expand All @@ -106,11 +109,11 @@ def add_commit_hash(version: Union[str, Version]) -> Version:
return get_solc_version_from_binary(solc, with_commit_hash=True)


def get_versions_can_use(pragma_spec: SpecifierSet, options: Iterable[Version]) -> list[Version]:
def get_versions_can_use(pragma_spec: "SpecifierSet", options: Iterable[Version]) -> list[Version]:
return sorted(list(pragma_spec.filter(options)), reverse=True)


def select_version(pragma_spec: SpecifierSet, options: Iterable[Version]) -> Optional[Version]:
def select_version(pragma_spec: "SpecifierSet", options: Iterable[Version]) -> Optional[Version]:
choices = get_versions_can_use(pragma_spec, options)
return choices[0] if choices else None

Expand Down
13 changes: 8 additions & 5 deletions ape_solidity/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from collections import defaultdict
from collections.abc import Iterable, Iterator, Sequence
from pathlib import Path
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union

from ape.api import CompilerAPI, PluginConfig
from ape.contracts import ContractInstance
from ape.exceptions import CompilerError, ConfigError, ContractLogicError
from ape.logging import logger
from ape.managers.project import LocalProject, ProjectManager
Expand All @@ -15,7 +14,6 @@
from eth_pydantic_types import HexBytes
from eth_utils import add_0x_prefix, is_0x_prefixed
from ethpm_types.source import Compiler, Content
from packaging.specifiers import SpecifierSet
from packaging.version import Version
from pydantic import model_validator
from requests.exceptions import ConnectionError
Expand Down Expand Up @@ -50,6 +48,11 @@
SolcInstallError,
)

if TYPE_CHECKING:
from ape.contracts import ContractInstance
from packaging.specifiers import SpecifierSet


LICENSES_PATTERN = re.compile(r"(// SPDX-License-Identifier:\s*([^\n]*)\s)")

# Comment patterns
Expand Down Expand Up @@ -234,7 +237,7 @@ def _get_configured_version(
def _ape_version(self) -> Version:
return Version(version.split(".dev")[0].strip())

def add_library(self, *contracts: ContractInstance, project: Optional[ProjectManager] = None):
def add_library(self, *contracts: "ContractInstance", project: Optional[ProjectManager] = None):
"""
Set a library contract type address. This is useful when deploying a library
in a local network and then adding the address afterward. Now, when
Expand Down Expand Up @@ -782,7 +785,7 @@ def get_version_map_from_imports(
# is more predictable. Also, remove any lingering empties.
return {k: result[k] for k in sorted(result) if result[k]}

def _get_pramga_spec_from_str(self, source_str: str) -> Optional[SpecifierSet]:
def _get_pramga_spec_from_str(self, source_str: str) -> Optional["SpecifierSet"]:
if not (pragma_spec := get_pragma_spec_from_str(source_str)):
return None

Expand Down
8 changes: 5 additions & 3 deletions ape_solidity/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from enum import IntEnum
from typing import Union
from typing import TYPE_CHECKING, Union

from ape.exceptions import CompilerError, ConfigError, ContractLogicError
from ape.logging import LogLevel, logger
from solcx.exceptions import SolcError

if TYPE_CHECKING:
from solcx.exceptions import SolcError


class SolcInstallError(CompilerError):
Expand All @@ -25,7 +27,7 @@ class SolcCompileError(CompilerError):
account Ape's logging verbosity.
"""

def __init__(self, solc_error: SolcError):
def __init__(self, solc_error: "SolcError"):
self.solc_error = solc_error

def __str__(self) -> str:
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
[flake8]
max-line-length = 100
ignore = E704,W503,PYD002,TC003,TC006
exclude =
venv*
docs
build
tests/node_modules
type-checking-pydantic-enabled = True
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"types-requests", # Needed for mypy type shed
"types-setuptools", # Needed for mypy type shed
"flake8>=7.1.1,<8", # Style linter
"flake8-pydantic", # For detecting issues with Pydantic models
"flake8-type-checking", # Detect imports to move in/out of type-checking blocks
"isort>=5.13.2,<6", # Import sorting linter
"mdformat>=0.7.18", # Auto-formatter for markdown
"mdformat-gfm>=0.3.5", # Needed for formatting GitHub-flavored markdown
Expand Down

0 comments on commit 99bcdb3

Please sign in to comment.