Skip to content

Commit

Permalink
✨ Accept Wake contracts path as parameter to config
Browse files Browse the repository at this point in the history
  • Loading branch information
michprev committed Sep 24, 2024
1 parent 2ae57db commit bf00281
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 24 deletions.
3 changes: 3 additions & 0 deletions wake/cli/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@

def export_json(config: WakeConfig, compiler: SolidityCompiler, build: ProjectBuild, build_info: ProjectBuildInfo):
import json
import platform

config_dict = config.todict(mode="json")
del config_dict["subconfigs"]
del config_dict["api_keys"]

out = {
"version": build_info.wake_version,
"system": platform.system(),
"project_root": str(config.project_root_path),
"wake_contracts_path": str(config.wake_contracts_path),
"config": config_dict,
"sources": {},
}
Expand Down
9 changes: 7 additions & 2 deletions wake/cli/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import sys
import time
from pathlib import Path
from pathlib import Path, PurePosixPath, PureWindowsPath
from typing import (
TYPE_CHECKING,
AbstractSet,
Expand Down Expand Up @@ -650,7 +650,12 @@ def process_detection(detection: Detection) -> Dict[str, Any]:
if loaded["version"] != get_package_version("eth-wake"):
raise click.BadParameter(f"JSON file was created with version {loaded['version']} of eth-wake, while the current version is {get_package_version('eth-wake')}")

config = WakeConfig.fromdict(loaded["config"])
if loaded["system"] == "Windows":
wake_contracts_path = PureWindowsPath(loaded["wake_contracts_path"])
else:
wake_contracts_path = PurePosixPath(loaded["wake_contracts_path"])

config = WakeConfig.fromdict(loaded["config"], wake_contracts_path=wake_contracts_path)
original_project_root = Path(loaded["project_root"])

# add project root as an include path (for solc to resolve imports correctly)
Expand Down
8 changes: 4 additions & 4 deletions wake/cli/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def run_init(
from ..compiler import SolcOutputSelectionEnum, SolidityCompiler
from ..compiler.solc_frontend import SolcOutputErrorSeverityEnum
from ..development.pytypes_generator import TypeGenerator
from ..utils.file_utils import copy_dir, is_relative_to, wake_contracts_path
from ..utils.file_utils import copy_dir, is_relative_to

if example is None:
# create tests directory
Expand Down Expand Up @@ -184,7 +184,7 @@ def run_init(
and file.is_file()
):
sol_files.add(file)
for file in wake_contracts_path.rglob("**/*.sol"):
for file in Path(config.wake_contracts_path).rglob("**/*.sol"):
sol_files.add(file)
end = time.perf_counter()
console.log(
Expand Down Expand Up @@ -295,7 +295,7 @@ async def run_init_pytypes(
from ..compiler.compiler import CompilationFileSystemEventHandler
from ..compiler.solc_frontend import SolcOutputErrorSeverityEnum
from ..development.pytypes_generator import TypeGenerator
from ..utils.file_utils import is_relative_to, wake_contracts_path
from ..utils.file_utils import is_relative_to

def callback(build: ProjectBuild, build_info: ProjectBuildInfo):
start = time.perf_counter()
Expand All @@ -319,7 +319,7 @@ def callback(build: ProjectBuild, build_info: ProjectBuildInfo):
and file.is_file()
):
sol_files.add(file)
for file in wake_contracts_path.rglob("**/*.sol"):
for file in Path(config.wake_contracts_path).rglob("**/*.sol"):
sol_files.add(file)
end = time.perf_counter()
console.log(
Expand Down
3 changes: 1 addition & 2 deletions wake/compiler/solc_frontend/solc_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from wake.core import get_logger
from wake.core.solidity_version import SolidityVersion
from wake.svm import SolcVersionManager
from wake.utils import wake_contracts_path

from .exceptions import SolcCompilationError
from .input_data_model import (
Expand Down Expand Up @@ -76,7 +75,7 @@ async def __run_solc(
args.append("--base-path=.")
for include_path in self.__config.compiler.solc.include_paths:
args.append(f"--include-path={include_path}")
args.append(f"--include-path={wake_contracts_path}")
args.append(f"--include-path={self.__config.wake_contracts_path}")

logger.debug(f"Running solc: {' '.join(args)}")

Expand Down
7 changes: 3 additions & 4 deletions wake/compiler/source_path_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from wake.config import WakeConfig

from ..utils import wake_contracts_path
from .exceptions import CompilationResolveError


Expand All @@ -31,9 +30,9 @@ def resolve(
for include_path in itertools.chain(
[self.__config.project_root_path],
self.__config.compiler.solc.include_paths,
[wake_contracts_path],
[self.__config.wake_contracts_path],
):
path = include_path / source_unit_name
path = Path(include_path / source_unit_name)
if path.is_file():
matching_paths.append(path)
elif path in virtual_files:
Expand All @@ -58,7 +57,7 @@ def matches(self, source_unit_name: str, file: Path) -> bool:
for include_path in itertools.chain(
[self.__config.project_root_path],
self.__config.compiler.solc.include_paths,
[wake_contracts_path],
[self.__config.wake_contracts_path],
):
path = include_path / source_unit_name
if path == file:
Expand Down
3 changes: 1 addition & 2 deletions wake/compiler/source_unit_name_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from wake.config import WakeConfig
from wake.config.data_model import SolcRemapping
from wake.utils import wake_contracts_path


class SourceUnitNameResolver:
Expand Down Expand Up @@ -121,7 +120,7 @@ def resolve_cmdline_arg(self, arg: str) -> str:
for include_path in itertools.chain(
[self.__config.project_root_path],
self.__config.compiler.solc.include_paths,
[wake_contracts_path],
[self.__config.wake_contracts_path],
):
try:
return str(PurePosixPath(pure_path.relative_to(include_path)))
Expand Down
21 changes: 19 additions & 2 deletions wake/config/wake_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import platform
import reprlib
from copy import deepcopy
from pathlib import Path
from pathlib import Path, PurePath
from typing import Any, Dict, FrozenSet, Iterable, Optional, Set, Tuple, Union

import networkx as nx
import tomli
from typing_extensions import Literal

import wake.utils.file_utils
from wake.core import get_logger
from wake.utils import change_cwd

Expand Down Expand Up @@ -44,6 +45,7 @@ class WakeConfig:

__local_config_path: Path
__project_root_path: Path
__wake_contracts_path: PurePath
__global_config_path: Path
__global_data_path: Path
__global_cache_path: Path
Expand All @@ -56,6 +58,7 @@ def __init__(
*_,
local_config_path: Optional[Union[str, Path]] = None,
project_root_path: Optional[Union[str, Path]] = None,
wake_contracts_path: Optional[PurePath] = None,
):
"""
Initialize the `WakeConfig` class. If `project_root_path` is not provided, the current working directory is used.
Expand Down Expand Up @@ -112,6 +115,11 @@ def __init__(
else:
self.__local_config_path = Path(local_config_path).resolve()

if wake_contracts_path is None:
self.__wake_contracts_path = wake.utils.file_utils.wake_contracts_path
else:
self.__wake_contracts_path = wake_contracts_path

if not self.__project_root_path.is_dir():
raise ValueError(
f"Project root path '{self.__project_root_path}' is not a directory."
Expand Down Expand Up @@ -219,6 +227,7 @@ def fromdict(
config_dict: Dict[str, Any],
*,
project_root_path: Optional[Union[str, Path]] = None,
wake_contracts_path: Optional[PurePath] = None,
) -> "WakeConfig":
"""
Args:
Expand All @@ -228,7 +237,7 @@ def fromdict(
Returns:
Instance of the `WakeConfig` class with the provided config options.
"""
instance = cls(project_root_path=project_root_path)
instance = cls(project_root_path=project_root_path, wake_contracts_path=wake_contracts_path)
with change_cwd(instance.project_root_path):
parsed_config = TopLevelConfig.model_validate(config_dict)
instance.__config_raw = parsed_config.model_dump(
Expand Down Expand Up @@ -436,6 +445,14 @@ def project_root_path(self) -> Path:
"""
return self.__project_root_path

@property
def wake_contracts_path(self) -> PurePath:
"""
Returns:
System path to the Wake contracts directory.
"""
return self.__wake_contracts_path

@property
def min_solidity_version(self) -> SolidityVersion:
"""
Expand Down
6 changes: 3 additions & 3 deletions wake/detectors/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from wake.core.visitor import Visitor, group_map, visit_map
from wake.core.wake_comments import WakeComment, error_commented_out
from wake.utils import StrEnum, get_class_that_defined_method
from wake.utils.file_utils import is_relative_to, wake_contracts_path
from wake.utils.file_utils import is_relative_to
from wake.utils.keyed_default_dict import KeyedDefaultDict

if TYPE_CHECKING:
Expand Down Expand Up @@ -220,7 +220,7 @@ def _strip_excluded_subdetections(
for d in detection.subdetections:
if not any(
is_relative_to(d.ir_node.source_unit.file, p)
for p in chain(config.detectors.exclude_paths, [wake_contracts_path])
for p in chain(config.detectors.exclude_paths, [config.wake_contracts_path])
):
subdetections.append(d)
continue
Expand Down Expand Up @@ -325,7 +325,7 @@ def _filter_detections(

if any(
is_relative_to(detection.detection.ir_node.source_unit.file, p)
for p in chain(config.detectors.exclude_paths, [wake_contracts_path])
for p in chain(config.detectors.exclude_paths, [config.wake_contracts_path])
):
detection = DetectorResult(
_strip_excluded_subdetections(detection.detection, config),
Expand Down
7 changes: 3 additions & 4 deletions wake/lsp/features/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from wake.core import get_logger

from ...compiler.source_unit_name_resolver import SourceUnitNameResolver
from ...utils import wake_contracts_path
from ..common_structures import (
Command,
MarkupContent,
Expand Down Expand Up @@ -541,7 +540,7 @@ async def completion(
this_source_unit_name = None
for include_path in chain(
context.config.compiler.solc.include_paths,
[context.config.project_root_path, wake_contracts_path],
[context.config.project_root_path, context.config.wake_contracts_path],
):
try:
rel_path = str(path.relative_to(include_path).as_posix())
Expand Down Expand Up @@ -575,7 +574,7 @@ async def completion(

for include_path in chain(
context.config.compiler.solc.include_paths,
[context.config.project_root_path, wake_contracts_path],
[context.config.project_root_path, Path(context.config.wake_contracts_path)],
):
if include_path.is_dir():
for p in include_path.iterdir():
Expand All @@ -601,7 +600,7 @@ async def completion(
else:
for include_path in chain(
context.config.compiler.solc.include_paths,
[context.config.project_root_path, wake_contracts_path],
[context.config.project_root_path, Path(context.config.wake_contracts_path)],
):
if include_path.is_dir():
dir = include_path / parent
Expand Down
2 changes: 1 addition & 1 deletion wake/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .context_managers import change_cwd, recursion_guard
from .decorators import cached_return_on_recursion, return_on_recursion
from .enums import StrEnum
from .file_utils import is_relative_to, wake_contracts_path
from .file_utils import is_relative_to
from .general import get_class_that_defined_method
from .version import get_package_version

0 comments on commit bf00281

Please sign in to comment.