Skip to content

Commit

Permalink
feat: Developers can now more easily override the mapping from JSON s…
Browse files Browse the repository at this point in the history
…schema to SQL column type
  • Loading branch information
edgarrmondragon committed Oct 25, 2024
1 parent fb9ac30 commit 8cce194
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 6 deletions.
213 changes: 210 additions & 3 deletions singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
else:
from warnings import deprecated

if sys.version_info < (3, 10):
from typing_extensions import TypeAlias
else:
from typing import TypeAlias # noqa: ICN003

if t.TYPE_CHECKING:
from sqlalchemy.engine import Engine
from sqlalchemy.engine.reflection import Inspector
Expand Down Expand Up @@ -192,6 +197,199 @@ def boolean_to_jsonschema(self, column_type: sa.types.Boolean) -> dict: # noqa:
return th.BooleanType.type_dict # type: ignore[no-any-return]


JSONtoSQLHandler: TypeAlias = t.Union[
t.Type[sa.types.TypeEngine],
t.Callable[[dict], sa.types.TypeEngine],
]


class JSONSchemaToSQL:
"""A configurable mapper for converting JSON Schema types to SQLAlchemy types."""

def __init__(self) -> None:
"""Initialize the mapper with default type mappings."""
# Default type mappings
self._type_mapping: dict[str, JSONtoSQLHandler] = {
"string": self._handle_string_type,
"integer": sa.types.INTEGER,
"number": sa.types.DECIMAL,
"boolean": sa.types.BOOLEAN,
"object": sa.types.VARCHAR,
"array": sa.types.VARCHAR,
}

# Format handlers for string types
self._format_handlers: dict[str, JSONtoSQLHandler] = {
# Default date-like formats
"date-time": sa.types.DATETIME,
"time": sa.types.TIME,
"date": sa.types.DATE,
# Common string formats with sensible defaults
"uuid": sa.types.UUID,
"email": lambda _: sa.types.VARCHAR(254), # RFC 5321
"uri": lambda _: sa.types.VARCHAR(2083), # Common browser limit
"hostname": lambda _: sa.types.VARCHAR(253), # RFC 1035
"ipv4": lambda _: sa.types.VARCHAR(15),
"ipv6": lambda _: sa.types.VARCHAR(45),
}

def _invoke_handler( # noqa: PLR6301
self,
handler: JSONtoSQLHandler,
schema: dict,
) -> sa.types.TypeEngine:
"""Invoke a handler, handling both type classes and callables.
Args:
handler: The handler to invoke.
schema: The schema to pass to callable handlers.
Returns:
The resulting SQLAlchemy type.
"""
if isinstance(handler, type):
return handler()
return handler(schema)

def register_type_handler(self, json_type: str, handler: JSONtoSQLHandler) -> None:
"""Register a custom type handler.
Args:
json_type: The JSON Schema type to handle.
handler: Either a SQLAlchemy type class or a callable that takes a schema
dict and returns a SQLAlchemy type instance.
"""
self._type_mapping[json_type] = handler

def register_format_handler(
self,
format_name: str,
handler: JSONtoSQLHandler,
) -> None:
"""Register a custom format handler.
Args:
format_name: The format string (e.g., "date-time", "email", "custom-format").
handler: Either a SQLAlchemy type class or a callable that takes a schema
dict and returns a SQLAlchemy type instance.
""" # noqa: E501
self._format_handlers[format_name] = handler

def _get_type_from_schema(self, schema: dict) -> sa.types.TypeEngine | None:
"""Try to get a SQL type from a single schema object.
Args:
schema: The JSON Schema object.
Returns:
SQL type if one can be determined, None otherwise.
"""
# Check if this is a string with format first
if schema.get("type") == "string" and "format" in schema:
format_type = self._handle_format(schema)
if format_type is not None:
return format_type

# Then check regular types
if "type" in schema:
schema_type = schema["type"]
if isinstance(schema_type, (list, tuple)):
# For type arrays, try each type
for t in schema_type:
if handler := self._type_mapping.get(t):
return self._invoke_handler(handler, schema)
elif schema_type in self._type_mapping:
handler = self._type_mapping[schema_type]
return self._invoke_handler(handler, schema)

