Skip to content

Commit

Permalink
fix: bug with flattening
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Jun 13, 2024
1 parent 0e04df7 commit fc46688
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 120 deletions.
84 changes: 57 additions & 27 deletions ape_solidity/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,15 @@
# Define a regex pattern that matches import statements
# Both single and multi-line imports will be matched
IMPORTS_PATTERN = re.compile(
r"import\s+((.*?)(?=;)|[\s\S]*?from\s+(.*?)(?=;));\s", flags=re.MULTILINE
r"import\s+(([\s\S]*?)(?=;)|[\s\S]*?from\s+([^\s;]+));\s*", flags=re.MULTILINE
)
LICENSES_PATTERN = re.compile(r"(// SPDX-License-Identifier:\s*([^\n]*)\s)")

# Comment patterns
SINGLE_LINE_COMMENT_PATTERN = re.compile(r"^\s*//")
MULTI_LINE_COMMENT_START_PATTERN = re.compile(r"/\*")
MULTI_LINE_COMMENT_END_PATTERN = re.compile(r"\*/")

VERSION_PRAGMA_PATTERN = re.compile(r"pragma solidity[^;]*;")
DEFAULT_OPTIMIZATION_RUNS = 200

Expand Down Expand Up @@ -142,7 +148,7 @@ class SolidityConfig(PluginConfig):
def _get_flattened_source(path: Path, name: Optional[str] = None) -> str:
name = name or path.name
result = f"// File: {name}\n"
result += path.read_text() + "\n"
result += f"{path.read_text().rstrip()}\n"
return result


Expand Down Expand Up @@ -471,19 +477,6 @@ def get_standard_input_json(
version_map, remapping, project=pm, import_map=import_map, **overrides
)

def get_standard_input_json_from(
self,
version_map: dict[Version, set[Path]],
import_remappings: dict[str, str],
project: Optional[ProjectManager] = None,
**overrides,
):
pm = project or self.local_project
settings = self._get_settings_from_version_map(
version_map, import_remappings, project=pm, **overrides
)
return self.get_standard_input_json_from_settings(settings, version_map, project=pm)

def get_standard_input_json_from_version_map(
self,
version_map: dict[Version, set[Path]],
Expand Down Expand Up @@ -1112,14 +1105,17 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError:

def _flatten_source(
self,
path: Path,
path: Union[Path, str],
project: Optional[ProjectManager] = None,
raw_import_name: Optional[str] = None,
handled: Optional[set[str]] = None,
) -> str:
pm = project or self.local_project
handled = handled or set()
source_id = f"{get_relative_path(path, pm.path)}"

path = Path(path)
source_id = f"{get_relative_path(path, pm.path)}" if path.is_absolute() else f"{path}"

handled.add(source_id)
remapping = self.get_import_remapping(project=project)
imports = self._get_imports((path,), remapping, pm, tracked=set(), include_raw=True)
Expand All @@ -1132,26 +1128,36 @@ def _flatten_source(
continue

sub_import_name = import_str.replace("import ", "").strip(" \n\t;\"'")
final_source += self._flatten_source(
sub_source = self._flatten_source(
pm.path / source_id,
project=pm,
raw_import_name=sub_import_name,
handled=handled,
)
final_source += sub_source

flattened_src = _get_flattened_source(path, name=raw_import_name)
if flattened_src and final_source.rstrip():
final_source = f"{final_source.rstrip()}\n\n{flattened_src}"
elif flattened_src:
final_source = flattened_src

final_source += _get_flattened_source(path, name=raw_import_name)
return final_source

def flatten_contract(
self, path: Path, project: Optional[ProjectManager] = None, **kwargs
) -> Content:
# try compiling in order to validate it works
res = self._flatten_source(path, project=project)
res = remove_imports(res)
res = process_licenses(res)
res = remove_version_pragmas(res)
pragma = get_first_version_pragma(path.read_text())
res = "\n".join([pragma, res])

# Simple auto-format.
while "\n\n\n" in res:
res = res.replace("\n\n\n", "\n\n")

lines = res.splitlines()
line_dict = {i + 1: line for i, line in enumerate(lines)}
return Content(root=line_dict)
Expand Down Expand Up @@ -1260,11 +1266,37 @@ def _import_str_to_source_id(
return f"{get_relative_path(path.absolute(), pm.path)}"


def remove_imports(flattened_contract: str) -> str:
# Use regex.sub() to remove matched import statements
no_imports_contract = IMPORTS_PATTERN.sub("", flattened_contract)
def remove_imports(source_code: str) -> str:
in_multi_line_comment = False
result_lines = []

lines = source_code.splitlines()
for line in lines:
# Check if we're entering a multi-line comment
if MULTI_LINE_COMMENT_START_PATTERN.search(line):
in_multi_line_comment = True

# If inside a multi-line comment, just add the line to the result
if in_multi_line_comment:
result_lines.append(line)
# Check if this line ends the multi-line comment
if MULTI_LINE_COMMENT_END_PATTERN.search(line):
in_multi_line_comment = False
continue

# Skip single-line comments
if SINGLE_LINE_COMMENT_PATTERN.match(line):
result_lines.append(line)
continue

# Skip import statements in non-comment lines
if IMPORTS_PATTERN.search(line):
continue

# Add the line to the result if it's not an import statement
result_lines.append(line)

return no_imports_contract
return "\n".join(result_lines)


def remove_version_pragmas(flattened_contract: str) -> str:
Expand Down Expand Up @@ -1301,9 +1333,7 @@ def process_licenses(contract: str) -> str:
license_line, root_license = extracted_licenses[-1]

# Get the unique license identifiers. All licenses in a contract _should_ be the same.
unique_license_identifiers = {
license_identifier for _, license_identifier in extracted_licenses
}
unique_license_identifiers = {lid for _, lid in extracted_licenses}

# If we have more than one unique license identifier, warn the user and use the root.
if len(unique_license_identifiers) > 1:
Expand Down
4 changes: 0 additions & 4 deletions tests/data/ImportingLessConstrainedVersionFlat.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ pragma solidity =0.8.12;

// File: ./SpecificVersionRange.sol



contract SpecificVersionRange {
function foo() pure public returns(bool) {
return true;
Expand All @@ -13,8 +11,6 @@ contract SpecificVersionRange {

// File: ImportingLessConstrainedVersion.sol



// The file we are importing specific range '>=0.8.12 <0.8.15';
// This means on its own, the plugin would use 0.8.14 if its installed.
// However - it should use 0.8.12 because of this file's requirements.
Expand Down
Loading

0 comments on commit fc46688

Please sign in to comment.