Skip to content

Commit

Permalink
feat: support easy context
Browse files Browse the repository at this point in the history
  • Loading branch information
FlickerSoul committed Nov 7, 2023
1 parent 26e1df9 commit 3556723
Show file tree
Hide file tree
Showing 20 changed files with 218 additions and 9 deletions.
13 changes: 11 additions & 2 deletions src/gapper/core/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ class ProblemConfig:
:param mock_input: Whether to mock the input of the solution.
:param captured_context: The context to capture from the submission.
:param is_script: Whether this problem is a script.
:param easy_context: Whether to use context directly in gap override tests.
"""

check_stdout: bool = False
mock_input: bool = False
captured_context: Iterable[str] = ()
easy_context: bool = False
is_script: bool = False


Expand Down Expand Up @@ -165,7 +167,7 @@ def from_path(cls, path: Path) -> Problem:

@overload
def problem(
*, is_script: bool = False, context: Iterable[str] = ()
*, is_script: bool = False, context: Iterable[str] = (), easy_context: bool = False
) -> Callable[
[Callable[ProbInputType, ProbOutputType]], Problem[ProbInputType, ProbOutputType]
]:
Expand All @@ -175,7 +177,11 @@ def problem(

@overload
def problem(
*, check_stdout: bool = False, mock_input: bool = False, context: Iterable[str] = ()
*,
check_stdout: bool = False,
mock_input: bool = False,
context: Iterable[str] = (),
easy_context: bool = False,
) -> Callable[
[Callable[ProbInputType, ProbOutputType]], Problem[ProbInputType, ProbOutputType]
]:
Expand All @@ -189,6 +195,7 @@ def problem(
check_stdout: Optional[bool] = None,
mock_input: Optional[bool] = None,
context: Iterable[str] = (),
easy_context: bool = False,
) -> Callable[
[Callable[ProbInputType, ProbOutputType]], Problem[ProbInputType, ProbOutputType]
]:
Expand All @@ -198,6 +205,7 @@ def problem(
:param check_stdout: Whether to check the stdout of the solution.
:param mock_input: Whether to mock the input of the solution.
:param context: The context to capture from the submission.
:param easy_context: Whether to use context directly in gap override tests.
"""

if is_script:
Expand All @@ -215,6 +223,7 @@ def problem(
mock_input=mock_input,
captured_context=context,
is_script=is_script,
easy_context=easy_context,
)

