Skip to content

Commit

Permalink
Merge pull request #219 from networktocode/typing-overhaul
Browse files Browse the repository at this point in the history
Type hinting overhaul.
  • Loading branch information
Kircheneer authored Aug 18, 2023
2 parents 6a0949e + 4fa39b5 commit 2c1577f
Show file tree
Hide file tree
Showing 12 changed files with 302 additions and 228 deletions.
189 changes: 98 additions & 91 deletions diffsync/__init__.py

Large diffs are not rendered by default.

83 changes: 42 additions & 41 deletions diffsync/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,44 @@
"""

from functools import total_ordering
from typing import Any, Iterator, Iterable, Mapping, Optional, Text, Type
from typing import Any, Iterator, Optional, Type, List, Dict, Iterable

from .exceptions import ObjectAlreadyExists
from .utils import intersection, OrderedDefaultDict
from .enum import DiffSyncActions

# This workaround is used because we are defining a method called `str` in our class definition, which therefore renders
# the builtin `str` type unusable.
StrType = str


class Diff:
"""Diff Object, designed to store multiple DiffElement object and organize them in a group."""

def __init__(self):
def __init__(self) -> None:
"""Initialize a new, empty Diff object."""
self.children = OrderedDefaultDict(dict)
self.children = OrderedDefaultDict[StrType, Dict[StrType, DiffElement]](dict)
"""DefaultDict for storing DiffElement objects.
`self.children[group][unique_id] == DiffElement(...)`
"""
self.models_processed = 0

def __len__(self):
def __len__(self) -> int:
"""Total number of DiffElements stored herein."""
total = 0
for child in self.get_children():
total += len(child)
return total

def complete(self):
def complete(self) -> None:
"""Method to call when this Diff has been fully populated with data and is "complete".
The default implementation does nothing, but a subclass could use this, for example, to save
the completed Diff to a file or database record.
"""

def add(self, element: "DiffElement"):
def add(self, element: "DiffElement") -> None:
"""Add a new DiffElement to the changeset of this Diff.
Raises:
Expand All @@ -61,15 +65,15 @@ def add(self, element: "DiffElement"):

self.children[element.type][element.name] = element

def groups(self):
def groups(self) -> List[StrType]:
"""Get the list of all group keys in self.children."""
return self.children.keys()
return list(self.children.keys())

def has_diffs(self) -> bool:
"""Indicate if at least one of the child elements contains some diff.
Returns:
bool: True if at least one child element contains some diff
True if at least one child element contains some diff
"""
for group in self.groups():
for child in self.children[group].values():
Expand All @@ -96,15 +100,15 @@ def get_children(self) -> Iterator["DiffElement"]:
yield from order_method(self.children[group])

@classmethod
def order_children_default(cls, children: Mapping) -> Iterator["DiffElement"]:
def order_children_default(cls, children: Dict[StrType, "DiffElement"]) -> Iterator["DiffElement"]:
"""Default method to an Iterator for children.
Since children is already an OrderedDefaultDict, this method is not doing anything special.
"""
for child in children.values():
yield child

def summary(self) -> Mapping[Text, int]:
def summary(self) -> Dict[StrType, int]:
"""Build a dict summary of this Diff and its child DiffElements."""
summary = {
DiffSyncActions.CREATE: 0,
Expand All @@ -127,7 +131,7 @@ def summary(self) -> Mapping[Text, int]:
)
return summary

def str(self, indent: int = 0):
def str(self, indent: int = 0) -> StrType:
"""Build a detailed string representation of this Diff and its child DiffElements."""
margin = " " * indent
output = []
Expand All @@ -144,9 +148,9 @@ def str(self, indent: int = 0):
result = "(no diffs)"
return result

def dict(self) -> Mapping[Text, Mapping[Text, Mapping]]:
def dict(self) -> Dict[StrType, Dict[StrType, Dict]]:
"""Build a dictionary representation of this Diff."""
result = OrderedDefaultDict(dict)
result = OrderedDefaultDict[str, Dict](dict)
for child in self.get_children():
if child.has_diffs(include_children=True):
result[child.type][child.name] = child.dict()
Expand All @@ -159,11 +163,11 @@ class DiffElement: # pylint: disable=too-many-instance-attributes

