diff --git a/README.md b/README.md index 7374b7e..5bdc9d4 100644 --- a/README.md +++ b/README.md @@ -84,4 +84,5 @@ You can customize name and location of the fixture, but you will need to include ``` ## Upgrading from v1 -As `pytest-interface-tester` v2 is using `pydantic` v2 that introduces breaking changes to their API, you might need to adjust your tested charm to also support v2. See [migration guide](https://docs.pydantic.dev/latest/migration/) for more information. \ No newline at end of file +`pytest-interface-tester` supports both pydantic v1 and v2, but using v2 is recommended. +You might need to adjust your tested charm to also support v2. See [migration guide](https://docs.pydantic.dev/latest/migration/) for more information. \ No newline at end of file diff --git a/interface_tester/collector.py b/interface_tester/collector.py index 8f62a99..d297455 100644 --- a/interface_tester/collector.py +++ b/interface_tester/collector.py @@ -105,6 +105,13 @@ def load_schema_module(schema_path: Path) -> types.ModuleType: if module_name in sys.modules: del sys.modules[module_name] + if pydantic.version.VERSION.split(".") <= ["2"]: + # in pydantic v1 it's necessary; in v2 it isn't. + + # Otherwise we'll get an error when we re-run @validator + logger.debug("Clearing pydantic.class_validators._FUNCS") + pydantic.class_validators._FUNCS.clear() # noqa + try: module = importlib.import_module(module_name) except ImportError: diff --git a/interface_tester/interface_test.py b/interface_tester/interface_test.py index 4def3de..daa008c 100644 --- a/interface_tester/interface_test.py +++ b/interface_tester/interface_test.py @@ -10,9 +10,11 @@ from enum import Enum from typing import Any, Callable, List, Literal, Optional, Union +import pydantic from ops.testing import CharmType from pydantic import ValidationError -from scenario import Context, Event, Relation, State +from scenario import Context, Event, Relation, State, state +from scenario.state import _EventPath from interface_tester.errors import InvalidTestCaseError, SchemaValidationError @@ -28,6 +30,20 @@ logger = logging.getLogger(__name__) +_has_pydantic_v1 = pydantic.version.VERSION.split(".") <= ["2"] +if _has_pydantic_v1: + logger.warning( + "You seem to be using pydantic v1. " + "Please upgrade to v2, as compatibility may be dropped in a future version of " + "pytest-interface-tester." + ) + + +def _validate(model: pydantic.BaseModel, obj: dict): + if _has_pydantic_v1: + return model.validate(obj) + return model.model_validate(obj) + class InvalidTestCase(RuntimeError): """Raised if a function decorated with interface_test_case is invalid.""" @@ -283,14 +299,15 @@ def assert_schema_valid(self, schema: Optional["DataBagSchema"] = None): errors = [] for relation in self._relations: try: - databag_schema.model_validate( + _validate( + databag_schema, { "unit": relation.local_unit_data, "app": relation.local_app_data, - } + }, ) except ValidationError as e: - errors.append(e.errors()[0]) + errors.append(str(e)) if errors: raise SchemaValidationError(errors) @@ -306,9 +323,16 @@ def assert_relation_data_empty(self): raise SchemaValidationError( f"test {self._test_id}: local app databag not empty for {relation}" ) - if relation.local_unit_data: + + # remove the default unit databag keys or we'll get false positives. + local_unit_data_keys = set(relation.local_unit_data).difference( + set(state.DEFAULT_JUJU_DATABAG.keys()) + ) + + if local_unit_data_keys: raise SchemaValidationError( - f"test {self._test_id}: local unit databag not empty for {relation}" + f"test {self._test_id}: local unit databag not empty for {relation}: " + f"found {local_unit_data_keys!r} keys set" ) self._has_checked_schema = True @@ -365,8 +389,7 @@ def _run(self, event: Union[str, Event]): # the Relation instance this test is about: relation = next(filter(lambda r: r.interface == self.ctx.interface_name, relations)) - # test.EVENT might be a string or an Event. Cast to Event. - evt: Event = self._coerce_event(event, relation) + evt: Event = self._cast_event(event, relation) logger.info("collected test for %s with %s" % (self.ctx.interface_name, evt.name)) return self._run_scenario(evt, modified_state) @@ -387,45 +410,50 @@ def _run_scenario(self, event: Event, state: State): ) return ctx.run(event, state) - def _coerce_event(self, raw_event: Union[str, Event], relation: Relation) -> Event: - # if the event being tested is a relation event, we need to inject some metadata - # or scenario.Runtime won't be able to guess what envvars need setting before ops.main - # takes over - if isinstance(raw_event, str): - ep_name, _, evt_kind = raw_event.rpartition("-relation-") - if ep_name and evt_kind: - # this is a relation event. - # we inject the relation metadata - # todo: if the user passes a relation event that is NOT about the relation - # interface that this test is about, at this point we are injecting the wrong - # Relation instance. - # e.g. if in interfaces/foo one wants to test that if 'bar-relation-joined' is - # fired... then one would have to pass an Event instance already with its - # own Relation. - return Event( - raw_event, - relation=relation.replace(endpoint=ep_name), - ) - - else: - return Event(raw_event) - - elif isinstance(raw_event, Event): - if raw_event._is_relation_event and not raw_event.relation: - raise InvalidTestCaseError( - "This test case was passed an Event representing a relation event." - "However it does not have a Relation. Please pass it to the Event like so: " - "evt = Event('my_relation_changed', relation=Relation(...))" - ) - - return raw_event + def _cast_event(self, raw_event: Union[str, Event], relation: Relation): + # test.EVENT might be a string or an Event. Cast to Event. + event = Event(raw_event) if isinstance(raw_event, str) else raw_event - else: + if not isinstance(event, Event): raise InvalidTestCaseError( f"Expected Event or str, not {type(raw_event)}. " f"Invalid test case: {self} cannot cast {raw_event} to Event." ) + if not event._is_relation_event: + raise InvalidTestCaseError( + f"Bad interface test specification: event {raw_event} " "is not a relation event." + ) + + # todo: if the user passes a relation event that is NOT about the relation + # interface that this test is about, at this point we are injecting the wrong + # Relation instance. + # e.g. if in interfaces/foo one wants to test that if 'bar-relation-joined' is + # fired... then one would have to pass an Event instance already with its + # own Relation. + + # next we need to ensure that the event's .relation is our relation, and that the endpoint + # in the relation and the event path match that of the charm we're testing. + charm_event = event.replace( + relation=relation, path=relation.endpoint + typing.cast(_EventPath, event.path).suffix + ) + + return charm_event + + @staticmethod + def _get_endpoint(supported_endpoints: dict, role: Role, interface_name: str): + endpoints_for_interface = supported_endpoints[role] + + if len(endpoints_for_interface) < 1: + raise ValueError(f"no endpoint found for {role}/{interface_name}.") + elif len(endpoints_for_interface) > 1: + raise ValueError( + f"Multiple endpoints found for {role}/{interface_name}: " + f"{endpoints_for_interface}: cannot guess which one it is " + f"we're supposed to be testing" + ) + return endpoints_for_interface[0] + def _generate_relations_state( self, state_template: State, input_state: State, supported_endpoints, role: Role ) -> List[Relation]: @@ -437,6 +465,9 @@ def _generate_relations_state( """ interface_name = self.ctx.interface_name + # determine what charm endpoint we're testing. + endpoint = self._get_endpoint(supported_endpoints, role, interface_name=interface_name) + for rel in state_template.relations: if rel.interface == interface_name: logger.warning( @@ -448,18 +479,31 @@ def _generate_relations_state( def filter_relations(rels: List[Relation], op: Callable): return [r for r in rels if op(r.interface, interface_name)] - # the baseline is: all relations whose interface IS NOT the interface we're testing. + # the baseline is: all relations provided by the charm in the state_template, + # whose interface IS NOT the interface we're testing. We assume the test (input_state) is + # the ultimate owner of the state when it comes to the interface we're testing. + # We don't allow the charm to mess with it. relations = filter_relations(state_template.relations, op=operator.ne) if input_state: - # if the charm we're testing specified some relations in its input state, we add those - # whose interface IS the same as the one we're testing. If other relation interfaces - # were specified, they will be ignored. - relations.extend(filter_relations(input_state.relations, op=operator.eq)) - - if ignored := filter_relations(input_state.relations, op=operator.eq): + # if the interface test we're running specified some relations in its input_state, + # we add those whose interface IS the same as the one we're testing. + # If other relation interfaces were specified (for whatever reason?), + # they will be ignored. + relations_from_input_state = filter_relations(input_state.relations, op=operator.eq) + + # relations that come from the state_template presumably have the right endpoint, + # but those that we get from interface tests cannot. + relations_with_endpoint = [ + r.replace(endpoint=endpoint) for r in relations_from_input_state + ] + + relations.extend(relations_with_endpoint) + + if ignored := filter_relations(input_state.relations, op=operator.ne): + # this is a sign of a bad test. logger.warning( - "irrelevant relations specified in input state for %s/%s." + "irrelevant relations specified in input_state for %s/%s." "These will be ignored. details: %s" % (interface_name, role, ignored) ) @@ -468,19 +512,6 @@ def filter_relations(rels: List[Relation], op: Callable): if not filter_relations(relations, op=operator.eq): # if neither the charm nor the interface specified any custom relation spec for # the interface we're testing, we will provide one. - endpoints_for_interface = supported_endpoints[role] - - if len(endpoints_for_interface) < 1: - raise ValueError(f"no endpoint found for {role}/{interface_name}.") - elif len(endpoints_for_interface) > 1: - raise ValueError( - f"Multiple endpoints found for {role}/{interface_name}: " - f"{endpoints_for_interface}: cannot guess which one it is " - f"we're supposed to be testing" - ) - else: - endpoint = endpoints_for_interface[0] - relations.append( Relation( interface=interface_name, @@ -491,4 +522,5 @@ def filter_relations(rels: List[Relation], op: Callable): "%s: merged %s and %s --> relations=%s" % (self, input_state, state_template, relations) ) + return relations diff --git a/interface_tester/plugin.py b/interface_tester/plugin.py index 4dbb598..26331cd 100644 --- a/interface_tester/plugin.py +++ b/interface_tester/plugin.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type from ops.testing import CharmType -from scenario.state import Event, MetadataNotFoundError, State, _CharmSpec +from scenario.state import MetadataNotFoundError, State, _CharmSpec from interface_tester.collector import InterfaceTestSpec, gather_test_spec_for_version from interface_tester.errors import ( @@ -22,7 +22,6 @@ ) from interface_tester.schema_base import DataBagSchema -Callback = Callable[[State, Event], None] ROLE_TO_ROLE_META = {"provider": "provides", "requirer": "requires"} logger = logging.getLogger("pytest_interface_tester") @@ -328,8 +327,11 @@ def run(self) -> bool: with tester_context(ctx): test_fn() except Exception as e: + logger.exception(f"Interface tester plugin failed with {e}") + if self._RAISE_IMMEDIATELY: raise e + errors.append((ctx, e)) ran_some = True diff --git a/pyproject.toml b/pyproject.toml index f5ad5e3..d06c1c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ build-backend = "setuptools.build_meta" [project] name = "pytest-interface-tester" -version = "2.0.0" +version = "2.0.1" authors = [ { name = "Pietro Pasotti", email = "pietro.pasotti@canonical.com" }, ] @@ -20,7 +20,7 @@ license.text = "Apache-2.0" keywords = ["juju", "relation interfaces"] dependencies = [ - "pydantic>=2", + "pydantic>= 1.10.7", "typer==0.7.0", "ops-scenario>=5.2", "pytest" diff --git a/tests/unit/test_collect_schemas.py b/tests/unit/test_collect_schemas.py index 9500021..7357da4 100644 --- a/tests/unit/test_collect_schemas.py +++ b/tests/unit/test_collect_schemas.py @@ -8,6 +8,7 @@ get_schema_from_module, load_schema_module, ) +from interface_tester.interface_test import _has_pydantic_v1 def test_load_schema_module(tmp_path): @@ -70,8 +71,17 @@ class RequirerSchema(DataBagSchema): ) tests = collect_tests(root) - assert tests["mytestinterfacea"]["v0"]["requirer"]["schema"].model_fields["foo"].default == 1 - assert tests["mytestinterfaceb"]["v0"]["requirer"]["schema"].model_fields["foo"].default == 2 + if _has_pydantic_v1: + assert tests["mytestinterfacea"]["v0"]["requirer"]["schema"].__fields__["foo"].default == 1 + assert tests["mytestinterfaceb"]["v0"]["requirer"]["schema"].__fields__["foo"].default == 2 + + else: + assert ( + tests["mytestinterfacea"]["v0"]["requirer"]["schema"].model_fields["foo"].default == 1 + ) + assert ( + tests["mytestinterfaceb"]["v0"]["requirer"]["schema"].model_fields["foo"].default == 2 + ) def test_collect_invalid_schemas(tmp_path): diff --git a/tests/unit/test_e2e.py b/tests/unit/test_e2e.py index 5a021ed..e6f9f12 100644 --- a/tests/unit/test_e2e.py +++ b/tests/unit/test_e2e.py @@ -9,7 +9,7 @@ from interface_tester import InterfaceTester from interface_tester.collector import gather_test_spec_for_version -from interface_tester.errors import SchemaValidationError +from interface_tester.errors import InvalidTestCaseError, SchemaValidationError from interface_tester.interface_test import ( InvalidTesterRunError, NoSchemaError, @@ -40,8 +40,9 @@ def interface_tester(): charm_type=DummiCharm, meta={ "name": "dummi", - "provides": {"tracing": {"interface": "tracing"}}, - "requires": {"tracing-req": {"interface": "tracing"}}, + # interface tests should be agnostic to endpoint names + "provides": {"dead": {"interface": "tracing"}}, + "requires": {"beef-req": {"interface": "tracing"}}, }, state_template=State(leader=True), ) @@ -87,8 +88,9 @@ def _collect_interface_test_specs(self): charm_type=DummiCharm, meta={ "name": "dummi", - "provides": {"tracing": {"interface": "tracing"}}, - "requires": {"tracing-req": {"interface": "tracing"}}, + # interface tests should be agnostic to endpoint names + "provides": {"dead": {"interface": "tracing"}}, + "requires": {"beef-req": {"interface": "tracing"}}, }, state_template=State(leader=True), ) @@ -107,7 +109,7 @@ def test_error_if_skip_schema_before_run(): def test_data_on_changed(): t = Tester(State( relations=[Relation( - endpoint='tracing', + endpoint='foobadooble', # should not matter interface='tracing', remote_app_name='remote', local_app_data={} @@ -122,6 +124,37 @@ def test_data_on_changed(): tester.run() +def test_error_if_not_relation_event(): + tester = _setup_with_test_file( + dedent( + """ +from scenario import State, Relation + +from interface_tester.interface_test import Tester + +def test_data_on_changed(): + t = Tester(State( + relations=[Relation( + endpoint='foobadooble', # should not matter + interface='tracing', + remote_app_name='remote', + local_app_data={} + )] + )) + t.run("foobadooble-changed") + t.skip_schema_validation() +""" + ) + ) + + with pytest.raises(InvalidTestCaseError) as e: + tester.run() + + assert e.match( + "Bad interface test specification: event foobadooble-changed is not a relation event." + ) + + def test_error_if_assert_relation_data_empty_before_run(): tester = _setup_with_test_file( dedent( @@ -133,7 +166,7 @@ def test_error_if_assert_relation_data_empty_before_run(): def test_data_on_changed(): t = Tester(State( relations=[Relation( - endpoint='tracing', + endpoint='foobadooble', # should not matter interface='tracing', remote_app_name='remote', local_app_data={} @@ -160,7 +193,7 @@ def test_error_if_assert_schema_valid_before_run(): def test_data_on_changed(): t = Tester(State( relations=[Relation( - endpoint='tracing', + endpoint='foobadooble', # should not matter interface='tracing', remote_app_name='remote', local_app_data={} @@ -186,13 +219,13 @@ def test_error_if_assert_schema_without_schema(): def test_data_on_changed(): t = Tester(State( relations=[Relation( - endpoint='tracing', + endpoint='foobadooble', # should not matter interface='tracing', remote_app_name='remote', local_app_data={} )] )) - state_out = t.run("tracing-relation-changed") + state_out = t.run("axolotl-relation-changed") t.assert_schema_valid() """ ) @@ -213,13 +246,13 @@ def test_error_if_return_before_schema_call(): def test_data_on_changed(): t = Tester(State( relations=[Relation( - endpoint='tracing', + endpoint='foobadooble', # should not matter interface='tracing', remote_app_name='remote', local_app_data={} )] )) - state_out = t.run("tracing-relation-changed") + state_out = t.run("axolotl-relation-changed") """ ) ) @@ -239,7 +272,7 @@ def test_error_if_return_without_run(): def test_data_on_changed(): t = Tester(State( relations=[Relation( - endpoint='tracing', + endpoint='foobadooble', # should not matter interface='tracing', remote_app_name='remote', local_app_data={} @@ -273,10 +306,14 @@ def test_data_on_changed(): tester.run() -def test_valid_run(): +@pytest.mark.parametrize( + "endpoint", ("foo-one", "prometheus-scrape", "foobadoodle", "foo-one-two") +) +@pytest.mark.parametrize("evt_type", ("changed", "created", "joined", "departed", "broken")) +def test_valid_run(endpoint, evt_type): tester = _setup_with_test_file( dedent( - """ + f""" from scenario import State, Relation from interface_tester.interface_test import Tester @@ -285,13 +322,13 @@ def test_valid_run(): def test_data_on_changed(): t = Tester(State( relations=[Relation( - endpoint='tracing', + endpoint='{endpoint}', # should not matter interface='tracing', remote_app_name='remote', - local_app_data={} + local_app_data={{}} )] )) - state_out = t.run("tracing-relation-changed") + state_out = t.run("{endpoint}-relation-{evt_type}") t.assert_schema_valid(schema=DataBagSchema()) """ ) @@ -312,14 +349,14 @@ def test_valid_run_default_schema(): def test_data_on_changed(): t = Tester(State( relations=[Relation( - endpoint='tracing', + endpoint='foobadooble', # should not matter interface='tracing', remote_app_name='remote', local_app_data={"foo":"1"}, local_unit_data={"bar": "smackbeef"} )] )) - state_out = t.run("tracing-relation-changed") + state_out = t.run("axolotl-relation-changed") t.assert_schema_valid() """ ), @@ -355,14 +392,14 @@ def test_default_schema_validation_failure(): def test_data_on_changed(): t = Tester(State( relations=[Relation( - endpoint='tracing', + endpoint='foobadooble', # should not matter interface='tracing', remote_app_name='remote', local_app_data={"foo":"abc"}, local_unit_data={"bar": "smackbeef"} )] )) - state_out = t.run("tracing-relation-changed") + state_out = t.run("axolotl-relation-changed") t.assert_schema_valid() """ ), @@ -408,14 +445,14 @@ class FooBarSchema(DataBagSchema): def test_data_on_changed(): t = Tester(State( relations=[Relation( - endpoint='tracing', + endpoint='foobadooble', # should not matter interface='tracing', remote_app_name='remote', local_app_data={"foo":"1"}, local_unit_data={"bar": "smackbeef"} )] )) - state_out = t.run("tracing-relation-changed") + state_out = t.run("axolotl-relation-changed") t.assert_schema_valid(schema=FooBarSchema) """ ) @@ -445,14 +482,14 @@ class FooBarSchema(DataBagSchema): def test_data_on_changed(): t = Tester(State( relations=[Relation( - endpoint='tracing', + endpoint='foobadooble', # should not matter interface='tracing', remote_app_name='remote', local_app_data={"foo":"abc"}, local_unit_data={"bar": "smackbeef"} )] )) - state_out = t.run("tracing-relation-changed") + state_out = t.run("axolotl-relation-changed") t.assert_schema_valid(schema=FooBarSchema) """ )