diff --git a/singer_sdk/target_base.py b/singer_sdk/target_base.py index d62bbbfd8..cf1038831 100644 --- a/singer_sdk/target_base.py +++ b/singer_sdk/target_base.py @@ -28,8 +28,9 @@ if t.TYPE_CHECKING: from pathlib import PurePath + from singer_sdk.connectors import SQLConnector from singer_sdk.mapper import PluginMapper - from singer_sdk.sinks import Sink + from singer_sdk.sinks import Sink, SQLSink _MAX_PARALLELISM = 8 @@ -48,7 +49,7 @@ class Target(PluginBase, SingerReader, metaclass=abc.ABCMeta): # Default class to use for creating new sink objects. # Required if `Target.get_sink_class()` is not defined. - default_sink_class: type[Sink] | None = None + default_sink_class: type[Sink] def __init__( self, @@ -574,6 +575,23 @@ def get_singer_command(cls: type[Target]) -> click.Command: class SQLTarget(Target): """Target implementation for SQL destinations.""" + _target_connector: SQLConnector | None = None + + default_sink_class: type[SQLSink] + + @property + def target_connector(self) -> SQLConnector: + """The connector object. + + Returns: + The connector object. + """ + if self._target_connector is None: + self._target_connector = self.default_sink_class.connector_class( + dict(self.config), + ) + return self._target_connector + @classproperty def capabilities(self) -> list[CapabilitiesEnum]: """Get target capabilities. @@ -617,3 +635,114 @@ def _merge_missing(source_jsonschema: dict, target_jsonschema: dict) -> None: super().append_builtin_config(config_jsonschema) pass + + @final + def add_sqlsink( + self, + stream_name: str, + schema: dict, + key_properties: list[str] | None = None, + ) -> Sink: + """Create a sink and register it. + + This method is internal to the SDK and should not need to be overridden. + + Args: + stream_name: Name of the stream. + schema: Schema of the stream. + key_properties: Primary key of the stream. + + Returns: + A new sink for the stream. + """ + self.logger.info("Initializing '%s' target sink...", self.name) + sink_class = self.get_sink_class(stream_name=stream_name) + sink = sink_class( + target=self, + stream_name=stream_name, + schema=schema, + key_properties=key_properties, + connector=self.target_connector, + ) + sink.setup() + self._sinks_active[stream_name] = sink + + return sink + + def get_sink_class(self, stream_name: str) -> type[SQLSink]: + """Get sink for a stream. + + Developers can override this method to return a custom Sink type depending + on the value of `stream_name`. Optional when `default_sink_class` is set. + + Args: + stream_name: Name of the stream. + + Raises: + ValueError: If no :class:`singer_sdk.sinks.Sink` class is defined. + + Returns: + The sink class to be used with the stream. + """ + if self.default_sink_class: + return self.default_sink_class + + msg = ( + f"No sink class defined for '{stream_name}' and no default sink class " + "available." + ) + raise ValueError(msg) + + def get_sink( + self, + stream_name: str, + *, + record: dict | None = None, + schema: dict | None = None, + key_properties: list[str] | None = None, + ) -> Sink: + """Return a sink for the given stream name. + + A new sink will be created if `schema` is provided and if either `schema` or + `key_properties` has changed. If so, the old sink becomes archived and held + until the next drain_all() operation. + + Developers only need to override this method if they want to provide a different + sink depending on the values within the `record` object. Otherwise, please see + `default_sink_class` property and/or the `get_sink_class()` method. + + Raises :class:`singer_sdk.exceptions.RecordsWithoutSchemaException` if sink does + not exist and schema is not sent. + + Args: + stream_name: Name of the stream. + record: Record being processed. + schema: Stream schema. + key_properties: Primary key of the stream. + + Returns: + The sink used for this target. + """ + _ = record # Custom implementations may use record in sink selection. + if schema is None: + self._assert_sink_exists(stream_name) + return self._sinks_active[stream_name] + + existing_sink = self._sinks_active.get(stream_name, None) + if not existing_sink: + return self.add_sqlsink(stream_name, schema, key_properties) + + if ( + existing_sink.schema != schema + or existing_sink.key_properties != key_properties + ): + self.logger.info( + "Schema or key properties for '%s' stream have changed. " + "Initializing a new '%s' sink...", + stream_name, + stream_name, + ) + self._sinks_to_clear.append(self._sinks_active.pop(stream_name)) + return self.add_sqlsink(stream_name, schema, key_properties) + + return existing_sink diff --git a/tests/conftest.py b/tests/conftest.py index 25319c015..cf64e28dd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,9 +10,10 @@ import pytest from sqlalchemy import __version__ as sqlalchemy_version +from singer_sdk import SQLConnector from singer_sdk import typing as th -from singer_sdk.sinks import BatchSink -from singer_sdk.target_base import Target +from singer_sdk.sinks import BatchSink, SQLSink +from singer_sdk.target_base import SQLTarget, Target if t.TYPE_CHECKING: from _pytest.config import Config @@ -116,3 +117,63 @@ def _write_state_message(self, state: dict): """Emit the stream's latest state.""" super()._write_state_message(state) self.state_messages_written.append(state) + + +class SQLConnectorMock(SQLConnector): + """A Mock SQLConnector class.""" + + +class SQLSinkMock(SQLSink): + """A mock Sink class.""" + + name = "sql-sink-mock" + connector_class = SQLConnectorMock + + def __init__( + self, + target: SQLTargetMock, + stream_name: str, + schema: dict, + key_properties: list[str] | None, + connector: SQLConnector | None = None, + ): + """Create the Mock batch-based sink.""" + self._connector: SQLConnector + self._connector = connector or self.connector_class(dict(target.config)) + super().__init__(target, stream_name, schema, key_properties, connector) + self.target = target + + def process_record(self, record: dict, context: dict) -> None: + """Tracks the count of processed records.""" + self.target.num_records_processed += 1 + super().process_record(record, context) + + def process_batch(self, context: dict) -> None: + """Write to mock trackers.""" + self.target.records_written.extend(context["records"]) + self.target.num_batches_processed += 1 + + @property + def key_properties(self) -> list[str]: + return [key.upper() for key in super().key_properties] + + +class SQLTargetMock(SQLTarget): + """A mock Target class.""" + + name = "sql-target-mock" + config_jsonschema = th.PropertiesList().to_dict() + default_sink_class = SQLSinkMock + + def __init__(self, *args, **kwargs): + """Create the Mock target sync.""" + super().__init__(*args, **kwargs) + self.state_messages_written: list[dict] = [] + self.records_written: list[dict] = [] + self.num_records_processed: int = 0 + self.num_batches_processed: int = 0 + + def _write_state_message(self, state: dict): + """Emit the stream's latest state.""" + super()._write_state_message(state) + self.state_messages_written.append(state) diff --git a/tests/core/test_target_base.py b/tests/core/test_target_base.py index 1fd6b9a93..ee00d35eb 100644 --- a/tests/core/test_target_base.py +++ b/tests/core/test_target_base.py @@ -4,8 +4,11 @@ import pytest -from singer_sdk.exceptions import MissingKeyPropertiesError -from tests.conftest import BatchSinkMock, TargetMock +from singer_sdk.exceptions import ( + MissingKeyPropertiesError, + RecordsWithoutSchemaException, +) +from tests.conftest import BatchSinkMock, SQLSinkMock, SQLTargetMock, TargetMock def test_get_sink(): @@ -53,3 +56,68 @@ def test_validate_record(): # Test invalid record with pytest.raises(MissingKeyPropertiesError): sink._singer_validate_message({"name": "test"}) + + +def test_sql_get_sink(): + input_schema_1 = { + "properties": { + "id": { + "type": ["string", "null"], + }, + "col_ts": { + "format": "date-time", + "type": ["string", "null"], + }, + }, + } + input_schema_2 = copy.deepcopy(input_schema_1) + key_properties = [] + target = SQLTargetMock(config={"sqlalchemy_url": "sqlite:///"}) + sink = SQLSinkMock( + target=target, + stream_name="foo", + schema=input_schema_1, + key_properties=key_properties, + connector=target.target_connector, + ) + target._sinks_active["foo"] = sink + sink_returned = target.get_sink( + "foo", + schema=input_schema_2, + key_properties=key_properties, + ) + assert sink_returned is sink + + +def test_add_sqlsink_and_get_sink(): + input_schema_1 = { + "properties": { + "id": { + "type": ["string", "null"], + }, + "col_ts": { + "format": "date-time", + "type": ["string", "null"], + }, + }, + } + input_schema_2 = copy.deepcopy(input_schema_1) + key_properties = [] + target = SQLTargetMock(config={"sqlalchemy_url": "sqlite:///"}) + sink = target.add_sqlsink( + "foo", + schema=input_schema_2, + key_properties=key_properties, + ) + + sink_returned = target.get_sink( + "foo", + ) + + assert sink_returned is sink + + # Test invalid call + with pytest.raises(RecordsWithoutSchemaException): + target.get_sink( + "bar", + )