def __init__(
self,
obj_type: Text,
name: Text,
keys: Mapping,
source_name: Text = "source",
dest_name: Text = "dest",
obj_type: StrType,
name: StrType,
keys: Dict,
source_name: StrType = "source",
dest_name: StrType = "dest",
diff_class: Type[Diff] = Diff,
): # pylint: disable=too-many-arguments
"""Instantiate a DiffElement.
Expand All @@ -177,10 +181,10 @@ def __init__(
dest_name: Name of the destination DiffSync object
diff_class: Diff or subclass thereof to use to calculate the diffs to use for synchronization
"""
if not isinstance(obj_type, str):
if not isinstance(obj_type, StrType):
raise ValueError(f"obj_type must be a string (not {type(obj_type)})")

if not isinstance(name, str):
if not isinstance(name, StrType):
raise ValueError(f"name must be a string (not {type(name)})")

self.type = obj_type
Expand All @@ -189,18 +193,18 @@ def __init__(
self.source_name = source_name
self.dest_name = dest_name
# Note: *_attrs == None if no target object exists; it'll be an empty dict if it exists but has no _attributes
self.source_attrs: Optional[Mapping] = None
self.dest_attrs: Optional[Mapping] = None
self.source_attrs: Optional[Dict] = None
self.dest_attrs: Optional[Dict] = None
self.child_diff = diff_class()

def __lt__(self, other):
def __lt__(self, other: "DiffElement") -> bool:
"""Logical ordering of DiffElements.
Other comparison methods (__gt__, __le__, __ge__, etc.) are created by our use of the @total_ordering decorator.
"""
return (self.type, self.name) < (other.type, other.name)

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
"""Logical equality of DiffElements.
Other comparison methods (__gt__, __le__, __ge__, etc.) are created by our use of the @total_ordering decorator.
Expand All @@ -216,26 +220,26 @@ def __eq__(self, other):
# TODO also check that self.child_diff == other.child_diff, needs Diff to implement __eq__().
)

def __str__(self):
def __str__(self) -> StrType:
"""Basic string representation of a DiffElement."""
return (
f'{self.type} "{self.name}" : {self.keys} : '
f"{self.source_name}{self.dest_name} : {self.get_attrs_diffs()}"
)

def __len__(self):
def __len__(self) -> int:
"""Total number of DiffElements in this one, including itself."""
total = 1 # self
for child in self.get_children():
total += len(child)
return total

@property
def action(self) -> Optional[Text]:
def action(self) -> Optional[StrType]:
"""Action, if any, that should be taken to remediate the diffs described by this element.
Returns:
str: DiffSyncActions ("create", "update", "delete", or None)
"create", "update", "delete", or None)
"""
if self.source_attrs is not None and self.dest_attrs is None:
return DiffSyncActions.CREATE
Expand All @@ -251,7 +255,7 @@ def action(self) -> Optional[Text]:
return None

# TODO: separate into set_source_attrs() and set_dest_attrs() methods, or just use direct property access instead?
def add_attrs(self, source: Optional[Mapping] = None, dest: Optional[Mapping] = None):
def add_attrs(self, source: Optional[Dict] = None, dest: Optional[Dict] = None) -> None:
"""Set additional attributes of a source and/or destination item that may result in diffs."""
# TODO: should source_attrs and dest_attrs be "write-once" properties, or is it OK to overwrite them once set?
if source is not None:
Expand All @@ -260,26 +264,26 @@ def add_attrs(self, source: Optional[Mapping] = None, dest: Optional[Mapping] =
if dest is not None:
self.dest_attrs = dest

def get_attrs_keys(self) -> Iterable[Text]:
def get_attrs_keys(self) -> Iterable[StrType]:
"""Get the list of shared attrs between source and dest, or the attrs of source or dest if only one is present.
- If source_attrs is not set, return the keys of dest_attrs
- If dest_attrs is not set, return the keys of source_attrs
- If both are defined, return the intersection of both keys
"""
if self.source_attrs is not None and self.dest_attrs is not None:
return intersection(self.dest_attrs.keys(), self.source_attrs.keys())
return intersection(list(self.dest_attrs.keys()), list(self.source_attrs.keys()))
if self.source_attrs is None and self.dest_attrs is not None:
return self.dest_attrs.keys()
if self.source_attrs is not None and self.dest_attrs is None:
return self.source_attrs.keys()
return []

def get_attrs_diffs(self) -> Mapping[Text, Mapping[Text, Any]]:
def get_attrs_diffs(self) -> Dict[StrType, Dict[StrType, Any]]:
"""Get the dict of actual attribute diffs between source_attrs and dest_attrs.
Returns:
dict: of the form `{"-": {key1: <value>, key2: ...}, "+": {key1: <value>, key2: ...}}`,
Dictionary of the form `{"-": {key1: <value>, key2: ...}, "+": {key1: <value>, key2: ...}}`,
where the `"-"` or `"+"` dicts may be absent.
"""
if self.source_attrs is not None and self.dest_attrs is not None:
Expand All @@ -301,13 +305,10 @@ def get_attrs_diffs(self) -> Mapping[Text, Mapping[Text, Any]]:
return {"+": {key: self.source_attrs[key] for key in self.get_attrs_keys()}}
return {}

def add_child(self, element: "DiffElement"):
def add_child(self, element: "DiffElement") -> None:
"""Attach a child object of type DiffElement.
Childs are saved in a Diff object and are organized by type and name.
Args:
element: DiffElement
"""
self.child_diff.add(element)

Expand Down Expand Up @@ -336,7 +337,7 @@ def has_diffs(self, include_children: bool = True) -> bool:

return False

def summary(self) -> Mapping[Text, int]:
def summary(self) -> Dict[StrType, int]:
"""Build a summary of this DiffElement and its children."""
summary = {
DiffSyncActions.CREATE: 0,
Expand All @@ -353,7 +354,7 @@ def summary(self) -> Mapping[Text, int]:
summary[key] += child_summary[key]
return summary

def str(self, indent: int = 0):
def str(self, indent: int = 0) -> StrType:
"""Build a detailed string representation of this DiffElement and its children."""
margin = " " * indent
result = f"{margin}{self.type}: {self.name}"
Expand All @@ -377,7 +378,7 @@ def str(self, indent: int = 0):
result += " (no diffs)"
return result

def dict(self) -> Mapping[Text, Mapping[Text, Any]]:
def dict(self) -> Dict[StrType, Dict[StrType, Any]]:
"""Build a dictionary representation of this DiffElement and its children."""
attrs_diffs = self.get_attrs_diffs()
result = {}
Expand Down
7 changes: 6 additions & 1 deletion diffsync/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import TYPE_CHECKING, Union, Any

if TYPE_CHECKING:
from diffsync import DiffSyncModel
from diffsync.diff import DiffElement


class ObjectCrudException(Exception):
Expand All @@ -39,7 +44,7 @@ class ObjectStoreException(Exception):
class ObjectAlreadyExists(ObjectStoreException):
"""Exception raised when trying to store a DiffSyncModel or DiffElement that is already being stored."""

def __init__(self, message, existing_object, *args, **kwargs):
def __init__(self, message: str, existing_object: Union["DiffSyncModel", "DiffElement"], *args: Any, **kwargs: Any):
"""Add existing_object to the exception to provide user with existing object."""
self.existing_object = existing_object
super().__init__(message, existing_object, *args, **kwargs)
Expand Down
22 changes: 12 additions & 10 deletions diffsync/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
limitations under the License.
"""
from collections.abc import Iterable as ABCIterable, Mapping as ABCMapping
from typing import Callable, Iterable, List, Mapping, Optional, Tuple, Type, TYPE_CHECKING
from typing import Callable, List, Optional, Tuple, Type, TYPE_CHECKING, Dict, Iterable

