diff --git a/dbterd/adapters/targets/mermaid.py b/dbterd/adapters/targets/mermaid.py index 5f3ef51..443e140 100644 --- a/dbterd/adapters/targets/mermaid.py +++ b/dbterd/adapters/targets/mermaid.py @@ -42,7 +42,7 @@ def match_complex_column_type(column_type: str) -> Optional[str]: Returns: Optional[str]: Returns root type if input type is nested complex type, otherwise returns `None` for primitive types """ - pattern = r"(\w+)<(\w+\s+\w+(\s*,\s*\w+\s+\w+)*)>" + pattern = r"(\w+)<.*>" match = re.match(pattern, column_type) if match: return match.group(1) diff --git a/tests/unit/adapters/targets/mermaid/test_mermaid_column_types.py b/tests/unit/adapters/targets/mermaid/test_mermaid_column_types.py new file mode 100644 index 0000000..4fcd089 --- /dev/null +++ b/tests/unit/adapters/targets/mermaid/test_mermaid_column_types.py @@ -0,0 +1,26 @@ +import pytest + +from dbterd.adapters.targets import mermaid + +complex_column_types = [ + ("string", None), + ("struct", "struct"), + ("array>", "array"), + ("array>", "array"), +] +column_types = [ + ("string", "string"), + ("struct", "struct[OMITTED]"), + ("array>", "array[OMITTED]"), + ("array>", "array[OMITTED]"), +] + + +@pytest.mark.parametrize("input,expected", complex_column_types) +def test_match_complex_column_type(input, expected): + assert mermaid.match_complex_column_type(input) == expected + + +@pytest.mark.parametrize("input,expected", column_types) +def test_replace_column_type(input, expected): + assert mermaid.replace_column_type(input) == expected