Skip to content

Commit

Permalink
refactor: Implement msgspec encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon committed Jul 17, 2024
1 parent 32c059b commit 8b0e66f
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 2 deletions.
54 changes: 53 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ duckdb = ">=0.8.0"
duckdb-engine = { version = ">=0.9.4", python = "<4" }

fastjsonschema = ">=2.19.1"
msgspec = ">=0.18.6"
pytest-benchmark = ">=4.0.0"
pytest-snapshot = ">=0.9.0"
pytz = ">=2022.2.1"
Expand Down
172 changes: 172 additions & 0 deletions singer_sdk/_singerlib/encoding/_msgspec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from __future__ import annotations

import datetime
import decimal
import logging
import sys
import typing as t

import msgspec

from singer_sdk._singerlib.exceptions import InvalidInputLine

from ._base import GenericSingerReader, GenericSingerWriter

logger = logging.getLogger(__name__)


class Message(msgspec.Struct, tag_field="type", tag=str.upper):
"""Singer base message."""

def to_dict(self): # noqa: ANN202
return {f: getattr(self, f) for f in self.__struct_fields__}


class RecordMessage(Message, tag="RECORD"):
"""Singer RECORD message."""

stream: str
record: t.Dict[str, t.Any] # noqa: UP006
version: t.Union[int, None] = None # noqa: UP007
time_extracted: t.Union[datetime.datetime, None] = None # noqa: UP007

def __post_init__(self) -> None:
"""Post-init processing.
Raises:
ValueError: If the time_extracted is not timezone-aware.
"""
if self.time_extracted and not self.time_extracted.tzinfo:
msg = (
"'time_extracted' must be either None or an aware datetime (with a "
"time zone)"
)
raise ValueError(msg)

if self.time_extracted:
self.time_extracted = self.time_extracted.astimezone(datetime.timezone.utc)


class SchemaMessage(Message, tag="SCHEMA"):
"""Singer SCHEMA message."""

stream: str
schema: t.Dict[str, t.Any] # noqa: UP006
key_properties: list[str]
bookmark_properties: t.Union[list[str], None] = None # noqa: UP007

def __post_init__(self) -> None:
"""Post-init processing.
Raises:
ValueError: If bookmark_properties is not a string or list of strings.
"""
if isinstance(self.bookmark_properties, (str, bytes)):
self.bookmark_properties = [self.bookmark_properties]
if self.bookmark_properties and not isinstance(self.bookmark_properties, list):
msg = "bookmark_properties must be a string or list of strings"
raise ValueError(msg)


class StateMessage(Message, tag="STATE"):
"""Singer state message."""

value: t.Dict[str, t.Any] # noqa: UP006
"""The state value."""


class ActivateVersionMessage(Message, tag="ACTIVATE_VERSION"):
"""Singer activate version message."""

stream: str
"""The stream name."""

version: int
"""The version to activate."""


def enc_hook(obj: t.Any) -> t.Any: # noqa: ANN401
"""Encoding type helper for non native types.
Args:
obj: the item to be encoded
Returns:
The object converted to the appropriate type, default is str
"""
return obj.isoformat(sep="T") if isinstance(obj, datetime.datetime) else str(obj)


def dec_hook(type: type, obj: t.Any) -> t.Any: # noqa: ARG001, A002, ANN401
"""Decoding type helper for non native types.
Args:
type: the type given
obj: the item to be decoded
Returns:
The object converted to the appropriate type, default is str.
"""
return str(obj)


encoder = msgspec.json.Encoder(enc_hook=enc_hook, decimal_format="number")
decoder = msgspec.json.Decoder(
t.Union[
RecordMessage,
SchemaMessage,
StateMessage,
ActivateVersionMessage,
],
dec_hook=dec_hook,
float_hook=decimal.Decimal,
)


class MsgSpecReader(GenericSingerReader[str]):
"""Base class for all plugins reading Singer messages as strings from stdin."""

default_input = sys.stdin

def deserialize_json(self, line: str) -> dict: # noqa: PLR6301
"""Deserialize a line of json.
Args:
line: A single line of json.
Returns:
A dictionary of the deserialized json.
Raises:
InvalidInputLine: If the line cannot be parsed
"""
try:
return decoder.decode(line).to_dict()
except msgspec.DecodeError as exc:
logger.exception("Unable to parse:\n%s", line)
msg = f"Unable to parse line as JSON: {line}"
raise InvalidInputLine(msg) from exc


class MsgSpecWriter(GenericSingerWriter[bytes, Message]):
"""Interface for all plugins writing Singer messages to stdout."""

def serialize_message(self, message: Message) -> bytes: # noqa: PLR6301
"""Serialize a dictionary into a line of json.
Args:
message: A Singer message object.
Returns:
A string of serialized json.
"""
return encoder.encode(message)

def write_message(self, message: Message) -> None:
"""Write a message to stdout.
Args:
message: The message to write.
"""
sys.stdout.buffer.write(self.format_message(message) + b"\n")
sys.stdout.flush()
41 changes: 41 additions & 0 deletions tests/_singerlib/_encoding/test_msgspec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations # noqa: INP001

import pytest

from singer_sdk._singerlib.encoding._msgspec import dec_hook, enc_hook


@pytest.mark.parametrize(
"test_type,test_value,expected_value,expected_type",
[
pytest.param(
int,
1,
"1",
str,
id="int-to-str",
),
],
)
def test_dec_hook(test_type, test_value, expected_value, expected_type):
returned = dec_hook(type=test_type, obj=test_value)
returned_type = type(returned)

assert returned == expected_value
assert returned_type == expected_type


@pytest.mark.parametrize(
"test_value,expected_value",
[
pytest.param(
1,
"1",
id="int-to-str",
),
],
)
def test_enc_hook(test_value, expected_value):
returned = enc_hook(obj=test_value)

assert returned == expected_value
20 changes: 19 additions & 1 deletion tests/core/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytest

from singer_sdk._singerlib import RecordMessage
from singer_sdk._singerlib.encoding._msgspec import MsgSpecReader, MsgSpecWriter
from singer_sdk._singerlib.exceptions import InvalidInputLine
from singer_sdk.io_base import SingerReader, SingerWriter

Expand Down Expand Up @@ -104,6 +105,7 @@ def test_write_message():
def bench_record():
return {
"stream": "users",
"type": "RECORD",
"record": {
"Id": 1,
"created_at": "2021-01-01T00:08:00-07:00",
Expand Down Expand Up @@ -131,7 +133,7 @@ def test_bench_format_message(benchmark, bench_record_message):
"""Run benchmark for Sink._validator method validate."""
number_of_runs = 1000

writer = SingerWriter()
writer = MsgSpecWriter()

def run_format_message():
for record in itertools.repeat(bench_record_message, number_of_runs):
Expand All @@ -144,6 +146,22 @@ def test_bench_deserialize_json(benchmark, bench_encoded_record):
"""Run benchmark for Sink._validator method validate."""
number_of_runs = 1000

class DummyReader(MsgSpecReader):
def _process_activate_version_message(self, message_dict: dict) -> None:
pass

def _process_batch_message(self, message_dict: dict) -> None:
pass

def _process_record_message(self, message_dict: dict) -> None:
pass

def _process_schema_message(self, message_dict: dict) -> None:
pass

def _process_state_message(self, message_dict: dict) -> None:
pass

reader = DummyReader()

def run_deserialize_json():
Expand Down

0 comments on commit 8b0e66f

Please sign in to comment.