def _wrapper(
Expand Down
8 changes: 8 additions & 0 deletions src/gapper/core/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class GapReservedKeywords(Enum):
gap_max_score = "gap_max_score"
gap_extra_points = "gap_extra_points"
gap_override_check = "gap_override_check"
gap_easy_context = "gap_easy_context"
gap_override_test = "gap_override_test"
gap_post_checks = "gap_post_checks"
gap_description = "gap_description"
Expand All @@ -63,6 +64,7 @@ class ParamInfo:
gap_name: str | None = None
gap_extra_points: float | None = None
gap_override_check: CustomEqualityCheckFn | None = None
gap_easy_context: bool = False
gap_override_test: CustomTestFn | None = None
gap_post_checks: List[PostChecksFn] | PostChecksFn | None = None
gap_description: str | Iterable[str] | None = None
Expand Down Expand Up @@ -161,6 +163,7 @@ def __init__(
gap_name: str | None = None,
gap_extra_points: float | None = None,
gap_override_check: CustomEqualityCheckFn | None = None,
gap_easy_context: bool = False,
gap_override_test: CustomTestFn | None = None,
gap_post_checks: List[PostChecksFn] | PostChecksFn | None = None,
gap_description: str | Iterable[str] | None = None,
Expand All @@ -177,6 +180,7 @@ def __init__(
:param gap_name: The name of the test case.
:param gap_extra_points: The extra credit of the test case.
:param gap_override_check: The custom equality check function.
:param gap_easy_context: Whether to use context directly in gap override tests.
:param gap_override_test: The custom test function.
:param gap_post_checks: The custom post check functions.
:param gap_description: The description of the test case.
Expand All @@ -195,6 +199,7 @@ def __init__(
gap_name: str | None = None,
gap_extra_points: float | None = None,
gap_override_check: CustomEqualityCheckFn | None = None,
gap_easy_context: bool = False,
gap_override_test: CustomTestFn | None = None,
gap_post_checks: List[PostChecksFn] | PostChecksFn | None = None,
gap_description: str | Iterable[str] | None = None,
Expand All @@ -211,6 +216,7 @@ def __init__(
:param gap_name: The name of the test case.
:param gap_extra_points: The extra credit of the test case.
:param gap_override_check: The custom equality check function.
:param gap_easy_context: Whether to use context directly in gap override tests.
:param gap_override_test: The custom test function.
:param gap_post_checks: The custom post check functions.
:param gap_description: The description of the test case.
Expand Down Expand Up @@ -331,6 +337,7 @@ def __init__[
gap_override_check: CustomEqualityCheckFn
| Sequence[CustomEqualityCheckFn]
| None = None,
gap_easy_context: bool | Sequence[bool] = False,
gap_override_test: CustomTestFn | Sequence[CustomTestFn] | None = None,
gap_post_checks: List[List[PostChecksFn]]
| List[PostChecksFn]
Expand Down Expand Up @@ -361,6 +368,7 @@ def __init__[
gap_override_check: CustomEqualityCheckFn
| Sequence[CustomEqualityCheckFn]
| None = None,
gap_easy_context: bool | Sequence[bool] = False,
gap_override_test: CustomTestFn | Sequence[CustomTestFn] | None = None,
gap_post_checks: List[List[PostChecksFn]]
| List[PostChecksFn]
Expand Down
28 changes: 24 additions & 4 deletions src/gapper/core/unittest_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from gapper.core.errors import InternalError, SubmissionSyntaxError, TestFailedError
from gapper.core.pipeline_support import PipelineBase
from gapper.core.test_result import TestResult
from gapper.core.utils import CaptureStdout, generate_custom_input
from gapper.core.utils import (
CaptureStdout,
CustomTestFn,
apply_context_on_fn,
generate_custom_input,
)

if TYPE_CHECKING:
from gapper.core.problem import Problem
Expand Down Expand Up @@ -250,9 +255,18 @@ def _run_test(self, submission: Any, result: TestResult) -> TestResult:

if self.test_param.param_info.gap_override_test is not None:
self._logger.debug("Handing testing to gap_override_test")
self.test_param.param_info.gap_override_test(
self, result, self.problem.solution, submission
)
if (
self.problem.config.easy_context
or self.test_param.param_info.gap_easy_context
):
self._logger.debug("Using easy context")
self.gap_override_test_with_context(
self, result, self.problem.solution, submission
)
else:
self.test_param.param_info.gap_override_test(
self, result, self.problem.solution, submission
)
else:
if self.test_param.param_info.gap_override_check:
check_fn: CustomEqualityCheckFn = (
Expand Down Expand Up @@ -298,6 +312,12 @@ def _run_test(self, submission: Any, result: TestResult) -> TestResult:

return result

@property
def gap_override_test_with_context(self) -> CustomTestFn:
return apply_context_on_fn(
self.test_param.param_info.gap_override_test, self.context
)

def load_context(self, context: ContextManager) -> Self:
"""Load the submission context into the test case.
Expand Down
77 changes: 75 additions & 2 deletions src/gapper/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,35 @@
from __future__ import annotations

import importlib.util
import logging
from contextlib import redirect_stdout
from copy import copy
from functools import update_wrapper
from importlib.machinery import ModuleSpec
from io import StringIO
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Protocol, Self, Tuple
from types import FunctionType, ModuleType
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Protocol,
Self,
Tuple,
)

if TYPE_CHECKING:
from gapper.core.test_result import TestResult
from gapper.core.unittest_wrapper import TestCaseWrapper
from gapper.gradescope.datatypes.gradescope_meta import GradescopeSubmissionMetadata


_util_logger = logging.getLogger("gapper.core.utils")


class CustomTestFn(Protocol):
"""The custom test function protocol."""

Expand Down Expand Up @@ -194,3 +210,60 @@ def _load_module_spec_and_module(
@staticmethod
def _load_symbol_from_module(md: ModuleType, symbol: str) -> Any:
return getattr(md, symbol)
def apply_context_on_fn[T: FunctionType](f: T, context: dict[str, Any]) -> T:
"""Apply a context on a function.

:param f: The function to apply context on.
:param context: The context to be applied.
"""
if isinstance(f, FunctionType):
_util_logger.debug(f"Applying context {context} on function {f}")

# update closure with context
_util_logger.debug(f"Gathering closure with context")
closure_mod: Dict[str, int] = {}
if f.__closure__ is not None:
for context_var_name in context.keys():
try:
closure_pos = f.__code__.co_freevars.index(context_var_name)
_util_logger.debug(
f"Found closure variable {context_var_name} at position {closure_pos}"
)
closure_mod[context_var_name] = closure_pos
except ValueError:
_util_logger.debug(
f'Cannot find closure variable "{context_var_name}, skipped"'
)

g = FunctionType(
f.__code__,
{
**f.__globals__,
**{
c_name: c_val
for c_name, c_val in context.items()
if c_name not in closure_mod
},
}, # copy globals and update with context
name=f.__name__,
argdefs=f.__defaults__,
closure=f.__closure__,
)
g = update_wrapper(g, f)
g.__kwdefaults__ = copy(f.__kwdefaults__)

_util_logger.debug(f"Function {f} copied")

for c_name, c_pos in closure_mod.items():
_util_logger.debug(
f"Updating closure variable {c_name} at position {c_pos}"
)
g.__closure__[c_pos].cell_contents = context[c_name]

_util_logger.debug("Closure updated")

return g
else:
raise TypeError(f"Cannot apply context on {f} because it is not a function")
1 change: 1 addition & 0 deletions tests/assets/problems/add_numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ def add_numbers(a: int, b: int) -> int:
"check_stdout": False,
"mock_input": False,
"captured_context": (),
"easy_context": False,
}
1 change: 1 addition & 0 deletions tests/assets/problems/assess_post_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,5 @@ def fib(n: int) -> int:
"check_stdout": False,
"mock_input": False,
"captured_context": (),
"easy_context": False,
}
1 change: 1 addition & 0 deletions tests/assets/problems/assess_post_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ def square(x: int | float) -> int | float:
"check_stdout": False,
"mock_input": False,
"captured_context": (),
"easy_context": False,
}
3 changes: 2 additions & 1 deletion tests/assets/problems/capture_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Type
from typing import Type

from gapper import problem, test_case
from gapper.core.test_result import TestResult
Expand Down Expand Up @@ -32,4 +32,5 @@ def __init__(self, gas_station: GasStation):
"check_stdout": False,
"mock_input": False,
"captured_context": ("GasStation",),
"easy_context": False,
}
35 changes: 35 additions & 0 deletions tests/assets/problems/easy_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Callable

from gapper import problem, test_case, test_cases
from gapper.core.test_result import TestResult
from gapper.core.unittest_wrapper import TestCaseWrapper


def my_adder(a, b) -> int:
return a + b % 10


adder: Callable[[int, int], int]


def custom_test(param: TestCaseWrapper, result_proxy: TestResult, solution, submission):
assert solution(*param.test_param.args, my_adder) == submission(
*param.test_param.args, adder
)
assert my_adder(*param.test_param.args) == adder(*param.test_param.args)


@test_cases.param_iter(([i, i + 1] for i in range(10)), gap_override_test=custom_test)
@test_case(1, 2, gap_override_test=custom_test)
@problem(context=["adder"], easy_context=True)
def add(a: int, b: int, the_adder) -> int:
return the_adder(a, b)


__problem_config__ = {
"is_script": False,
"check_stdout": False,
"mock_input": False,
"captured_context": ["adder"],
"easy_context": True,
}
40 changes: 40 additions & 0 deletions tests/assets/problems/easy_context_with_locals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Callable

from gapper import problem, test_case, test_cases
from gapper.core.test_result import TestResult
from gapper.core.unittest_wrapper import TestCaseWrapper
from gapper.core.utils import CustomTestFn


def factory() -> CustomTestFn:
def my_adder(a, b) -> int:
return a + b % 10

adder: Callable[[int, int], int]

def custom_test(
param: TestCaseWrapper, result_proxy: TestResult, solution, submission
):
nonlocal adder
assert solution(*param.test_param.args, my_adder) == submission(
*param.test_param.args, adder
)
assert my_adder(*param.test_param.args) == adder(*param.test_param.args)

return custom_test


@test_cases.param_iter(([i, i + 1] for i in range(10)), gap_override_test=factory())
@test_case(1, 2, gap_override_test=factory())
@problem(context=["adder"], easy_context=True)
def add(a: int, b: int, the_adder) -> int:
return the_adder(a, b)


__problem_config__ = {
"is_script": False,
"check_stdout": False,
"mock_input": False,
"captured_context": ["adder"],
"easy_context": True,
}
1 change: 1 addition & 0 deletions tests/assets/problems/gap_kwargs_auto_populate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ def mul(a: int, b: int) -> int:
"check_stdout": False,
"mock_input": False,
"captured_context": (),
"easy_context": False,
}
1 change: 1 addition & 0 deletions tests/assets/problems/get_output_and_stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ def check_output_and_stdout(a: float, b: float) -> int:
"check_stdout": True,
"mock_input": False,
"captured_context": (),
"easy_context": False,
}
1 change: 1 addition & 0 deletions tests/assets/problems/input_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ def input_script() -> None:
"check_stdout": True,
"mock_input": True,
"captured_context": (),
"easy_context": False,
}
Loading

0 comments on commit 3556723

Please sign in to comment.