Skip to content

Commit

Permalink
move AWST validation step into puya core
Browse files Browse the repository at this point in the history
  • Loading branch information
achidlow committed Sep 5, 2024
1 parent c7d43c2 commit 26e66e7
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 55 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Iterator

import attrs

from puya import log
from puya.awst import (
nodes as awst_nodes,
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from puya.awst import nodes as awst_nodes

from puyapy.awst_build.validation.arc4_copy import ARC4CopyValidator
from puyapy.awst_build.validation.base_invoker import BaseInvokerValidator
from puyapy.awst_build.validation.inner_transactions import (
from puya.awst.validation.arc4_copy import ARC4CopyValidator
from puya.awst.validation.base_invoker import BaseInvokerValidator
from puya.awst.validation.inner_transactions import (
InnerTransactionsValidator,
InnerTransactionUsedInALoopValidator,
StaleInnerTransactionsValidator,
)
from puyapy.awst_build.validation.labels import LabelsValidator
from puyapy.awst_build.validation.scratch_slots import ScratchSlotReservationValidator
from puyapy.awst_build.validation.storage import StorageTypesValidator
from puya.awst.validation.labels import LabelsValidator
from puya.awst.validation.scratch_slots import ScratchSlotReservationValidator
from puya.awst.validation.storage import StorageTypesValidator


def validate_awst(module: awst_nodes.AWST) -> None:
Expand Down
File renamed without changes.
File renamed without changes.
3 changes: 3 additions & 0 deletions src/puya/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from puya.arc32 import create_arc32_json
from puya.artifact_sorter import ArtifactCompilationSorter
from puya.awst.nodes import AWST
from puya.awst.validation.main import validate_awst
from puya.context import CompileContext
from puya.errors import CodeError, InternalError
from puya.ir.main import awst_to_ir, optimize_and_destructure_ir
Expand Down Expand Up @@ -52,6 +53,8 @@ def awst_to_teal(
*,
write: bool = True,
) -> list[CompilationArtifact]:
validate_awst(awst)
log_ctx.exit_if_errors()
context = CompileContext(
options=options,
compilation_set=compilation_set,
Expand Down
2 changes: 0 additions & 2 deletions src/puyapy/awst_build/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
get_decorators_by_fullname,
get_unaliased_fullname,
)
from puyapy.awst_build.validation.main import validate_awst

logger = log.get_logger(__name__)

Expand Down Expand Up @@ -67,7 +66,6 @@ def convert(self) -> AWST:
for deferred in deferrals:
awst_node = deferred(self.context)
awst.append(awst_node)
validate_awst(awst) # TODO: move/split this to/with puya core
return awst

# Supported Statements
Expand Down
88 changes: 42 additions & 46 deletions tests/test_expected_output/data.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import contextlib
import difflib
import tempfile
import typing as t
from collections.abc import Iterator, Sequence
from pathlib import Path

import _pytest._code.code
import attrs
import pytest
from puya.awst.nodes import AWST
from puya.awst.to_code_visitor import ToCodeVisitor
from puya.compile import awst_to_teal
from puya.errors import PuyaError, log_exceptions
Expand All @@ -19,12 +20,6 @@

from tests.utils import narrowed_compile_context

if t.TYPE_CHECKING:
from collections.abc import Sequence

import _pytest._code.code
from puya.awst.nodes import AWST

THIS_DIR = Path(__file__).parent
REPO_DIR = THIS_DIR.parent.parent
CASE_COMMENT = "##"
Expand Down Expand Up @@ -300,49 +295,52 @@ def compile_and_update_cases(cases: list[TestCase]) -> None:
awst, compilation_set = transform_ast(parse_result)
# lower each case further if possible and process
for case in cases:
if case_has_awst_errors(awst_log_ctx.logs, case):
case_logs = []
else:
# lower awst for each case individually to order to get any output
# from lower layers
# this needs a new logging context so AWST errors from other cases
# are not seen
case_options = attrs.evolve(
puyapy_options, cli_template_definitions=case.template_vars
)
case_sources_by_path, case_compilation_set = narrowed_compile_context(
parse_result,
case_path[case],
awst,
compilation_set,
case_awst = [
n
for n in awst
if n.source_location.file == case_path[case]
# hacky way to keep "framework" sources included, good enough for now
# the real solution here is to remove mypy, so we don't need to do this special
# combine+split of sources to achieve decent mypy parsing speed
or n.source_location.line < 0
]
# lower awst for each case individually to order to get any output
# from lower layers
# this needs a new logging context so AWST errors from other cases
# are not seen
case_options = attrs.evolve(
puyapy_options, cli_template_definitions=case.template_vars
)
case_sources_by_path, case_compilation_set = narrowed_compile_context(
parse_result,
case_path[case],
awst,
compilation_set,
case_options,
)
with (
contextlib.suppress(SystemExit),
logging_context() as case_log_ctx,
log_exceptions(),
):
case_log_ctx.logs.extend(filter_logs(awst_log_ctx.logs, case))
awst_to_teal(
case_log_ctx,
case_options,
case_compilation_set,
case_sources_by_path,
case_awst,
write=False,
)
with (
contextlib.suppress(SystemExit),
logging_context() as case_log_ctx,
log_exceptions(),
):
awst_to_teal(
case_log_ctx,
case_options,
case_compilation_set,
case_sources_by_path,
awst,
write=False,
)
case_logs = case_log_ctx.logs
process_test_case(case, awst_log_ctx.logs + case_logs, awst)
process_test_case(case, case_log_ctx.logs, case_awst)


def case_has_awst_errors(captured_logs: list[Log], case: TestCase) -> bool:
def filter_logs(captured_logs: list[Log], case: TestCase) -> Iterator[Log]:
for file in case.files:
path = file.src_path
assert path is not None
abs_path = path.resolve()
path_records = [record for record in captured_logs if record.file == abs_path]
if any(r.level == LogLevel.error and r.line is not None for r in path_records):
return True
return False
yield from (record for record in captured_logs if record.file == abs_path)


def get_python_file_name(name: str) -> str:
Expand All @@ -369,13 +367,11 @@ def process_test_case(case: TestCase, captured_logs: Sequence[Log], awst: AWST)
for file in case.files:
path = file.src_path
assert path is not None
abs_path = path.resolve()
expected_output = {
(line, message)
for line, messages in file.expected_output.items()
for message in messages
}
path_records = [record for record in captured_logs if record.file == abs_path]
seen_output = {
(
record.line,
Expand All @@ -384,7 +380,7 @@ def process_test_case(case: TestCase, captured_logs: Sequence[Log], awst: AWST)
output=record.message.strip(),
),
)
for record in path_records
for record in captured_logs
if record.line is not None and record.level >= MIN_LEVEL_TO_REPORT
}
file_missing_output = expected_output - seen_output
Expand Down

0 comments on commit 26e66e7

Please sign in to comment.