From 3ed4f32ca0984181e5d1f4656032c10b98e5eaaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez-Mondrag=C3=B3n?= Date: Tue, 16 Jul 2024 18:48:03 -0600 Subject: [PATCH] refactor: Implement abstract `serialize_message` for Singer writers --- singer_sdk/_singerlib/_encoding/base.py | 194 +--------------------- singer_sdk/_singerlib/_encoding/simple.py | 189 ++++++++++++++++++++- singer_sdk/_singerlib/messages.py | 6 +- 3 files changed, 195 insertions(+), 194 deletions(-) diff --git a/singer_sdk/_singerlib/_encoding/base.py b/singer_sdk/_singerlib/_encoding/base.py index 798a4f6dd..b3f85c7ec 100644 --- a/singer_sdk/_singerlib/_encoding/base.py +++ b/singer_sdk/_singerlib/_encoding/base.py @@ -8,8 +8,6 @@ import sys import typing as t from collections import Counter, defaultdict -from dataclasses import asdict, dataclass, field -from datetime import datetime, timezone from singer_sdk._singerlib import exceptions @@ -24,6 +22,7 @@ # TODO: Use to default to 'str' here # https://peps.python.org/pep-0696/ T = t.TypeVar("T", str, bytes) +M = t.TypeVar("M") class SingerMessageType(str, enum.Enum): @@ -36,186 +35,7 @@ class SingerMessageType(str, enum.Enum): BATCH = "BATCH" -def exclude_null_dict(pairs: list[tuple[str, t.Any]]) -> dict[str, t.Any]: - """Exclude null values from a dictionary. - - Args: - pairs: The dictionary key-value pairs. - - Returns: - The filtered key-value pairs. - """ - return {key: value for key, value in pairs if value is not None} - - -@dataclass -class Message: - """Singer base message.""" - - type: SingerMessageType = field(init=False) - """The message type.""" - - def to_dict(self) -> dict[str, t.Any]: - """Return a dictionary representation of the message. - - Returns: - A dictionary with the defined message fields. - """ - return asdict(self, dict_factory=exclude_null_dict) - - @classmethod - def from_dict( - cls: t.Type[Message], # noqa: UP006 - data: dict[str, t.Any], - ) -> Message: - """Create an encoding from a dictionary. - - Args: - data: The dictionary to create the message from. - - Returns: - The created message. - """ - data.pop("type") - return cls(**data) - - -@dataclass -class RecordMessage(Message): - """Singer record message.""" - - stream: str - """The stream name.""" - - record: dict[str, t.Any] - """The record data.""" - - version: int | None = None - """The record version.""" - - time_extracted: datetime | None = None - """The time the record was extracted.""" - - @classmethod - def from_dict(cls: type[RecordMessage], data: dict[str, t.Any]) -> RecordMessage: - """Create a record message from a dictionary. - - This overrides the default conversion logic, since it uses unnecessary - deep copying and is very slow. - - Args: - data: The dictionary to create the message from. - - Returns: - The created message. - """ - time_extracted = data.get("time_extracted") - return cls( - stream=data["stream"], - record=data["record"], - version=data.get("version"), - time_extracted=datetime.fromisoformat(time_extracted) - if time_extracted - else None, - ) - - def to_dict(self) -> dict[str, t.Any]: - """Return a dictionary representation of the message. - - This overrides the default conversion logic, since it uses unnecessary - deep copying and is very slow. - - Returns: - A dictionary with the defined message fields. - """ - result: dict[str, t.Any] = { - "type": "RECORD", - "stream": self.stream, - "record": self.record, - } - if self.version is not None: - result["version"] = self.version - if self.time_extracted is not None: - result["time_extracted"] = self.time_extracted - return result - - def __post_init__(self) -> None: - """Post-init processing. - - Raises: - ValueError: If the time_extracted is not timezone-aware. - """ - self.type = SingerMessageType.RECORD - if self.time_extracted and not self.time_extracted.tzinfo: - msg = ( - "'time_extracted' must be either None or an aware datetime (with a " - "time zone)" - ) - raise ValueError(msg) - - if self.time_extracted: - self.time_extracted = self.time_extracted.astimezone(timezone.utc) - - -@dataclass -class SchemaMessage(Message): - """Singer schema message.""" - - stream: str - """The stream name.""" - - schema: dict[str, t.Any] - """The schema definition.""" - - key_properties: t.Sequence[str] | None = None - """The key properties.""" - - bookmark_properties: list[str] | None = None - """The bookmark properties.""" - - def __post_init__(self) -> None: - """Post-init processing. - - Raises: - ValueError: If bookmark_properties is not a string or list of strings. - """ - self.type = SingerMessageType.SCHEMA - - if isinstance(self.bookmark_properties, (str, bytes)): - self.bookmark_properties = [self.bookmark_properties] - if self.bookmark_properties and not isinstance(self.bookmark_properties, list): - msg = "bookmark_properties must be a string or list of strings" - raise ValueError(msg) - - -@dataclass -class StateMessage(Message): - """Singer state message.""" - - value: dict[str, t.Any] - """The state value.""" - - def __post_init__(self) -> None: - """Post-init processing.""" - self.type = SingerMessageType.STATE - - -@dataclass -class ActivateVersionMessage(Message): - """Singer activate version message.""" - - stream: str - """The stream name.""" - - version: int - """The version to activate.""" - - def __post_init__(self) -> None: - """Post-init processing.""" - self.type = SingerMessageType.ACTIVATE_VERSION - - -class GenericSingerReader(t.Generic[T], metaclass=abc.ABCMeta): +class GenericSingerReader(t.Generic[T, M], metaclass=abc.ABCMeta): """Interface for all plugins reading Singer messages as strings or bytes.""" @t.final @@ -320,10 +140,10 @@ def _process_endofpipe(self) -> None: # noqa: PLR6301 logger.debug("End of pipe reached") -class GenericSingerWriter(t.Generic[T], metaclass=abc.ABCMeta): +class GenericSingerWriter(t.Generic[T, M], metaclass=abc.ABCMeta): """Interface for all plugins writing Singer messages as strings or bytes.""" - def format_message(self, message: Message) -> T: + def format_message(self, message: M) -> T: """Format a message as a JSON string. Args: @@ -332,10 +152,10 @@ def format_message(self, message: Message) -> T: Returns: The formatted message. """ - return self.serialize_json(message.to_dict()) + return self.serialize_message(message) @abc.abstractmethod - def serialize_json(self, obj: object) -> T: ... + def serialize_message(self, message: M) -> T: ... @abc.abstractmethod - def write_message(self, message: Message) -> None: ... + def write_message(self, message: M) -> None: ... diff --git a/singer_sdk/_singerlib/_encoding/simple.py b/singer_sdk/_singerlib/_encoding/simple.py index 636a8f981..690ccfbbf 100644 --- a/singer_sdk/_singerlib/_encoding/simple.py +++ b/singer_sdk/_singerlib/_encoding/simple.py @@ -4,11 +4,13 @@ import logging import sys import typing as t +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone from singer_sdk._singerlib.exceptions import InvalidInputLine from singer_sdk._singerlib.json import deserialize_json, serialize_json -from .base import GenericSingerReader, GenericSingerWriter +from .base import GenericSingerReader, GenericSingerWriter, SingerMessageType if t.TYPE_CHECKING: from singer_sdk._singerlib.messages import Message @@ -16,6 +18,185 @@ logger = logging.getLogger(__name__) +def exclude_null_dict(pairs: list[tuple[str, t.Any]]) -> dict[str, t.Any]: + """Exclude null values from a dictionary. + + Args: + pairs: The dictionary key-value pairs. + + Returns: + The filtered key-value pairs. + """ + return {key: value for key, value in pairs if value is not None} + + +@dataclass +class Message: + """Singer base message.""" + + type: SingerMessageType = field(init=False) + """The message type.""" + + def to_dict(self) -> dict[str, t.Any]: + """Return a dictionary representation of the message. + + Returns: + A dictionary with the defined message fields. + """ + return asdict(self, dict_factory=exclude_null_dict) + + @classmethod + def from_dict( + cls: t.Type[Message], # noqa: UP006 + data: dict[str, t.Any], + ) -> Message: + """Create an encoding from a dictionary. + + Args: + data: The dictionary to create the message from. + + Returns: + The created message. + """ + data.pop("type") + return cls(**data) + + +@dataclass +class RecordMessage(Message): + """Singer record message.""" + + stream: str + """The stream name.""" + + record: dict[str, t.Any] + """The record data.""" + + version: int | None = None + """The record version.""" + + time_extracted: datetime | None = None + """The time the record was extracted.""" + + @classmethod + def from_dict(cls: type[RecordMessage], data: dict[str, t.Any]) -> RecordMessage: + """Create a record message from a dictionary. + + This overrides the default conversion logic, since it uses unnecessary + deep copying and is very slow. + + Args: + data: The dictionary to create the message from. + + Returns: + The created message. + """ + time_extracted = data.get("time_extracted") + return cls( + stream=data["stream"], + record=data["record"], + version=data.get("version"), + time_extracted=datetime.fromisoformat(time_extracted) + if time_extracted + else None, + ) + + def to_dict(self) -> dict[str, t.Any]: + """Return a dictionary representation of the message. + + This overrides the default conversion logic, since it uses unnecessary + deep copying and is very slow. + + Returns: + A dictionary with the defined message fields. + """ + result: dict[str, t.Any] = { + "type": "RECORD", + "stream": self.stream, + "record": self.record, + } + if self.version is not None: + result["version"] = self.version + if self.time_extracted is not None: + result["time_extracted"] = self.time_extracted + return result + + def __post_init__(self) -> None: + """Post-init processing. + + Raises: + ValueError: If the time_extracted is not timezone-aware. + """ + self.type = SingerMessageType.RECORD + if self.time_extracted and not self.time_extracted.tzinfo: + msg = ( + "'time_extracted' must be either None or an aware datetime (with a " + "time zone)" + ) + raise ValueError(msg) + + if self.time_extracted: + self.time_extracted = self.time_extracted.astimezone(timezone.utc) + + +@dataclass +class SchemaMessage(Message): + """Singer schema message.""" + + stream: str + """The stream name.""" + + schema: dict[str, t.Any] + """The schema definition.""" + + key_properties: t.Sequence[str] | None = None + """The key properties.""" + + bookmark_properties: list[str] | None = None + """The bookmark properties.""" + + def __post_init__(self) -> None: + """Post-init processing. + + Raises: + ValueError: If bookmark_properties is not a string or list of strings. + """ + self.type = SingerMessageType.SCHEMA + + if isinstance(self.bookmark_properties, (str, bytes)): + self.bookmark_properties = [self.bookmark_properties] + if self.bookmark_properties and not isinstance(self.bookmark_properties, list): + msg = "bookmark_properties must be a string or list of strings" + raise ValueError(msg) + + +@dataclass +class StateMessage(Message): + """Singer state message.""" + + value: dict[str, t.Any] + """The state value.""" + + def __post_init__(self) -> None: + """Post-init processing.""" + self.type = SingerMessageType.STATE + + +@dataclass +class ActivateVersionMessage(Message): + """Singer activate version message.""" + + stream: str + """The stream name.""" + + version: int + """The version to activate.""" + + def __post_init__(self) -> None: + """Post-init processing.""" + self.type = SingerMessageType.ACTIVATE_VERSION + + class SingerReader(GenericSingerReader[str]): """Base class for all plugins reading Singer messages as strings from stdin.""" @@ -44,16 +225,16 @@ def deserialize_json(self, line: str) -> dict: # noqa: PLR6301 class SingerWriter(GenericSingerWriter[str]): """Interface for all plugins writing Singer messages to stdout.""" - def serialize_json(self, obj: object) -> str: # noqa: PLR6301 + def serialize_message(self, message: Message) -> str: # noqa: PLR6301 """Serialize a dictionary into a line of json. Args: - obj: A Python object usually a dict. + message: A Singer message object. Returns: A string of serialized json. """ - return serialize_json(obj) + return serialize_json(message.to_dict()) def write_message(self, message: Message) -> None: """Write a message to stdout. diff --git a/singer_sdk/_singerlib/messages.py b/singer_sdk/_singerlib/messages.py index b08739126..4b340f72a 100644 --- a/singer_sdk/_singerlib/messages.py +++ b/singer_sdk/_singerlib/messages.py @@ -2,13 +2,13 @@ from __future__ import annotations -from singer_sdk._singerlib._encoding import SingerWriter -from singer_sdk._singerlib._encoding.base import ( +from ._encoding import SingerWriter +from ._encoding.base import SingerMessageType +from ._encoding.simple import ( ActivateVersionMessage, Message, RecordMessage, SchemaMessage, - SingerMessageType, StateMessage, exclude_null_dict, )