return None

def _type_check(self, schema: dict, type_check: tuple[str, ...]) -> bool:
"""Check if the schema supports any of the specified types.
Args:
schema: The JSON Schema object.
type_check: Tuple of type strings to check for.
Returns:
bool: True if the schema supports any of the specified types.
"""
if "type" in schema:
schema_type = schema["type"]
if isinstance(schema_type, (list, tuple)):
return any(t in type_check for t in schema_type)
return schema_type in type_check

return any(self._type_check(t, type_check) for t in schema.get("anyOf", ()))

def _handle_format(self, schema: dict) -> sa.types.TypeEngine | None:
"""Handle format-specific type conversion.
Args:
schema: The JSON Schema object.
Returns:
The format-specific SQL type if applicable, None otherwise.
"""
for type_dict in schema.get("anyOf", ()):
if format_type := self._handle_format(type_dict):
return format_type

if "format" not in schema:
return None

format_type = schema["format"]
handler = self._format_handlers.get(format_type)

if handler is None:
return None

return self._invoke_handler(handler, schema)

def _handle_string_type(self, schema: dict) -> sa.types.TypeEngine:
"""Handle string type conversion with special cases for formats.
Args:
schema: The JSON Schema object.
Returns:
Appropriate SQLAlchemy type.
"""
# Check for format-specific handling first
if format_type := self._handle_format(schema):
return format_type

# Default string handling
maxlength = schema.get("maxLength")
return sa.types.VARCHAR(maxlength)

def to_sql_type(self, schema: dict) -> sa.types.TypeEngine:
"""Convert a JSON Schema type definition to a SQLAlchemy type.
Args:
schema: The JSON Schema object.
Returns:
The corresponding SQLAlchemy type.
"""
if sql_type := self._get_type_from_schema(schema):
return sql_type

# Handle anyOf
if "anyOf" in schema:
for subschema in schema["anyOf"]:
# Skip null types in anyOf
if subschema.get("type") == "null":
continue

sql_type = self._get_type_from_schema(subschema)
if sql_type is not None:
return sql_type

# Fallback
return sa.types.VARCHAR()


class SQLConnector: # noqa: PLR0904
"""Base class for SQLAlchemy-based connectors.
Expand Down Expand Up @@ -255,6 +453,16 @@ def sql_to_jsonschema(self) -> SQLToJSONSchema:
"""
return SQLToJSONSchema()

@functools.cached_property
def jsonschema_to_sql(self) -> JSONSchemaToSQL:
"""The JSON-to-SQL type mapper object for this SQL connector.
Override this property to provide a custom mapping for your SQL dialect.
.. versionadded:: 0.42.0
"""
return JSONSchemaToSQL()

