From e1f05cd0ccaae993d9d13c45c9b087ae5d94a23c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez-Mondrag=C3=B3n?= Date: Wed, 21 Aug 2024 14:59:05 -0600 Subject: [PATCH] feat: (WIP) Let developers more easily override SQL column type to JSON schema mapping --- singer_sdk/connectors/sql.py | 74 +++++++++++++++++++++++++++++++- tests/core/test_connector_sql.py | 49 ++++++++++++++++++++- 2 files changed, 120 insertions(+), 3 deletions(-) diff --git a/singer_sdk/connectors/sql.py b/singer_sdk/connectors/sql.py index b6a74a976..7a826537b 100644 --- a/singer_sdk/connectors/sql.py +++ b/singer_sdk/connectors/sql.py @@ -2,6 +2,7 @@ from __future__ import annotations +import functools import logging import typing as t import warnings @@ -103,6 +104,61 @@ def prepare_part(self, part: str) -> str: # noqa: PLR6301 return part +class SQLToJSONSchemaMap: + """SQLAlchemy to JSON Schema type mapping helper. + + This class provides a mapping from SQLAlchemy types to JSON Schema types. + """ + + @functools.singledispatchmethod + def to_jsonschema(self, sa_type: sa.types.TypeEngine) -> dict: # noqa: ARG002, PLR6301 + """Return a JSON Schema representation of the provided type.""" + return th.StringType.type_dict # type: ignore[no-any-return] + + @to_jsonschema.register + def datetime_to_jsonschema_datetime(self, sa_type: sa.types.DateTime) -> dict: # noqa: ARG002, PLR6301 + """Return a JSON Schema representation of a generic datetime type.""" + return th.DateTimeType.type_dict # type: ignore[no-any-return] + + @to_jsonschema.register + def date_to_jsonschema_date(self, sa_type: sa.types.Date) -> dict: # noqa: ARG002, PLR6301 + """Return a JSON Schema representation of a date type.""" + return th.DateType.type_dict # type: ignore[no-any-return] + + @to_jsonschema.register + def time_to_jsonschema_time(self, sa_type: sa.types.Time) -> dict: # noqa: ARG002, PLR6301 + """Return a JSON Schema representation of a time type.""" + return th.TimeType.type_dict # type: ignore[no-any-return] + + @to_jsonschema.register + def integer_to_jsonschema_integer(self, sa_type: sa.types.Integer) -> dict: # noqa: ARG002, PLR6301 + """Return a JSON Schema representation of a an integer type.""" + return th.IntegerType.type_dict # type: ignore[no-any-return] + + @to_jsonschema.register + def float_to_jsonschema_float(self, sa_type: sa.types.Numeric) -> dict: # noqa: ARG002, PLR6301 + """Return a JSON Schema representation of a generic number type.""" + return th.NumberType.type_dict # type: ignore[no-any-return] + + @to_jsonschema.register + def string_to_jsonschema_string(self, sa_type: sa.types.String) -> dict: # noqa: ARG002, PLR6301 + """Return a JSON Schema representation of a generic string type.""" + # TODO: Enable support for maxLength. + # if sa_type.length: + # return StringType(max_length=sa_type.length).type_dict # noqa: ERA001 + return th.StringType.type_dict # type: ignore[no-any-return] + + @to_jsonschema.register + def boolean_to_jsonschema_boolean(self, sa_type: sa.types.Boolean) -> dict: # noqa: ARG002, PLR6301 + """Return a JSON Schema representation of a boolean type.""" + return th.BooleanType.type_dict # type: ignore[no-any-return] + + @to_jsonschema.register + def variant_to_jsonschema_time(self, sa_type: sa.types.Variant) -> dict: # noqa: ARG002, PLR6301 + """Return a JSON Schema representation of a variant type.""" + return th.TimeType.type_dict # type: ignore[no-any-return] + + class SQLConnector: # noqa: PLR0904 """Base class for SQLAlchemy-based connectors. @@ -156,6 +212,17 @@ def logger(self) -> logging.Logger: """ return logging.getLogger("sqlconnector") + @functools.cached_property + def type_mapping(self) -> SQLToJSONSchemaMap: + """Return the type mapper object. + + Override this method to provide a custom mapping for your SQL dialect. + + Returns: + The type mapper object. + """ + return SQLToJSONSchemaMap() + @contextmanager def _connect(self) -> t.Iterator[sa.engine.Connection]: with self._engine.connect().execution_options(stream_results=True) as conn: @@ -260,8 +327,8 @@ def get_sqlalchemy_url(self, config: dict[str, t.Any]) -> str: # noqa: PLR6301 return t.cast(str, config["sqlalchemy_url"]) - @staticmethod def to_jsonschema_type( + self, sql_type: ( str # noqa: ANN401 | sa.types.TypeEngine @@ -287,7 +354,10 @@ def to_jsonschema_type( Returns: The JSON Schema representation of the provided type. """ - if isinstance(sql_type, (str, sa.types.TypeEngine)): + if isinstance(sql_type, sa.types.TypeEngine): + return self.type_mapping.to_jsonschema(sql_type) + + if isinstance(sql_type, str): return th.to_jsonschema_type(sql_type) if isinstance(sql_type, type): diff --git a/tests/core/test_connector_sql.py b/tests/core/test_connector_sql.py index c8390f33d..66fd265e1 100644 --- a/tests/core/test_connector_sql.py +++ b/tests/core/test_connector_sql.py @@ -11,7 +11,7 @@ from samples.sample_duckdb import DuckDBConnector from singer_sdk.connectors import SQLConnector -from singer_sdk.connectors.sql import FullyQualifiedName +from singer_sdk.connectors.sql import FullyQualifiedName, SQLToJSONSchemaMap from singer_sdk.exceptions import ConfigValidationError if t.TYPE_CHECKING: @@ -392,3 +392,50 @@ def prepare_part(self, part: str) -> str: def test_fully_qualified_name_empty_error(): with pytest.raises(ValueError, match="Could not generate fully qualified name"): FullyQualifiedName() + + +@pytest.mark.parametrize( + "sql_type, expected_jsonschema_type", + [ + pytest.param(sa.types.VARCHAR(), {"type": ["string"]}, id="varchar"), + pytest.param( + sa.types.VARCHAR(length=127), + {"type": ["string"], "maxLength": 127}, + marks=pytest.mark.xfail, + id="varchar-length", + ), + pytest.param(sa.types.TEXT(), {"type": ["string"]}, id="text"), + pytest.param(sa.types.INTEGER(), {"type": ["integer"]}, id="integer"), + pytest.param(sa.types.BOOLEAN(), {"type": ["boolean"]}, id="boolean"), + pytest.param(sa.types.DECIMAL(), {"type": ["number"]}, id="decimal"), + pytest.param(sa.types.FLOAT(), {"type": ["number"]}, id="float"), + pytest.param(sa.types.REAL(), {"type": ["number"]}, id="real"), + pytest.param(sa.types.NUMERIC(), {"type": ["number"]}, id="numeric"), + pytest.param( + sa.types.DATE(), + {"type": ["string"], "format": "date"}, + id="date", + ), + pytest.param( + sa.types.DATETIME(), + {"type": ["string"], "format": "date-time"}, + id="datetime", + ), + pytest.param( + sa.types.TIMESTAMP(), + {"type": ["string"], "format": "date-time"}, + id="timestamp", + ), + pytest.param( + sa.types.TIME(), + {"type": ["string"], "format": "time"}, + id="time", + ), + ], +) +def test_sql_to_json_schema_map( + sql_type: sa.types.TypeEngine, + expected_jsonschema_type: dict, +): + m = SQLToJSONSchemaMap() + assert m.to_jsonschema(sql_type) == expected_jsonschema_type