Skip to content

Commit

Permalink
feat: (WIP) Let developers more easily override SQL column type to JS…
Browse files Browse the repository at this point in the history
…ON schema mapping
  • Loading branch information
edgarrmondragon committed Aug 21, 2024
1 parent b619b0d commit e1f05cd
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 3 deletions.
74 changes: 72 additions & 2 deletions singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import functools
import logging
import typing as t
import warnings
Expand Down Expand Up @@ -103,6 +104,61 @@ def prepare_part(self, part: str) -> str: # noqa: PLR6301
return part


class SQLToJSONSchemaMap:
"""SQLAlchemy to JSON Schema type mapping helper.
This class provides a mapping from SQLAlchemy types to JSON Schema types.
"""

@functools.singledispatchmethod
def to_jsonschema(self, sa_type: sa.types.TypeEngine) -> dict: # noqa: ARG002, PLR6301
"""Return a JSON Schema representation of the provided type."""
return th.StringType.type_dict # type: ignore[no-any-return]

@to_jsonschema.register
def datetime_to_jsonschema_datetime(self, sa_type: sa.types.DateTime) -> dict: # noqa: ARG002, PLR6301
"""Return a JSON Schema representation of a generic datetime type."""
return th.DateTimeType.type_dict # type: ignore[no-any-return]

@to_jsonschema.register
def date_to_jsonschema_date(self, sa_type: sa.types.Date) -> dict: # noqa: ARG002, PLR6301
"""Return a JSON Schema representation of a date type."""
return th.DateType.type_dict # type: ignore[no-any-return]

@to_jsonschema.register
def time_to_jsonschema_time(self, sa_type: sa.types.Time) -> dict: # noqa: ARG002, PLR6301
"""Return a JSON Schema representation of a time type."""
return th.TimeType.type_dict # type: ignore[no-any-return]

@to_jsonschema.register
def integer_to_jsonschema_integer(self, sa_type: sa.types.Integer) -> dict: # noqa: ARG002, PLR6301
"""Return a JSON Schema representation of a an integer type."""
return th.IntegerType.type_dict # type: ignore[no-any-return]

@to_jsonschema.register
def float_to_jsonschema_float(self, sa_type: sa.types.Numeric) -> dict: # noqa: ARG002, PLR6301
"""Return a JSON Schema representation of a generic number type."""
return th.NumberType.type_dict # type: ignore[no-any-return]

@to_jsonschema.register
def string_to_jsonschema_string(self, sa_type: sa.types.String) -> dict: # noqa: ARG002, PLR6301
"""Return a JSON Schema representation of a generic string type."""
# TODO: Enable support for maxLength.
# if sa_type.length:
# return StringType(max_length=sa_type.length).type_dict # noqa: ERA001
return th.StringType.type_dict # type: ignore[no-any-return]

@to_jsonschema.register
def boolean_to_jsonschema_boolean(self, sa_type: sa.types.Boolean) -> dict: # noqa: ARG002, PLR6301
"""Return a JSON Schema representation of a boolean type."""
return th.BooleanType.type_dict # type: ignore[no-any-return]

@to_jsonschema.register
def variant_to_jsonschema_time(self, sa_type: sa.types.Variant) -> dict: # noqa: ARG002, PLR6301
"""Return a JSON Schema representation of a variant type."""
return th.TimeType.type_dict # type: ignore[no-any-return]


class SQLConnector: # noqa: PLR0904
"""Base class for SQLAlchemy-based connectors.
Expand Down Expand Up @@ -156,6 +212,17 @@ def logger(self) -> logging.Logger:
"""
return logging.getLogger("sqlconnector")

@functools.cached_property
def type_mapping(self) -> SQLToJSONSchemaMap:
"""Return the type mapper object.
Override this method to provide a custom mapping for your SQL dialect.
Returns:
The type mapper object.
"""
return SQLToJSONSchemaMap()

@contextmanager
def _connect(self) -> t.Iterator[sa.engine.Connection]:
with self._engine.connect().execution_options(stream_results=True) as conn:
Expand Down Expand Up @@ -260,8 +327,8 @@ def get_sqlalchemy_url(self, config: dict[str, t.Any]) -> str: # noqa: PLR6301

return t.cast(str, config["sqlalchemy_url"])

@staticmethod
def to_jsonschema_type(
self,
sql_type: (
str # noqa: ANN401
| sa.types.TypeEngine
Expand All @@ -287,7 +354,10 @@ def to_jsonschema_type(
Returns:
The JSON Schema representation of the provided type.
"""
if isinstance(sql_type, (str, sa.types.TypeEngine)):
if isinstance(sql_type, sa.types.TypeEngine):
return self.type_mapping.to_jsonschema(sql_type)

if isinstance(sql_type, str):
return th.to_jsonschema_type(sql_type)

if isinstance(sql_type, type):
Expand Down
49 changes: 48 additions & 1 deletion tests/core/test_connector_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from samples.sample_duckdb import DuckDBConnector
from singer_sdk.connectors import SQLConnector
from singer_sdk.connectors.sql import FullyQualifiedName
from singer_sdk.connectors.sql import FullyQualifiedName, SQLToJSONSchemaMap
from singer_sdk.exceptions import ConfigValidationError

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -392,3 +392,50 @@ def prepare_part(self, part: str) -> str:
def test_fully_qualified_name_empty_error():
with pytest.raises(ValueError, match="Could not generate fully qualified name"):
FullyQualifiedName()


@pytest.mark.parametrize(
"sql_type, expected_jsonschema_type",
[
pytest.param(sa.types.VARCHAR(), {"type": ["string"]}, id="varchar"),
pytest.param(
sa.types.VARCHAR(length=127),
{"type": ["string"], "maxLength": 127},
marks=pytest.mark.xfail,
id="varchar-length",
),
pytest.param(sa.types.TEXT(), {"type": ["string"]}, id="text"),
pytest.param(sa.types.INTEGER(), {"type": ["integer"]}, id="integer"),
pytest.param(sa.types.BOOLEAN(), {"type": ["boolean"]}, id="boolean"),
pytest.param(sa.types.DECIMAL(), {"type": ["number"]}, id="decimal"),
pytest.param(sa.types.FLOAT(), {"type": ["number"]}, id="float"),
pytest.param(sa.types.REAL(), {"type": ["number"]}, id="real"),
pytest.param(sa.types.NUMERIC(), {"type": ["number"]}, id="numeric"),
pytest.param(
sa.types.DATE(),
{"type": ["string"], "format": "date"},
id="date",
),
pytest.param(
sa.types.DATETIME(),
{"type": ["string"], "format": "date-time"},
id="datetime",
),
pytest.param(
sa.types.TIMESTAMP(),
{"type": ["string"], "format": "date-time"},
id="timestamp",
),
pytest.param(
sa.types.TIME(),
{"type": ["string"], "format": "time"},
id="time",
),
],
)
def test_sql_to_json_schema_map(
sql_type: sa.types.TypeEngine,
expected_jsonschema_type: dict,
):
m = SQLToJSONSchemaMap()
assert m.to_jsonschema(sql_type) == expected_jsonschema_type

0 comments on commit e1f05cd

Please sign in to comment.