From b37033822ee3131552d345ed6e768fa1515df626 Mon Sep 17 00:00:00 2001 From: Juliya Smith Date: Mon, 8 Jul 2024 17:08:53 -0500 Subject: [PATCH] fix: issue with flattening --- ape_solidity/compiler.py | 48 ++++++++----- tests/test_cli.py | 151 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 177 insertions(+), 22 deletions(-) diff --git a/ape_solidity/compiler.py b/ape_solidity/compiler.py index 8d5d095..da01c7b 100644 --- a/ape_solidity/compiler.py +++ b/ape_solidity/compiler.py @@ -51,11 +51,6 @@ SolcInstallError, ) -# Define a regex pattern that matches import statements -# Both single and multi-line imports will be matched -IMPORTS_PATTERN = re.compile( - r"import\s+(([\s\S]*?)(?=;)|[\s\S]*?from\s+([^\s;]+));\s*", flags=re.MULTILINE -) LICENSES_PATTERN = re.compile(r"(// SPDX-License-Identifier:\s*([^\n]*)\s)") # Comment patterns @@ -1126,7 +1121,11 @@ def _flatten_source( final_source = "" - for import_str, source_id in relevant_imports.items(): # type: ignore + # type-ignore note: we know it is a dict because of `include_raw=True`. + import_items = relevant_imports.items() # type: ignore + + import_iter = sorted(import_items, key=lambda x: f"{x[1]}{x[0]}") + for import_str, source_id in import_iter: if source_id in handled: continue @@ -1270,8 +1269,30 @@ def _import_str_to_source_id( def remove_imports(source_code: str) -> str: + code = remove_comments(source_code) + result_lines: list[str] = [] + in_multiline_import = False + for line in code.splitlines(): + if line.lstrip().startswith("import ") or line.strip() == "import": + if not line.rstrip().endswith(";"): + in_multiline_import = True + + continue + + elif in_multiline_import: + if line.rstrip().endswith(";"): + in_multiline_import = False + + continue + + result_lines.append(line) + + return "\n".join(result_lines) + + +def remove_comments(source_code: str) -> str: in_multi_line_comment = False - result_lines = [] + result_lines: list[str] = [] lines = source_code.splitlines() for line in lines: @@ -1292,21 +1313,10 @@ def remove_imports(source_code: str) -> str: result_lines.append(line) continue - # Skip import statements in non-comment lines. - # NOTE: multi-line imports not handled until after loop. - if IMPORTS_PATTERN.search(line): - continue - # Add the line to the result if it's not an import statement result_lines.append(line) - result = "\n".join(result_lines) - - # Remove multi-line imports. - while IMPORTS_PATTERN.search(result): - result = IMPORTS_PATTERN.sub("", result) - - return result + return "\n".join(result_lines) def remove_version_pragmas(flattened_contract: str) -> str: diff --git a/tests/test_cli.py b/tests/test_cli.py index 6c8e7ed..b45da1c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -4,6 +4,152 @@ from ape_solidity._cli import cli +EXPECTED_FLATTENED_CONTRACT = """ +pragma solidity ^0.8.4; +// SPDX-License-Identifier: MIT + +// File: @browniedependency/contracts/BrownieContract.sol + +contract CompilingContract { + function foo() pure public returns(bool) { + return true; + } +} +// File: @dependencyofdependency/contracts/DependencyOfDependency.sol + +contract DependencyOfDependency { + function foo() pure public returns(bool) { + return true; + } +} + +// File: @dependency/contracts/Dependency.sol" as Depend2 + +struct DependencyStruct { + string name; + uint value; +} + +contract Dependency { + function foo() pure public returns(bool) { + return true; + } +} +// File: @noncompilingdependency/CompilingContract.sol + +contract BrownieStyleDependency { + function foo() pure public returns(bool) { + return true; + } +} +// File: @noncompilingdependency/subdir/SubCompilingContract.sol + +contract SubCompilingContract { + function foo() pure public returns(bool) { + return true; + } +} +// File: @safe/contracts/common/Enum.sol + +/// @title Enum - Collection of enums +/// @author Richard Meissner - +contract Enum { + enum Operation {Call, DelegateCall} +} +// File: { MyStruct } from "contracts/CompilesOnce.sol + +struct MyStruct { + string name; + uint value; +} + +contract CompilesOnce { + // This contract tests the scenario when we have a contract with + // a similar compiler version to more than one other contract's. + // This ensures we don't compile the same contract more than once. + + function foo() pure public returns(bool) { + return true; + } +} +// File: ./././././././././././././././././././././././././././././././././././MissingPragma.sol + +contract MissingPragma { + function foo() pure public returns(bool) { + return true; + } +} +// File: { Struct0, Struct1, Struct2, Struct3, Struct4, Struct5 } from "./NumerousDefinitions.sol + +struct Struct0 { + string name; + uint value; +} + +struct Struct1 { + string name; + uint value; +} + +struct Struct2 { + string name; + uint value; +} + +struct Struct3 { + string name; + uint value; +} + +struct Struct4 { + string name; + uint value; +} + +struct Struct5 { + string name; + uint value; +} + +contract NumerousDefinitions { + function foo() pure public returns(bool) { + return true; + } +} +// File: ./Source.extra.ext.sol + +// Showing sources with extra extensions are by default excluded, +// unless used as an import somewhere in a non-excluded source. +contract SourceExtraExt { + function foo() pure public returns(bool) { + return true; + } +} +// File: ./subfolder/Relativecontract.sol + +contract Relativecontract { + + function foo() pure public returns(bool) { + return true; + } +} + +// File: Imports.sol + +// Purposely repeat an import to test how the plugin handles that. + +// Purposely exclude the contracts folder to test older Ape-style project imports. + +// Showing sources with extra extensions are by default excluded, +// unless used as an import somewhere in a non-excluded source. + +contract Imports { + function foo() pure public returns(bool) { + return true; + } +} +""".strip() + def test_cli_flatten(project, cli_runner): path = project.contracts_folder / "Imports.sol" @@ -14,9 +160,8 @@ def test_cli_flatten(project, cli_runner): arguments.extend([str(file), *end]) result = cli_runner.invoke(cli, arguments, catch_exceptions=False) assert result.exit_code == 0, result.stderr_bytes - output = file.read_text(encoding="utf8") - breakpoint() - x = "" + output = file.read_text(encoding="utf8").strip() + assert output == EXPECTED_FLATTENED_CONTRACT def test_compile():