import structlog # type: ignore

Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__( # pylint: disable=too-many-arguments
self.total_models = len(src_diffsync) + len(dst_diffsync)
self.logger.debug(f"Diff calculation between these two datasets will involve {self.total_models} models")

def incr_models_processed(self, delta: int = 1):
def incr_models_processed(self, delta: int = 1) -> None:
"""Increment self.models_processed, then call self.callback if present."""
if delta:
self.models_processed += delta
Expand Down Expand Up @@ -136,7 +136,9 @@ def diff_object_list(self, src: List["DiffSyncModel"], dst: List["DiffSyncModel"
return diff_elements

@staticmethod
def validate_objects_for_diff(object_pairs: Iterable[Tuple[Optional["DiffSyncModel"], Optional["DiffSyncModel"]]]):
def validate_objects_for_diff(
object_pairs: Iterable[Tuple[Optional["DiffSyncModel"], Optional["DiffSyncModel"]]]
) -> None:
"""Check whether all DiffSyncModels in the given dictionary are valid for comparison to one another.
Helper method for `diff_object_list`.
Expand Down Expand Up @@ -234,15 +236,15 @@ def diff_child_objects(
diff_element: DiffElement,
src_obj: Optional["DiffSyncModel"],
dst_obj: Optional["DiffSyncModel"],
):
) -> DiffElement:
"""For all children of the given DiffSyncModel pair, diff recursively, adding diffs to the given diff_element.
Helper method to `calculate_diffs`, usually doesn't need to be called directly.
These helper methods work in a recursive cycle:
diff_object_list -> diff_object_pair -> diff_child_objects -> diff_object_list -> etc.
"""
children_mapping: Mapping[str, str]
children_mapping: Dict[str, str]
if src_obj and dst_obj:
# Get the subset of child types common to both src_obj and dst_obj
src_mapping = src_obj.get_children_mapping()
Expand Down Expand Up @@ -308,7 +310,7 @@ def __init__( # pylint: disable=too-many-arguments
self.model_class: Type["DiffSyncModel"]
self.action: Optional[str] = None

def incr_elements_processed(self, delta: int = 1):
def incr_elements_processed(self, delta: int = 1) -> None:
"""Increment self.elements_processed, then call self.callback if present."""
if delta:
self.elements_processed += delta
Expand All @@ -319,7 +321,7 @@ def perform_sync(self) -> bool:
"""Perform data synchronization based on the provided diff.
Returns:
bool: True if any changes were actually performed, else False.
True if any changes were actually performed, else False.
"""
changed = False
self.base_logger.info("Beginning sync")
Expand Down Expand Up @@ -401,14 +403,14 @@ def sync_diff_element(self, element: DiffElement, parent_model: Optional["DiffSy
return changed

def sync_model( # pylint: disable=too-many-branches, unused-argument
self, src_model: Optional["DiffSyncModel"], dst_model: Optional["DiffSyncModel"], ids: Mapping, attrs: Mapping
self, src_model: Optional["DiffSyncModel"], dst_model: Optional["DiffSyncModel"], ids: Dict, attrs: Dict
) -> Tuple[bool, Optional["DiffSyncModel"]]:
"""Create/update/delete the current DiffSyncModel with current ids/attrs, and update self.status and self.message.
Helper method to `sync_diff_element`.
Returns:
tuple: (changed, model) where model may be None if an error occurred
(changed, model) where model may be None if an error occurred
"""
if self.action is None:
status = DiffSyncStatus.SUCCESS
Expand Down Expand Up @@ -451,7 +453,7 @@ def sync_model( # pylint: disable=too-many-branches, unused-argument

return (True, dst_model)

def log_sync_status(self, action: Optional[str], status: DiffSyncStatus, message: str):
def log_sync_status(self, action: Optional[str], status: DiffSyncStatus, message: str) -> None:
"""Log the current sync status at the appropriate verbosity with appropriate context.
Helper method to `sync_diff_element`/`sync_model`.
Expand Down
Loading

0 comments on commit 2c1577f

Please sign in to comment.