Skip to content

Commit

Permalink
Merge pull request #15 from canonical/pydantic-v2
Browse files Browse the repository at this point in the history
fixed endpoint determination from input_state
  • Loading branch information
PietroPasotti authored Feb 9, 2024
2 parents 50cd086 + 8f61287 commit 6de05b2
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 95 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,5 @@ You can customize name and location of the fixture, but you will need to include
```
## Upgrading from v1
As `pytest-interface-tester` v2 is using `pydantic` v2 that introduces breaking changes to their API, you might need to adjust your tested charm to also support v2. See [migration guide](https://docs.pydantic.dev/latest/migration/) for more information.
`pytest-interface-tester` supports both pydantic v1 and v2, but using v2 is recommended.
You might need to adjust your tested charm to also support v2. See [migration guide](https://docs.pydantic.dev/latest/migration/) for more information.
7 changes: 7 additions & 0 deletions interface_tester/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ def load_schema_module(schema_path: Path) -> types.ModuleType:
if module_name in sys.modules:
del sys.modules[module_name]

if pydantic.version.VERSION.split(".") <= ["2"]:
# in pydantic v1 it's necessary; in v2 it isn't.

# Otherwise we'll get an error when we re-run @validator
logger.debug("Clearing pydantic.class_validators._FUNCS")
pydantic.class_validators._FUNCS.clear() # noqa

try:
module = importlib.import_module(module_name)
except ImportError:
Expand Down
156 changes: 94 additions & 62 deletions interface_tester/interface_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from enum import Enum
from typing import Any, Callable, List, Literal, Optional, Union

import pydantic
from ops.testing import CharmType
from pydantic import ValidationError
from scenario import Context, Event, Relation, State
from scenario import Context, Event, Relation, State, state
from scenario.state import _EventPath

from interface_tester.errors import InvalidTestCaseError, SchemaValidationError

Expand All @@ -28,6 +30,20 @@

logger = logging.getLogger(__name__)

_has_pydantic_v1 = pydantic.version.VERSION.split(".") <= ["2"]
if _has_pydantic_v1:
logger.warning(
"You seem to be using pydantic v1. "
"Please upgrade to v2, as compatibility may be dropped in a future version of "
"pytest-interface-tester."
)


def _validate(model: pydantic.BaseModel, obj: dict):
if _has_pydantic_v1:
return model.validate(obj)
return model.model_validate(obj)


class InvalidTestCase(RuntimeError):
"""Raised if a function decorated with interface_test_case is invalid."""
Expand Down Expand Up @@ -283,14 +299,15 @@ def assert_schema_valid(self, schema: Optional["DataBagSchema"] = None):
errors = []
for relation in self._relations:
try:
databag_schema.model_validate(
_validate(
databag_schema,
{
"unit": relation.local_unit_data,
"app": relation.local_app_data,
}
},
)
except ValidationError as e:
errors.append(e.errors()[0])
errors.append(str(e))
if errors:
raise SchemaValidationError(errors)

Expand All @@ -306,9 +323,16 @@ def assert_relation_data_empty(self):
raise SchemaValidationError(
f"test {self._test_id}: local app databag not empty for {relation}"
)
if relation.local_unit_data:

# remove the default unit databag keys or we'll get false positives.
local_unit_data_keys = set(relation.local_unit_data).difference(
set(state.DEFAULT_JUJU_DATABAG.keys())
)

if local_unit_data_keys:
raise SchemaValidationError(
f"test {self._test_id}: local unit databag not empty for {relation}"
f"test {self._test_id}: local unit databag not empty for {relation}: "
f"found {local_unit_data_keys!r} keys set"
)
self._has_checked_schema = True

Expand Down Expand Up @@ -365,8 +389,7 @@ def _run(self, event: Union[str, Event]):

# the Relation instance this test is about:
relation = next(filter(lambda r: r.interface == self.ctx.interface_name, relations))
# test.EVENT might be a string or an Event. Cast to Event.
evt: Event = self._coerce_event(event, relation)
evt: Event = self._cast_event(event, relation)

logger.info("collected test for %s with %s" % (self.ctx.interface_name, evt.name))
return self._run_scenario(evt, modified_state)
Expand All @@ -387,45 +410,50 @@ def _run_scenario(self, event: Event, state: State):
)
return ctx.run(event, state)

def _coerce_event(self, raw_event: Union[str, Event], relation: Relation) -> Event:
# if the event being tested is a relation event, we need to inject some metadata
# or scenario.Runtime won't be able to guess what envvars need setting before ops.main
# takes over
if isinstance(raw_event, str):
ep_name, _, evt_kind = raw_event.rpartition("-relation-")
if ep_name and evt_kind:
# this is a relation event.
# we inject the relation metadata
# todo: if the user passes a relation event that is NOT about the relation
# interface that this test is about, at this point we are injecting the wrong
# Relation instance.
# e.g. if in interfaces/foo one wants to test that if 'bar-relation-joined' is
# fired... then one would have to pass an Event instance already with its
# own Relation.
return Event(
raw_event,
relation=relation.replace(endpoint=ep_name),
)

else:
return Event(raw_event)

elif isinstance(raw_event, Event):
if raw_event._is_relation_event and not raw_event.relation:
raise InvalidTestCaseError(
"This test case was passed an Event representing a relation event."
"However it does not have a Relation. Please pass it to the Event like so: "
"evt = Event('my_relation_changed', relation=Relation(...))"
)

return raw_event
def _cast_event(self, raw_event: Union[str, Event], relation: Relation):
# test.EVENT might be a string or an Event. Cast to Event.
event = Event(raw_event) if isinstance(raw_event, str) else raw_event

else:
if not isinstance(event, Event):
raise InvalidTestCaseError(
f"Expected Event or str, not {type(raw_event)}. "
f"Invalid test case: {self} cannot cast {raw_event} to Event."
)

if not event._is_relation_event:
raise InvalidTestCaseError(
f"Bad interface test specification: event {raw_event} " "is not a relation event."
)

# todo: if the user passes a relation event that is NOT about the relation
# interface that this test is about, at this point we are injecting the wrong
# Relation instance.
# e.g. if in interfaces/foo one wants to test that if 'bar-relation-joined' is
# fired... then one would have to pass an Event instance already with its
# own Relation.

# next we need to ensure that the event's .relation is our relation, and that the endpoint
# in the relation and the event path match that of the charm we're testing.
charm_event = event.replace(
relation=relation, path=relation.endpoint + typing.cast(_EventPath, event.path).suffix
)

return charm_event

@staticmethod
def _get_endpoint(supported_endpoints: dict, role: Role, interface_name: str):
endpoints_for_interface = supported_endpoints[role]

if len(endpoints_for_interface) < 1:
raise ValueError(f"no endpoint found for {role}/{interface_name}.")
elif len(endpoints_for_interface) > 1:
raise ValueError(
f"Multiple endpoints found for {role}/{interface_name}: "
f"{endpoints_for_interface}: cannot guess which one it is "
f"we're supposed to be testing"
)
return endpoints_for_interface[0]

def _generate_relations_state(
self, state_template: State, input_state: State, supported_endpoints, role: Role
) -> List[Relation]:
Expand All @@ -437,6 +465,9 @@ def _generate_relations_state(
"""
interface_name = self.ctx.interface_name

# determine what charm endpoint we're testing.
endpoint = self._get_endpoint(supported_endpoints, role, interface_name=interface_name)

for rel in state_template.relations:
if rel.interface == interface_name:
logger.warning(
Expand All @@ -448,18 +479,31 @@ def _generate_relations_state(
def filter_relations(rels: List[Relation], op: Callable):
return [r for r in rels if op(r.interface, interface_name)]

# the baseline is: all relations whose interface IS NOT the interface we're testing.
# the baseline is: all relations provided by the charm in the state_template,
# whose interface IS NOT the interface we're testing. We assume the test (input_state) is
# the ultimate owner of the state when it comes to the interface we're testing.
# We don't allow the charm to mess with it.
relations = filter_relations(state_template.relations, op=operator.ne)

if input_state:
# if the charm we're testing specified some relations in its input state, we add those
# whose interface IS the same as the one we're testing. If other relation interfaces
# were specified, they will be ignored.
relations.extend(filter_relations(input_state.relations, op=operator.eq))

if ignored := filter_relations(input_state.relations, op=operator.eq):
# if the interface test we're running specified some relations in its input_state,
# we add those whose interface IS the same as the one we're testing.
# If other relation interfaces were specified (for whatever reason?),
# they will be ignored.
relations_from_input_state = filter_relations(input_state.relations, op=operator.eq)

# relations that come from the state_template presumably have the right endpoint,
# but those that we get from interface tests cannot.
relations_with_endpoint = [
r.replace(endpoint=endpoint) for r in relations_from_input_state
]

relations.extend(relations_with_endpoint)

if ignored := filter_relations(input_state.relations, op=operator.ne):
# this is a sign of a bad test.
logger.warning(
"irrelevant relations specified in input state for %s/%s."
"irrelevant relations specified in input_state for %s/%s."
"These will be ignored. details: %s" % (interface_name, role, ignored)
)

Expand All @@ -468,19 +512,6 @@ def filter_relations(rels: List[Relation], op: Callable):
if not filter_relations(relations, op=operator.eq):
# if neither the charm nor the interface specified any custom relation spec for
# the interface we're testing, we will provide one.
endpoints_for_interface = supported_endpoints[role]

if len(endpoints_for_interface) < 1:
raise ValueError(f"no endpoint found for {role}/{interface_name}.")
elif len(endpoints_for_interface) > 1:
raise ValueError(
f"Multiple endpoints found for {role}/{interface_name}: "
f"{endpoints_for_interface}: cannot guess which one it is "
f"we're supposed to be testing"
)
else:
endpoint = endpoints_for_interface[0]

relations.append(
Relation(
interface=interface_name,
Expand All @@ -491,4 +522,5 @@ def filter_relations(rels: List[Relation], op: Callable):
"%s: merged %s and %s --> relations=%s"
% (self, input_state, state_template, relations)
)

return relations
6 changes: 4 additions & 2 deletions interface_tester/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type

from ops.testing import CharmType
from scenario.state import Event, MetadataNotFoundError, State, _CharmSpec
from scenario.state import MetadataNotFoundError, State, _CharmSpec

from interface_tester.collector import InterfaceTestSpec, gather_test_spec_for_version
from interface_tester.errors import (
Expand All @@ -22,7 +22,6 @@
)
from interface_tester.schema_base import DataBagSchema

Callback = Callable[[State, Event], None]
ROLE_TO_ROLE_META = {"provider": "provides", "requirer": "requires"}

logger = logging.getLogger("pytest_interface_tester")
Expand Down Expand Up @@ -328,8 +327,11 @@ def run(self) -> bool:
with tester_context(ctx):
test_fn()
except Exception as e:
logger.exception(f"Interface tester plugin failed with {e}")

if self._RAISE_IMMEDIATELY:
raise e

errors.append((ctx, e))
ran_some = True

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "pytest-interface-tester"

version = "2.0.0"
version = "2.0.1"
authors = [
{ name = "Pietro Pasotti", email = "pietro.pasotti@canonical.com" },
]
Expand All @@ -20,7 +20,7 @@ license.text = "Apache-2.0"
keywords = ["juju", "relation interfaces"]

dependencies = [
"pydantic>=2",
"pydantic>= 1.10.7",
"typer==0.7.0",
"ops-scenario>=5.2",
"pytest"
Expand Down
14 changes: 12 additions & 2 deletions tests/unit/test_collect_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
get_schema_from_module,
load_schema_module,
)
from interface_tester.interface_test import _has_pydantic_v1


def test_load_schema_module(tmp_path):
Expand Down Expand Up @@ -70,8 +71,17 @@ class RequirerSchema(DataBagSchema):
)

tests = collect_tests(root)
assert tests["mytestinterfacea"]["v0"]["requirer"]["schema"].model_fields["foo"].default == 1
assert tests["mytestinterfaceb"]["v0"]["requirer"]["schema"].model_fields["foo"].default == 2
if _has_pydantic_v1:
assert tests["mytestinterfacea"]["v0"]["requirer"]["schema"].__fields__["foo"].default == 1
assert tests["mytestinterfaceb"]["v0"]["requirer"]["schema"].__fields__["foo"].default == 2

else:
assert (
tests["mytestinterfacea"]["v0"]["requirer"]["schema"].model_fields["foo"].default == 1
)
assert (
tests["mytestinterfaceb"]["v0"]["requirer"]["schema"].model_fields["foo"].default == 2
)


def test_collect_invalid_schemas(tmp_path):
Expand Down
Loading

0 comments on commit 6de05b2

Please sign in to comment.