Skip to content

Commit

Permalink
feat: Standard configurable load methods (#1893)
Browse files Browse the repository at this point in the history
* initial implementation of standard load methods, sql connector implementation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* enum to values

* fix enum comparisons

* adds test for sqlite overwrite load method

* Address issues

* drop table instead of truncating

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edgar Ramírez Mondragón <edgar@meltano.com>
  • Loading branch information
3 people authored Sep 13, 2023
1 parent 50c4725 commit 10b61d2
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class {{ cookiecutter.destination_name }}Connector(SQLConnector):
allow_column_rename: bool = True # Whether RENAME COLUMN is supported.
allow_column_alter: bool = False # Whether altering column types is supported.
allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported.
allow_overwrite: bool = False # Whether overwrite load method is supported.
allow_temp_tables: bool = True # Whether temp tables are supported.

def get_sqlalchemy_url(self, config: dict) -> str:
Expand Down
1 change: 1 addition & 0 deletions samples/sample_target_sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class SQLiteConnector(SQLConnector):
allow_temp_tables = False
allow_column_alter = False
allow_merge_upsert = True
allow_overwrite: bool = True

def get_sqlalchemy_url(self, config: dict[str, t.Any]) -> str:
"""Generates a SQLAlchemy URL for SQLite."""
Expand Down
12 changes: 12 additions & 0 deletions singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from singer_sdk import typing as th
from singer_sdk._singerlib import CatalogEntry, MetadataMapping, Schema
from singer_sdk.exceptions import ConfigValidationError
from singer_sdk.helpers.capabilities import TargetLoadMethods

if t.TYPE_CHECKING:
from sqlalchemy.engine.reflection import Inspector
Expand All @@ -40,6 +41,7 @@ class SQLConnector:
allow_column_rename: bool = True # Whether RENAME COLUMN is supported.
allow_column_alter: bool = False # Whether altering column types is supported.
allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported.
allow_overwrite: bool = False # Whether overwrite load method is supported.
allow_temp_tables: bool = True # Whether temp tables are supported.
_cached_engine: Engine | None = None

Expand Down Expand Up @@ -775,6 +777,16 @@ def prepare_table(
as_temp_table=as_temp_table,
)
return
if self.config["load_method"] == TargetLoadMethods.OVERWRITE:
self.get_table(full_table_name=full_table_name).drop(self._engine)
self.create_empty_table(
full_table_name=full_table_name,
schema=schema,
primary_keys=primary_keys,
partition_keys=partition_keys,
as_temp_table=as_temp_table,
)
return

for property_name, property_def in schema["properties"].items():
self.prepare_column(
Expand Down
34 changes: 34 additions & 0 deletions singer_sdk/helpers/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,40 @@
).to_dict()


class TargetLoadMethods(str, Enum):
"""Target-specific capabilities."""

# always write all input records whether that records already exists or not
APPEND_ONLY = "append-only"

# update existing records and insert new records
UPSERT = "upsert"

# delete all existing records and insert all input records
OVERWRITE = "overwrite"


TARGET_LOAD_METHOD_CONFIG = PropertiesList(
Property(
"load_method",
StringType(),
description=(
"The method to use when loading data into the destination. "
"`append-only` will always write all input records whether that records "
"already exists or not. `upsert` will update existing records and insert "
"new records. `overwrite` will delete all existing records and insert all "
"input records."
),
allowed_values=[
TargetLoadMethods.APPEND_ONLY,
TargetLoadMethods.UPSERT,
TargetLoadMethods.OVERWRITE,
],
default=TargetLoadMethods.APPEND_ONLY,
),
).to_dict()


class DeprecatedEnum(Enum):
"""Base class for capabilities enumeration."""

Expand Down
2 changes: 2 additions & 0 deletions singer_sdk/target_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from singer_sdk.helpers.capabilities import (
ADD_RECORD_METADATA_CONFIG,
BATCH_CONFIG,
TARGET_LOAD_METHOD_CONFIG,
TARGET_SCHEMA_CONFIG,
CapabilitiesEnum,
PluginCapabilities,
Expand Down Expand Up @@ -597,6 +598,7 @@ def _merge_missing(source_jsonschema: dict, target_jsonschema: dict) -> None:
target_jsonschema["properties"][k] = v

_merge_missing(ADD_RECORD_METADATA_CONFIG, config_jsonschema)
_merge_missing(TARGET_LOAD_METHOD_CONFIG, config_jsonschema)

capabilities = cls.capabilities

Expand Down
45 changes: 45 additions & 0 deletions tests/samples/test_target_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,48 @@ def test_hostile_to_sqlite(
"hname_starts_with_number",
"name_with_emoji_",
}


def test_overwrite_load_method(
sqlite_target_test_config: dict,
):
sqlite_target_test_config["load_method"] = "overwrite"
target = SQLiteTarget(config=sqlite_target_test_config)
test_tbl = f"zzz_tmp_{str(uuid4()).split('-')[-1]}"
schema_msg = {
"type": "SCHEMA",
"stream": test_tbl,
"schema": {
"type": "object",
"properties": {"col_a": th.StringType().to_dict()},
},
}

tap_output_a = "\n".join(
json.dumps(msg)
for msg in [
schema_msg,
{"type": "RECORD", "stream": test_tbl, "record": {"col_a": "123"}},
]
)
# Assert
db = sqlite3.connect(sqlite_target_test_config["path_to_db"])
cursor = db.cursor()

target_sync_test(target, input=StringIO(tap_output_a), finalize=True)
cursor.execute(f"SELECT col_a FROM {test_tbl} ;") # noqa: S608
records = [res[0] for res in cursor.fetchall()]
assert records == ["123"]

tap_output_b = "\n".join(
json.dumps(msg)
for msg in [
schema_msg,
{"type": "RECORD", "stream": test_tbl, "record": {"col_a": "456"}},
]
)
target = SQLiteTarget(config=sqlite_target_test_config)
target_sync_test(target, input=StringIO(tap_output_b), finalize=True)
cursor.execute(f"SELECT col_a FROM {test_tbl} ;") # noqa: S608
records = [res[0] for res in cursor.fetchall()]
assert records == ["456"]

0 comments on commit 10b61d2

Please sign in to comment.