Skip to content

Commit

Permalink
refactor: Implement msgspec encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon committed Jul 17, 2024
1 parent 32c059b commit 3931f1d
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 2 deletions.
54 changes: 53 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ duckdb = ">=0.8.0"
duckdb-engine = { version = ">=0.9.4", python = "<4" }

fastjsonschema = ">=2.19.1"
msgspec = ">=0.18.6"
pytest-benchmark = ">=4.0.0"
pytest-snapshot = ">=0.9.0"
pytz = ">=2022.2.1"
Expand Down
180 changes: 180 additions & 0 deletions singer_sdk/_singerlib/encoding/_msgspec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from __future__ import annotations

import datetime
import decimal
import logging
import sys
import typing as t

import msgspec

from singer_sdk._singerlib.exceptions import InvalidInputLine

from ._base import GenericSingerReader, GenericSingerWriter, SingerMessageType

logger = logging.getLogger(__name__)


class Message(msgspec.Struct):
"""Singer base message."""

type: str

def to_dict(self): # noqa: ANN202
return {f: getattr(self, f) for f in self.__struct_fields__}

Check warning on line 24 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L24

Added line #L24 was not covered by tests


class RecordMessage(Message):
"""Singer RECORD message."""

type = "RECORD"
stream: str
record: dict[str, t.Any]
version: int | None = None
time_extracted: datetime.datetime | None = None

def __post_init__(self) -> None:
"""Post-init processing.
Raises:
ValueError: If the time_extracted is not timezone-aware.
"""
self.type = SingerMessageType.RECORD

Check warning on line 42 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L42

Added line #L42 was not covered by tests
if self.time_extracted and not self.time_extracted.tzinfo:
msg = (

Check warning on line 44 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L44

Added line #L44 was not covered by tests
"'time_extracted' must be either None or an aware datetime (with a "
"time zone)"
)
raise ValueError(msg)

Check warning on line 48 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L48

Added line #L48 was not covered by tests

if self.time_extracted:
self.time_extracted = self.time_extracted.astimezone(datetime.timezone.utc)

Check warning on line 51 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L51

Added line #L51 was not covered by tests


class SchemaMessage(Message):
"""Singer SCHEMA message."""

type = "SCHEMA"
stream: str
schema: dict[str, t.Any]
key_properties: list[str]
bookmark_properties: list[str] | None = None

def __post_init__(self) -> None:
"""Post-init processing.
Raises:
ValueError: If bookmark_properties is not a string or list of strings.
"""
self.type = SingerMessageType.SCHEMA

Check warning on line 69 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L69

Added line #L69 was not covered by tests

if isinstance(self.bookmark_properties, (str, bytes)):
self.bookmark_properties = [self.bookmark_properties]

Check warning on line 72 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L72

Added line #L72 was not covered by tests
if self.bookmark_properties and not isinstance(self.bookmark_properties, list):
msg = "bookmark_properties must be a string or list of strings"
raise ValueError(msg)

Check warning on line 75 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L74-L75

Added lines #L74 - L75 were not covered by tests


class StateMessage(Message):
"""Singer state message."""

value: dict[str, t.Any]
"""The state value."""

def __post_init__(self) -> None:
"""Post-init processing."""
self.type = SingerMessageType.STATE

Check warning on line 86 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L86

Added line #L86 was not covered by tests


class ActivateVersionMessage(Message):
"""Singer activate version message."""

stream: str
"""The stream name."""

version: int
"""The version to activate."""

def __post_init__(self) -> None:
"""Post-init processing."""
self.type = SingerMessageType.ACTIVATE_VERSION

Check warning on line 100 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L100

Added line #L100 was not covered by tests


def enc_hook(obj: t.Any) -> t.Any: # noqa: ANN401
"""Encoding type helper for non native types.
Args:
obj: the item to be encoded
Returns:
The object converted to the appropriate type, default is str
"""
return obj.isoformat(sep="T") if isinstance(obj, datetime.datetime) else str(obj)


def dec_hook(type: type, obj: t.Any) -> t.Any: # noqa: ARG001, A002, ANN401
"""Decoding type helper for non native types.
Args:
type: the type given
obj: the item to be decoded
Returns:
The object converted to the appropriate type, default is str.
"""
return str(obj)


