diff --git a/ape_solidity/compiler.py b/ape_solidity/compiler.py index 3886ac3..56aaa7e 100644 --- a/ape_solidity/compiler.py +++ b/ape_solidity/compiler.py @@ -54,9 +54,15 @@ # Define a regex pattern that matches import statements # Both single and multi-line imports will be matched IMPORTS_PATTERN = re.compile( - r"import\s+((.*?)(?=;)|[\s\S]*?from\s+(.*?)(?=;));\s", flags=re.MULTILINE + r"import\s+(([\s\S]*?)(?=;)|[\s\S]*?from\s+([^\s;]+));\s*", flags=re.MULTILINE ) LICENSES_PATTERN = re.compile(r"(// SPDX-License-Identifier:\s*([^\n]*)\s)") + +# Comment patterns +SINGLE_LINE_COMMENT_PATTERN = re.compile(r"^\s*//") +MULTI_LINE_COMMENT_START_PATTERN = re.compile(r"/\*") +MULTI_LINE_COMMENT_END_PATTERN = re.compile(r"\*/") + VERSION_PRAGMA_PATTERN = re.compile(r"pragma solidity[^;]*;") DEFAULT_OPTIMIZATION_RUNS = 200 @@ -142,7 +148,7 @@ class SolidityConfig(PluginConfig): def _get_flattened_source(path: Path, name: Optional[str] = None) -> str: name = name or path.name result = f"// File: {name}\n" - result += path.read_text() + "\n" + result += f"{path.read_text().rstrip()}\n" return result @@ -373,12 +379,15 @@ def _get_settings_from_imports( files_by_solc_version = self.get_version_map_from_imports( contract_filepaths, import_map, project=pm ) - return self._get_settings_from_version_map(files_by_solc_version, remappings, project=pm) + return self._get_settings_from_version_map( + files_by_solc_version, remappings, import_map=import_map, project=pm + ) def _get_settings_from_version_map( self, version_map: dict, import_remappings: dict[str, str], + import_map: Optional[dict[str, list[str]]] = None, project: Optional[ProjectManager] = None, **kwargs, ) -> dict[Version, dict]: @@ -397,7 +406,9 @@ def _get_settings_from_version_map( }, **kwargs, } - if remappings_used := self._get_used_remappings(sources, import_remappings, project=pm): + if remappings_used := self._get_used_remappings( + sources, import_remappings, import_map=import_map, project=pm + ): remappings_str = [f"{k}={v}" for k, v in remappings_used.items()] # Standard JSON input requires remappings to be sorted. @@ -421,6 +432,7 @@ def _get_used_remappings( self, sources: Iterable[Path], remappings: dict[str, str], + import_map: Optional[dict[str, list[str]]] = None, project: Optional[ProjectManager] = None, ) -> dict[str, str]: pm = project or self.local_project @@ -435,7 +447,8 @@ def _get_used_remappings( # Filter out unused import remapping. result = {} sources = list(sources) - imports = self.get_imports(sources, project=pm).values() + import_map = import_map or self.get_imports(sources, project=pm) + imports = import_map.values() for source_list in imports: for src in source_list: @@ -461,32 +474,20 @@ def get_standard_input_json( import_map = self.get_imports_from_remapping(paths, remapping, project=pm) version_map = self.get_version_map_from_imports(paths, import_map, project=pm) return self.get_standard_input_json_from_version_map( - version_map, remapping, project=pm, **overrides + version_map, remapping, project=pm, import_map=import_map, **overrides ) - def get_standard_input_json_from( - self, - version_map: dict[Version, set[Path]], - import_remappings: dict[str, str], - project: Optional[ProjectManager] = None, - **overrides, - ): - pm = project or self.local_project - settings = self._get_settings_from_version_map( - version_map, import_remappings, project=pm, **overrides - ) - return self.get_standard_input_json_from_settings(settings, version_map, project=pm) - def get_standard_input_json_from_version_map( self, version_map: dict[Version, set[Path]], import_remapping: dict[str, str], + import_map: Optional[dict[str, list[str]]] = None, project: Optional[ProjectManager] = None, **overrides, ): pm = project or self.local_project settings = self._get_settings_from_version_map( - version_map, import_remapping, project=pm, **overrides + version_map, import_remapping, import_map=import_map, project=pm, **overrides ) return self.get_standard_input_json_from_settings(settings, version_map, project=pm) @@ -571,8 +572,16 @@ def _compile( settings: Optional[dict] = None, ): pm = project or self.local_project - input_jsons = self.get_standard_input_json( - contract_filepaths, project=pm, **(settings or {}) + remapping = self.get_import_remapping(project=pm) + paths = list(contract_filepaths) # Handle if given generator= + import_map = self.get_imports_from_remapping(paths, remapping, project=pm) + version_map = self.get_version_map_from_imports(paths, import_map, project=pm) + input_jsons = self.get_standard_input_json_from_version_map( + version_map, + remapping, + project=pm, + import_map=import_map, + **(settings or {}), ) contract_versions: dict[str, Version] = {} contract_types: list[ContractType] = [] @@ -608,7 +617,7 @@ def _compile( for name, _ in contracts_out.items(): # Filter source files that the user did not ask for, such as # imported relative files that are not part of the input. - for input_file_path in contract_filepaths: + for input_file_path in paths: if source_id in str(input_file_path): input_contract_names.append(name) @@ -1096,14 +1105,17 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError: def _flatten_source( self, - path: Path, + path: Union[Path, str], project: Optional[ProjectManager] = None, raw_import_name: Optional[str] = None, handled: Optional[set[str]] = None, ) -> str: pm = project or self.local_project handled = handled or set() - source_id = f"{get_relative_path(path, pm.path)}" + + path = Path(path) + source_id = f"{get_relative_path(path, pm.path)}" if path.is_absolute() else f"{path}" + handled.add(source_id) remapping = self.get_import_remapping(project=project) imports = self._get_imports((path,), remapping, pm, tracked=set(), include_raw=True) @@ -1116,26 +1128,36 @@ def _flatten_source( continue sub_import_name = import_str.replace("import ", "").strip(" \n\t;\"'") - final_source += self._flatten_source( + sub_source = self._flatten_source( pm.path / source_id, project=pm, raw_import_name=sub_import_name, handled=handled, ) + final_source += sub_source + + flattened_src = _get_flattened_source(path, name=raw_import_name) + if flattened_src and final_source.rstrip(): + final_source = f"{final_source.rstrip()}\n\n{flattened_src}" + elif flattened_src: + final_source = flattened_src - final_source += _get_flattened_source(path, name=raw_import_name) return final_source def flatten_contract( self, path: Path, project: Optional[ProjectManager] = None, **kwargs ) -> Content: - # try compiling in order to validate it works res = self._flatten_source(path, project=project) res = remove_imports(res) res = process_licenses(res) res = remove_version_pragmas(res) pragma = get_first_version_pragma(path.read_text()) res = "\n".join([pragma, res]) + + # Simple auto-format. + while "\n\n\n" in res: + res = res.replace("\n\n\n", "\n\n") + lines = res.splitlines() line_dict = {i + 1: line for i, line in enumerate(lines)} return Content(root=line_dict) @@ -1244,11 +1266,37 @@ def _import_str_to_source_id( return f"{get_relative_path(path.absolute(), pm.path)}" -def remove_imports(flattened_contract: str) -> str: - # Use regex.sub() to remove matched import statements - no_imports_contract = IMPORTS_PATTERN.sub("", flattened_contract) +def remove_imports(source_code: str) -> str: + in_multi_line_comment = False + result_lines = [] + + lines = source_code.splitlines() + for line in lines: + # Check if we're entering a multi-line comment + if MULTI_LINE_COMMENT_START_PATTERN.search(line): + in_multi_line_comment = True + + # If inside a multi-line comment, just add the line to the result + if in_multi_line_comment: + result_lines.append(line) + # Check if this line ends the multi-line comment + if MULTI_LINE_COMMENT_END_PATTERN.search(line): + in_multi_line_comment = False + continue + + # Skip single-line comments + if SINGLE_LINE_COMMENT_PATTERN.match(line): + result_lines.append(line) + continue + + # Skip import statements in non-comment lines + if IMPORTS_PATTERN.search(line): + continue + + # Add the line to the result if it's not an import statement + result_lines.append(line) - return no_imports_contract + return "\n".join(result_lines) def remove_version_pragmas(flattened_contract: str) -> str: @@ -1285,9 +1333,7 @@ def process_licenses(contract: str) -> str: license_line, root_license = extracted_licenses[-1] # Get the unique license identifiers. All licenses in a contract _should_ be the same. - unique_license_identifiers = { - license_identifier for _, license_identifier in extracted_licenses - } + unique_license_identifiers = {lid for _, lid in extracted_licenses} # If we have more than one unique license identifier, warn the user and use the root. if len(unique_license_identifiers) > 1: diff --git a/setup.py b/setup.py index 2a30d37..a987929 100644 --- a/setup.py +++ b/setup.py @@ -69,7 +69,7 @@ include_package_data=True, install_requires=[ "py-solc-x>=2.0.2,<3", - "eth-ape>=0.8.1,<0.9", + "eth-ape>=0.8.4,<0.9", "ethpm-types", # Use the version ape requires "eth-pydantic-types", # Use the version ape requires "packaging", # Use the version ape requires diff --git a/tests/contracts/Imports.sol b/tests/contracts/Imports.sol index 8687fd1..87514f6 100644 --- a/tests/contracts/Imports.sol +++ b/tests/contracts/Imports.sol @@ -26,6 +26,10 @@ import "@safe/contracts/common/Enum.sol"; // Purposely exclude the contracts folder to test older Ape-style project imports. import "@noncompilingdependency/subdir/SubCompilingContract.sol"; +// Showing sources with extra extensions are by default excluded, +// unless used as an import somewhere in a non-excluded source. +import "./Source.extra.ext.sol"; + contract Imports { function foo() pure public returns(bool) { return true; diff --git a/tests/contracts/Source.extra.ext.sol b/tests/contracts/Source.extra.ext.sol new file mode 100644 index 0000000..99a98ac --- /dev/null +++ b/tests/contracts/Source.extra.ext.sol @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.4; + +// Showing sources with extra extensions are by default excluded, +// unless used as an import somewhere in a non-excluded source. +contract SourceExtraExt { + function foo() pure public returns(bool) { + return true; + } +} diff --git a/tests/data/ImportingLessConstrainedVersionFlat.sol b/tests/data/ImportingLessConstrainedVersionFlat.sol index ddf7012..61a4644 100644 --- a/tests/data/ImportingLessConstrainedVersionFlat.sol +++ b/tests/data/ImportingLessConstrainedVersionFlat.sol @@ -3,8 +3,6 @@ pragma solidity =0.8.12; // File: ./SpecificVersionRange.sol - - contract SpecificVersionRange { function foo() pure public returns(bool) { return true; @@ -13,8 +11,6 @@ contract SpecificVersionRange { // File: ImportingLessConstrainedVersion.sol - - // The file we are importing specific range '>=0.8.12 <0.8.15'; // This means on its own, the plugin would use 0.8.14 if its installed. // However - it should use 0.8.12 because of this file's requirements. diff --git a/tests/data/ImportsFlattened.sol.txt b/tests/data/ImportsFlattened.sol similarity index 57% rename from tests/data/ImportsFlattened.sol.txt rename to tests/data/ImportsFlattened.sol index 32b1f59..9cfa235 100644 --- a/tests/data/ImportsFlattened.sol.txt +++ b/tests/data/ImportsFlattened.sol @@ -1,96 +1,85 @@ +pragma solidity ^0.8.4; // SPDX-License-Identifier: MIT -// File: @remapping_2_brownie/BrownieContract.sol - -pragma solidity ^0.8.4; +// File: @dependencyofdependency/contracts/DependencyOfDependency.sol -contract BrownieContract { +contract DependencyOfDependency { function foo() pure public returns(bool) { return true; } } -// File: @styleofbrownie/BrownieStyleDependency.sol +// File: * as Depend from "@dependency/contracts/Dependency.sol -pragma solidity ^0.8.4; +struct DependencyStruct { + string name; + uint value; +} -contract BrownieStyleDependency { +contract Dependency { function foo() pure public returns(bool) { return true; } } +// File: { Struct0, Struct1, Struct2, Struct3, Struct4, Struct5 } from "./NumerousDefinitions.sol -// File: @dependency_remapping/DependencyOfDependency.sol - -pragma solidity ^0.8.4; +struct Struct0 { + string name; + uint value; +} -contract DependencyOfDependency { - function foo() pure public returns(bool) { - return true; - } +struct Struct1 { + string name; + uint value; } -// File: @remapping/contracts/Dependency.sol +struct Struct2 { + string name; + uint value; +} -pragma solidity ^0.8.4; +struct Struct3 { + string name; + uint value; +} +struct Struct4 { + string name; + uint value; +} -struct DependencyStruct { +struct Struct5 { string name; uint value; } -contract Dependency { +contract NumerousDefinitions { function foo() pure public returns(bool) { return true; } } +// File: @noncompilingdependency/CompilingContract.sol -// File: @dependency_remapping/DependencyOfDependency.sol - -pragma solidity ^0.8.4; - -contract DependencyOfDependency { +contract BrownieStyleDependency { function foo() pure public returns(bool) { return true; } } +// File: @browniedependency/contracts/BrownieContract.sol -// File: @remapping_2/Dependency.sol - -pragma solidity ^0.8.4; - - -struct DependencyStruct { - string name; - uint value; -} - -contract Dependency { +contract CompilingContract { function foo() pure public returns(bool) { return true; } } +// File: ./subfolder/Relativecontract.sol -// File: CompilesOnce.sol - -pragma solidity >=0.8.0; - -struct MyStruct { - string name; - uint value; -} - -contract CompilesOnce { - // This contract tests the scenario when we have a contract with - // a similar compiler version to more than one other contract's. - // This ensures we don't compile the same contract more than once. +contract Relativecontract { function foo() pure public returns(bool) { return true; } } - // File: ./././././././././././././././././././././././././././././././././././MissingPragma.sol contract MissingPragma { @@ -98,53 +87,41 @@ contract MissingPragma { return true; } } +// File: @safe/contracts/common/Enum.sol -// File: ./NumerousDefinitions.sol - -pragma solidity >=0.8.0; - -struct Struct0 { - string name; - uint value; -} - -struct Struct1 { - string name; - uint value; -} - -struct Struct2 { - string name; - uint value; +/// @title Enum - Collection of enums +/// @author Richard Meissner - +contract Enum { + enum Operation {Call, DelegateCall} } +// File: ./Source.extra.ext.sol -struct Struct3 { - string name; - uint value; +// Showing sources with extra extensions are by default excluded, +// unless used as an import somewhere in a non-excluded source. +contract SourceExtraExt { + function foo() pure public returns(bool) { + return true; + } } +// File: { MyStruct } from "contracts/CompilesOnce.sol -struct Struct4 { +struct MyStruct { string name; uint value; } -struct Struct5 { - string name; - uint value; -} +contract CompilesOnce { + // This contract tests the scenario when we have a contract with + // a similar compiler version to more than one other contract's. + // This ensures we don't compile the same contract more than once. -contract NumerousDefinitions { function foo() pure public returns(bool) { return true; } } +// File: @noncompilingdependency/subdir/SubCompilingContract.sol -// File: ./subfolder/Relativecontract.sol - -pragma solidity >=0.8.0; - -contract Relativecontract { - +contract SubCompilingContract { function foo() pure public returns(bool) { return true; } @@ -152,8 +129,22 @@ contract Relativecontract { // File: Imports.sol -pragma solidity ^0.8.4; - +import + "./././././././././././././././././././././././././././././././././././MissingPragma.sol"; +import { + Struct0, + Struct1, + Struct2, + Struct3, + Struct4, + Struct5 +} from "./NumerousDefinitions.sol"; +// Purposely repeat an import to test how the plugin handles that. + +// Purposely exclude the contracts folder to test older Ape-style project imports. + +// Showing sources with extra extensions are by default excluded, +// unless used as an import somewhere in a non-excluded source. contract Imports { function foo() pure public returns(bool) { diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 65f8a40..bf79a76 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -4,7 +4,7 @@ import solcx from ape import Project, reverts from ape.exceptions import CompilerError -from ape.logging import LogLevel +from ape.utils import get_full_extension from ethpm_types import ContractType from packaging.version import Version @@ -126,6 +126,7 @@ def test_get_imports_complex(project, compiler): "contracts/CompilesOnce.sol", "contracts/MissingPragma.sol", "contracts/NumerousDefinitions.sol", + "contracts/Source.extra.ext.sol", "contracts/subfolder/Relativecontract.sol", ], "contracts/MissingPragma.sol": [], @@ -403,6 +404,7 @@ def test_get_compiler_settings(project, compiler): "contracts/Imports.sol", "contracts/MissingPragma.sol", "contracts/NumerousDefinitions.sol", + "contracts/Source.extra.ext.sol", "contracts/subfolder/Relativecontract.sol", ] assert actual_files == expected_files @@ -623,7 +625,7 @@ def test_compile_project(project, compiler): """ Simple test showing the full project indeed compiles. """ - paths = [x for x in project.sources.paths if x.suffix == ".sol"] + paths = [x for x in project.sources.paths if get_full_extension(x) == ".sol"] actual = [c for c in compiler.compile(paths, project=project)] assert len(actual) > 0 @@ -684,23 +686,33 @@ def test_enrich_error_when_builtin(project, owner, connection): contract.checkIndexOutOfBounds(sender=owner) -def test_flatten(project, compiler, caplog): - path = project.sources.lookup("contracts/Imports.sol") - with caplog.at_level(LogLevel.WARNING): - compiler.flatten_contract(path, project=project) - actual = caplog.messages[-1] - expected = ( - "Conflicting licenses found: 'LGPL-3.0-only, MIT'. " - "Using the root file's license 'MIT'." - ) - assert actual == expected - - path = project.sources.lookup("contracts/ImportingLessConstrainedVersion.sol") - flattened_source = compiler.flatten_contract(path, project=project) - flattened_source_path = ( - Path(__file__).parent / "data" / "ImportingLessConstrainedVersionFlat.sol" +def test_flatten(mocker, project, compiler): + path = project.contracts_folder / "Imports.sol" + base_expected = Path(__file__).parent / "data" + + # NOTE: caplog for some reason is inconsistent and causes flakey tests. + # Thus, we are using our own "logger_spy". + logger_spy = mocker.patch("ape_solidity.compiler.logger") + + res = compiler.flatten_contract(path, project=project) + call_args = logger_spy.warning.call_args + actual_logs = call_args[0] if call_args else () + assert actual_logs, f"Missing warning logs from dup-licenses, res: {res}" + actual = actual_logs[-1] + # NOTE: MIT coming from Imports.sol and LGPL-3.0-only coming from + # @safe/contracts/common/Enum.sol. + expected = ( + "Conflicting licenses found: 'LGPL-3.0-only, MIT'. Using the root file's license 'MIT'." ) - assert str(flattened_source) == str(flattened_source_path.read_text()) + assert actual == expected + + path = project.contracts_folder / "ImportingLessConstrainedVersion.sol" + flattened_source = compiler.flatten_contract(path, project=project) + flattened_source_path = base_expected / "ImportingLessConstrainedVersionFlat.sol" + + actual = str(flattened_source) + expected = str(flattened_source_path.read_text()) + assert actual == expected def test_compile_code(project, compiler):