diff --git a/wake/cli/test.py b/wake/cli/test.py index ba23c662..ad3a5b32 100644 --- a/wake/cli/test.py +++ b/wake/cli/test.py @@ -240,6 +240,14 @@ def shell_complete( count=True, help="Increase verbosity. Can be specified multiple times.", ) +@click.option( + "--shrink", + type=click.Path(exists=True, dir_okay=False, readable=True), # Ensure it's an existing file + help="Path to the shrink log file.", + is_flag=False, + flag_value=-1, + default=None, +) @click.argument("paths_or_pytest_args", nargs=-1, type=FileAndPassParamType()) @click.pass_context @@ -255,6 +263,7 @@ def run_test( attach_first: bool, dist: str, verbosity: int, + shrink: Optional[str], paths_or_pytest_args: Tuple[str, ...], ) -> None: """Execute Wake tests using pytest.""" @@ -361,6 +370,48 @@ def run_test( ) else: from wake.testing.pytest_plugin_single import PytestWakePluginSingle + from wake.development.globals import set_fuzz_mode,set_sequence_initial_internal_state, set_error_flow_num + + def extract_executed_flow_number(crash_log_file_path): + if crash_log_file_path is not None: + with open(crash_log_file_path, 'r') as file: + for line in file: + if "executed flow number" in line: + # Extract the number after the colon + parts = line.split(":") + if len(parts) == 2: + try: + executed_flow_number = int(parts[1].strip()) + return executed_flow_number + except ValueError: + pass # Handle the case where the value after ":" is not an integer + return None + + def extract_internal_state(crash_log_file_path): + if crash_log_file_path is not None: + with open(crash_log_file_path, 'r') as file: + for line in file: + if "Internal state of beginning of sequence" in line: + # Extract the part after the colon + parts = line.split(":") + if len(parts) == 2: + hex_string = parts[1].strip() + try: + # Convert the hex string to bytes + internal_state_bytes = bytes.fromhex(hex_string) + return internal_state_bytes + except ValueError: + pass # Handle the case where the value after ":" is not a valid hex string + return None + + if shrink is not None: + number = extract_executed_flow_number(shrink) + assert number is not None, "Unexpected file format" + set_fuzz_mode(1) + set_error_flow_num(number) + beginning_random_state_bytes = extract_internal_state(shrink) + assert beginning_random_state_bytes is not None, "Unexpected file format" + set_sequence_initial_internal_state(beginning_random_state_bytes) sys.exit( pytest.main( diff --git a/wake/development/globals.py b/wake/development/globals.py index fb71320e..bfa0fad7 100644 --- a/wake/development/globals.py +++ b/wake/development/globals.py @@ -57,6 +57,11 @@ _config: Optional[WakeConfig] = None _verbosity: int = 0 +_fuzz_mode: int = 0 + +_error_flow_num: int = 0 + + def attach_debugger( e_type: Optional[Type[BaseException]], @@ -106,6 +111,20 @@ def attach_debugger( p.reset() p.interaction(None, tb) +def get_fuzz_mode() -> int: + return _fuzz_mode + +def set_fuzz_mode(fuzz_mode: int): + global _fuzz_mode + _fuzz_mode = fuzz_mode + +def get_error_flow_num() -> int: + return _error_flow_num + +def set_error_flow_num(error_flow_num: int): + global _error_flow_num + _error_flow_num = error_flow_num + def get_exception_handler() -> Optional[ Callable[ @@ -132,11 +151,11 @@ def set_exception_handler( ): global _exception_handler _exception_handler = handler - + def set_sequence_initial_internal_state(intenral_state: bytes): global _initial_internal_state _initial_internal_state = intenral_state - + def get_sequence_initial_internal_state() -> bytes: return _initial_internal_state diff --git a/wake/testing/fuzzing/fuzz_test.py b/wake/testing/fuzzing/fuzz_test.py index 642d9cc0..cc751c6b 100644 --- a/wake/testing/fuzzing/fuzz_test.py +++ b/wake/testing/fuzzing/fuzz_test.py @@ -5,7 +5,7 @@ from typing_extensions import get_type_hints -from wake.development.globals import random, set_sequence_initial_internal_state +from wake.development.globals import random, set_sequence_initial_internal_state, get_fuzz_mode, get_sequence_initial_internal_state, set_error_flow_num, get_error_flow_num from ..core import get_connected_chains from .generators import generate @@ -71,85 +71,317 @@ def run( flows: List[Callable] = self.__get_methods("flow") invariants: List[Callable] = self.__get_methods("invariant") + fuzz_mode = get_fuzz_mode() + if fuzz_mode == 0: + for i in range(sequences_count): + flows_counter: DefaultDict[Callable, int] = defaultdict(int) + invariant_periods: DefaultDict[Callable[[None], None], int] = defaultdict( + int + ) + + snapshots = [chain.snapshot() for chain in chains] + + set_sequence_initial_internal_state( + pickle.dumps( + random.getstate() + ) + ) + + self._flow_num = 0 + self._sequence_num = i + self.pre_sequence() + + for j in range(flows_count): + valid_flows = [ + f + for f in flows + if ( + not hasattr(f, "max_times") + or flows_counter[f] < getattr(f, "max_times") + ) + and ( + not hasattr(f, "precondition") + or getattr(f, "precondition")(self) + ) + ] + weights = [getattr(f, "weight") for f in valid_flows] + if len(valid_flows) == 0: + max_times_flows = [ + f + for f in flows + if hasattr(f, "max_times") + and flows_counter[f] >= getattr(f, "max_times") + ] + precondition_flows = [ + f + for f in flows + if hasattr(f, "precondition") + and not getattr(f, "precondition")(self) + ] + raise Exception( + f"Could not find a valid flow to run.\nFlows that have reached their max_times: {max_times_flows}\nFlows that do not satisfy their precondition: {precondition_flows}" + ) + flow = random.choices(valid_flows, weights=weights)[0] + flow_params = [ + generate(v) + for k, v in get_type_hints(flow, include_extras=True).items() + if k != "return" + ] + + self._flow_num = j + set_error_flow_num(j) + self.pre_flow(flow) + flow(self, *flow_params) + flows_counter[flow] += 1 + self.post_flow(flow) + + if not dry_run: + self.pre_invariants() + for inv in invariants: + if invariant_periods[inv] == 0: + self.pre_invariant(inv) + inv(self) + self.post_invariant(inv) + + invariant_periods[inv] += 1 + if invariant_periods[inv] == getattr(inv, "period"): + invariant_periods[inv] = 0 + self.post_invariants() + + self.post_sequence() + + for snapshot, chain in zip(snapshots, chains): + chain.revert(snapshot) + + elif(fuzz_mode == 1): + + error_flow_num = get_error_flow_num() + random_stored_states: List[bytes] = [] - for i in range(sequences_count): flows_counter: DefaultDict[Callable, int] = defaultdict(int) invariant_periods: DefaultDict[Callable[[None], None], int] = defaultdict( int ) snapshots = [chain.snapshot() for chain in chains] - - set_sequence_initial_internal_state( - pickle.dumps( - random.getstate() - ) - ) - + + state = get_sequence_initial_internal_state() + random.setstate(pickle.loads(state)) + self._flow_num = 0 - self._sequence_num = i + self._sequence_num = 0 self.pre_sequence() - for j in range(flows_count): - valid_flows = [ - f - for f in flows - if ( - not hasattr(f, "max_times") - or flows_counter[f] < getattr(f, "max_times") - ) - and ( - not hasattr(f, "precondition") - or getattr(f, "precondition")(self) - ) - ] - weights = [getattr(f, "weight") for f in valid_flows] - if len(valid_flows) == 0: - max_times_flows = [ + exception = False + try: + for j in range(flows_count): + valid_flows = [ f for f in flows - if hasattr(f, "max_times") - and flows_counter[f] >= getattr(f, "max_times") + if ( + not hasattr(f, "max_times") + or flows_counter[f] < getattr(f, "max_times") + ) + and ( + not hasattr(f, "precondition") + or getattr(f, "precondition")(self) + ) ] - precondition_flows = [ - f - for f in flows - if hasattr(f, "precondition") - and not getattr(f, "precondition")(self) + weights = [getattr(f, "weight") for f in valid_flows] + if len(valid_flows) == 0: + max_times_flows = [ + f + for f in flows + if hasattr(f, "max_times") + and flows_counter[f] >= getattr(f, "max_times") + ] + precondition_flows = [ + f + for f in flows + if hasattr(f, "precondition") + and not getattr(f, "precondition")(self) + ] + raise Exception( + f"Could not find a valid flow to run.\nFlows that have reached their max_times: {max_times_flows}\nFlows that do not satisfy their precondition: {precondition_flows}" + ) + random_stored_states.append(pickle.dumps(random.getstate())) + flow = random.choices(valid_flows, weights=weights)[0] + flow_params = [ + generate(v) + for k, v in get_type_hints(flow, include_extras=True).items() + if k != "return" ] - raise Exception( - f"Could not find a valid flow to run.\nFlows that have reached their max_times: {max_times_flows}\nFlows that do not satisfy their precondition: {precondition_flows}" + + self._flow_num = j + self.pre_flow(flow) + flow(self, *flow_params) + flows_counter[flow] += 1 + self.post_flow(flow) + + if not dry_run: + self.pre_invariants() + for inv in invariants: + if invariant_periods[inv] == 0: + self.pre_invariant(inv) + inv(self) + self.post_invariant(inv) + + invariant_periods[inv] += 1 + if invariant_periods[inv] == getattr(inv, "period"): + invariant_periods[inv] = 0 + self.post_invariants() + self.post_sequence() + + for snapshot, chain in zip(snapshots, chains): + chain.revert(snapshot) + except Exception: + exception = True + + for snapshot, chain in zip(snapshots, chains): + chain.revert(snapshot) + + assert self._flow_num == error_flow_num, "Unexpected failing flow" + if exception == False: + raise Exception("Exception not raised unexpected state changes") + + + + print("Random state corrected: ", error_flow_num) + print("Starting shrinking") + + run_flows: List[bool] = [True] * (error_flow_num+1) + + curr = 0 # current testing flow index + + class OverRunException(Exception): + def __init__(self): + super().__init__("Overrun") + + while curr <= error_flow_num: + run_flows[curr] = False + + flows_counter: DefaultDict[Callable, int] = defaultdict(int) + invariant_periods: DefaultDict[Callable[[None], None], int] = defaultdict( + int + ) + snapshots = [chain.snapshot() for chain in chains] + + set_sequence_initial_internal_state( + pickle.dumps( + random.getstate() ) - flow = random.choices(valid_flows, weights=weights)[0] - flow_params = [ - generate(v) - for k, v in get_type_hints(flow, include_extras=True).items() - if k != "return" - ] - - self._flow_num = j - self.pre_flow(flow) - flow(self, *flow_params) - flows_counter[flow] += 1 - self.post_flow(flow) - - if not dry_run: - self.pre_invariants() - for inv in invariants: - if invariant_periods[inv] == 0: - self.pre_invariant(inv) - inv(self) - self.post_invariant(inv) - - invariant_periods[inv] += 1 - if invariant_periods[inv] == getattr(inv, "period"): - invariant_periods[inv] = 0 - self.post_invariants() - - self.post_sequence() - - for snapshot, chain in zip(snapshots, chains): - chain.revert(snapshot) + ) + self._flow_num = 0 + self._sequence_num = 0 + self.pre_sequence() + exception = False + try: + + for j in range(flows_count): + if self._flow_num > error_flow_num: + raise OverRunException() + + valid_flows = [ + f + for f in flows + if ( + not hasattr(f, "max_times") + or flows_counter[f] < getattr(f, "max_times") + ) + and ( + not hasattr(f, "precondition") + or getattr(f, "precondition")(self) + ) + ] + weights = [getattr(f, "weight") for f in valid_flows] + if len(valid_flows) == 0: + max_times_flows = [ + f + for f in flows + if hasattr(f, "max_times") + and flows_counter[f] >= getattr(f, "max_times") + ] + precondition_flows = [ + f + for f in flows + if hasattr(f, "precondition") + and not getattr(f, "precondition")(self) + ] + raise Exception( + f"Could not find a valid flow to run.\nFlows that have reached their max_times: {max_times_flows}\nFlows that do not satisfy their precondition: {precondition_flows}" + ) + + random.setstate(pickle.loads(random_stored_states[j])) + flow = random.choices(valid_flows, weights=weights)[0] + flow_params = [ + generate(v) + for k, v in get_type_hints(flow, include_extras=True).items() + if k != "return" + ] + # print(j) + if run_flows[j]: + self._flow_num = j + self.pre_flow(flow) + flow(self, *flow_params) + flows_counter[flow] += 1 + self.post_flow(flow) + + if not dry_run: + self.pre_invariants() + for inv in invariants: + if invariant_periods[inv] == 0: + self.pre_invariant(inv) + inv(self) + self.post_invariant(inv) + + invariant_periods[inv] += 1 + if invariant_periods[inv] == getattr(inv, "period"): + invariant_periods[inv] = 0 + self.post_invariants() + self.post_sequence() + except OverRunException: + exception = False # since it is not test exception + except Exception: + exception = True + for snapshot, chain in zip(snapshots, chains): + chain.revert(snapshot) + + if self._flow_num == error_flow_num: + # the removed flow is not required to reproduce same error. @ try remove next flow + print("remove worked!!, ", curr) + assert run_flows[curr] == False + if curr == error_flow_num: + # the final flow is required since it caused the error + run_flows[curr] = True + # run_flows[curr] = False + pass + else: + # the removing flow caused different error . @this flow should not removed restore current flow and remove next flow + run_flows[curr] = True + pass + + if exception == False: + for snapshot, chain in zip(snapshots, chains): + chain.revert(snapshot) + + run_flows[curr] = True + # the removed flow is required to reproduce same error. @ this flow should not removed # restore current flow and remove next flow + + print("True!!", run_flows[curr], curr) + curr += 1 + + + print("Shrinking flow: ", curr) + print("sum flows: ", error_flow_num) + + print("Shrinking completed") + print(run_flows) + print(len(run_flows)) + print(sum(run_flows)) + + + else: + raise Exception("Invalid fuzz mode") def pre_sequence(self) -> None: pass diff --git a/wake/testing/pytest_plugin_single.py b/wake/testing/pytest_plugin_single.py index b40423b0..3b9e65e8 100644 --- a/wake/testing/pytest_plugin_single.py +++ b/wake/testing/pytest_plugin_single.py @@ -17,7 +17,8 @@ reset_exception_handled, set_coverage_handler, set_exception_handler, - get_sequence_initial_internal_state + get_sequence_initial_internal_state, + get_error_flow_num, ) from wake.testing.coverage import ( CoverageHandler, @@ -75,8 +76,9 @@ def pytest_exception_interact(self, node, call, report): ) file_console = Console(file=f, force_terminal=False) file_console.print(rich_tb) - f.write(f"\nInternal state of beginning of sequence : \n{state.hex()}") - + f.write(f"\nInternal state of beginning of sequence : {state.hex()}\n") + f.write(f"executed flow number : {get_error_flow_num()}\n") + def pytest_runtestloop(self, session: Session): if ( session.testsfailed @@ -92,7 +94,7 @@ def pytest_runtestloop(self, session: Session): coverage = self._cov_proc_count == 1 or self._cov_proc_count == -1 - + if len(self._random_states) > 0: assert self._random_states[0] is not None random.setstate(pickle.loads(self._random_states[0]))