encoder = msgspec.json.Encoder(enc_hook=enc_hook, decimal_format="number")
decoder = msgspec.json.Decoder(dec_hook=dec_hook, float_hook=decimal.Decimal)


class MsgSpecReader(GenericSingerReader[str]):
"""Base class for all plugins reading Singer messages as strings from stdin."""

default_input = sys.stdin

def deserialize_json(self, line: str) -> dict: # noqa: PLR6301
"""Deserialize a line of json.
Args:
line: A single line of json.
Returns:
A dictionary of the deserialized json.
Raises:
InvalidInputLine: If the line cannot be parsed
"""
try:
return decoder.decode(line)
except msgspec.DecodeError as exc:
logger.exception("Unable to parse:\n%s", line)
msg = f"Unable to parse line as JSON: {line}"
raise InvalidInputLine(msg) from exc

Check warning on line 154 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L149-L154

Added lines #L149 - L154 were not covered by tests


class MsgSpecWriter(GenericSingerWriter[bytes, Message]):
"""Interface for all plugins writing Singer messages to stdout."""

msg_buffer = bytearray(64)

def serialize_message(self, message: Message) -> bytes: # noqa: PLR6301
"""Serialize a dictionary into a line of json.
Args:
message: A Singer message object.
Returns:
A string of serialized json.
"""
return encoder.encode(message)

Check warning on line 171 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L171

Added line #L171 was not covered by tests

def write_message(self, message: Message) -> None:
"""Write a message to stdout.
Args:
message: The message to write.
"""
sys.stdout.buffer.write(self.format_message(message) + b"\n")
sys.stdout.flush()

Check warning on line 180 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L179-L180

Added lines #L179 - L180 were not covered by tests
41 changes: 41 additions & 0 deletions tests/_singerlib/_encoding/test_msgspec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations # noqa: INP001

import pytest

from singer_sdk._singerlib.encoding._msgspec import dec_hook, enc_hook


@pytest.mark.parametrize(
"test_type,test_value,expected_value,expected_type",
[
pytest.param(
int,
1,
"1",
str,
id="int-to-str",
),
],
)
def test_dec_hook(test_type, test_value, expected_value, expected_type):
returned = dec_hook(type=test_type, obj=test_value)
returned_type = type(returned)

assert returned == expected_value
assert returned_type == expected_type


@pytest.mark.parametrize(
"test_value,expected_value",
[
pytest.param(
1,
"1",
id="int-to-str",
),
],
)
def test_enc_hook(test_value, expected_value):
returned = enc_hook(obj=test_value)

assert returned == expected_value
19 changes: 18 additions & 1 deletion tests/core/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytest

from singer_sdk._singerlib import RecordMessage
from singer_sdk._singerlib.encoding._msgspec import MsgSpecReader, MsgSpecWriter
from singer_sdk._singerlib.exceptions import InvalidInputLine
from singer_sdk.io_base import SingerReader, SingerWriter

Expand Down Expand Up @@ -131,7 +132,7 @@ def test_bench_format_message(benchmark, bench_record_message):
"""Run benchmark for Sink._validator method validate."""
number_of_runs = 1000

writer = SingerWriter()
writer = MsgSpecWriter()

def run_format_message():
for record in itertools.repeat(bench_record_message, number_of_runs):
Expand All @@ -144,6 +145,22 @@ def test_bench_deserialize_json(benchmark, bench_encoded_record):
"""Run benchmark for Sink._validator method validate."""
number_of_runs = 1000

class DummyReader(MsgSpecReader):
def _process_activate_version_message(self, message_dict: dict) -> None:
pass

def _process_batch_message(self, message_dict: dict) -> None:
pass

def _process_record_message(self, message_dict: dict) -> None:
pass

def _process_schema_message(self, message_dict: dict) -> None:
pass

def _process_state_message(self, message_dict: dict) -> None:
pass

reader = DummyReader()

def run_deserialize_json():
Expand Down

0 comments on commit 3931f1d

Please sign in to comment.