@contextmanager
def _connect(self) -> t.Iterator[sa.engine.Connection]:
with self._engine.connect().execution_options(stream_results=True) as conn:
Expand Down Expand Up @@ -418,8 +626,7 @@ def to_jsonschema_type(
msg = f"Unexpected type received: '{type(sql_type).__name__}'"
raise ValueError(msg)

@staticmethod
def to_sql_type(jsonschema_type: dict) -> sa.types.TypeEngine:
def to_sql_type(self, jsonschema_type: dict) -> sa.types.TypeEngine:
"""Return a JSON Schema representation of the provided type.
By default will call `typing.to_sql_type()`.
Expand All @@ -435,7 +642,7 @@ def to_sql_type(jsonschema_type: dict) -> sa.types.TypeEngine:
Returns:
The SQLAlchemy type representation of the data type.
"""
return th.to_sql_type(jsonschema_type)
return self.jsonschema_to_sql.to_sql_type(jsonschema_type)

@staticmethod
def get_fully_qualified_name(
Expand Down
4 changes: 4 additions & 0 deletions singer_sdk/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,10 @@ def _jsonschema_type_check(jsonschema_type: dict, type_check: tuple[str]) -> boo
)


@deprecated(
"Use `JSONSchemaToSQL` instead.",
category=DeprecationWarning,
)
def to_sql_type( # noqa: PLR0911, C901
jsonschema_type: dict,
) -> sa.types.TypeEngine:
Expand Down
54 changes: 52 additions & 2 deletions tests/core/test_connector_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@

from samples.sample_duckdb import DuckDBConnector
from singer_sdk.connectors import SQLConnector
from singer_sdk.connectors.sql import FullyQualifiedName, SQLToJSONSchema
from singer_sdk.connectors.sql import (
FullyQualifiedName,
JSONSchemaToSQL,
SQLToJSONSchema,
)
from singer_sdk.exceptions import ConfigValidationError

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -445,7 +449,7 @@ def test_sql_to_json_schema_map(
assert m.to_jsonschema(sql_type) == expected_jsonschema_type


def test_custom_type():
def test_custom_type_to_jsonschema():
class MyMap(SQLToJSONSchema):
@SQLToJSONSchema.to_jsonschema.register
def custom_number_to_jsonschema(self, column_type: sa.types.NUMERIC) -> dict:
Expand All @@ -470,3 +474,49 @@ def my_type_to_jsonschema(self, column_type) -> dict: # noqa: ARG002
"multipleOf": 0.01,
}
assert m.to_jsonschema(sa.types.BOOLEAN()) == {"type": ["boolean"]}


@pytest.mark.parametrize(
"jsonschema_type,expected",
[
({"type": ["string", "null"]}, sa.types.VARCHAR),
({"type": ["integer", "null"]}, sa.types.INTEGER),
({"type": ["number", "null"]}, sa.types.DECIMAL),
({"type": ["boolean", "null"]}, sa.types.BOOLEAN),
({"type": "object", "properties": {}}, sa.types.VARCHAR),
({"type": "array"}, sa.types.VARCHAR),
({"format": "date", "type": ["string", "null"]}, sa.types.DATE),
({"format": "time", "type": ["string", "null"]}, sa.types.TIME),
({"format": "uuid", "type": ["string", "null"]}, sa.types.UUID),
(
{"format": "date-time", "type": ["string", "null"]},
sa.types.DATETIME,
),
(
{"anyOf": [{"type": "string", "format": "date-time"}, {"type": "null"}]},
sa.types.DATETIME,
),
({"anyOf": [{"type": "integer"}, {"type": "null"}]}, sa.types.INTEGER),
(
{"type": ["array", "object", "boolean", "null"]},
sa.types.VARCHAR,
),
],
)
def test_to_sql_type(jsonschema_type, expected):
to_sql = JSONSchemaToSQL()
assert isinstance(to_sql.to_sql_type(jsonschema_type), expected)


def test_register_jsonschema_type_handler():
to_sql = JSONSchemaToSQL()
to_sql.register_type_handler("my-type", sa.types.LargeBinary)
result = to_sql.to_sql_type({"type": "my-type"})
assert isinstance(result, sa.types.LargeBinary)


def test_register_jsonschema_format_handler():
to_sql = JSONSchemaToSQL()
to_sql.register_format_handler("my-format", sa.types.LargeBinary)
result = to_sql.to_sql_type({"type": "string", "format": "my-format"})
assert isinstance(result, sa.types.LargeBinary)
3 changes: 2 additions & 1 deletion tests/core/test_sql_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def test_convert_jsonschema_type_to_sql_type(
jsonschema_type: dict,
sql_type: sa.types.TypeEngine,
):
result = th.to_sql_type(jsonschema_type)
with pytest.warns(DeprecationWarning):
result = th.to_sql_type(jsonschema_type)
assert isinstance(result, sql_type.__class__)
assert str(result) == str(sql_type)

Expand Down
1 change: 1 addition & 0 deletions tests/core/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def test_conform_primitives():
assert _conform_primitive_property(1, {"type": ["boolean"]}) is True


@pytest.mark.filterwarnings("ignore:Use `JSONSchemaToSQL` instead.:DeprecationWarning")
@pytest.mark.parametrize(
"jsonschema_type,expected",
[
Expand Down

0 comments on commit 8cce194

Please sign in to comment.