From 56a96fd717a6f24cf66361d1060995e48eba4d3d Mon Sep 17 00:00:00 2001 From: Stanislav Khlud Date: Fri, 15 Mar 2024 17:26:59 +0700 Subject: [PATCH] Add tests --- .github/workflows/checks.yaml | 2 + .pre-commit-config.yaml | 8 + docker-compose.yaml | 2 +- poetry.lock | 178 +++++- pyproject.toml | 48 +- saritasa_sqlalchemy_tools/__init__.py | 2 +- saritasa_sqlalchemy_tools/auto_schema.py | 22 +- .../repositories/__init__.py | 2 +- .../repositories/core.py | 80 +-- .../repositories/filters.py | 18 +- .../repositories/ordering.py | 8 +- saritasa_sqlalchemy_tools/session.py | 1 + .../testing/factories.py | 43 +- tests/__init__.py | 0 tests/conftest.py | 126 +++++ tests/factories.py | 120 ++++ tests/models.py | 351 ++++++++++++ tests/repositories.py | 75 +++ tests/test_auto_schema.py | 341 ++++++++++++ tests/test_factories.py | 106 ++++ tests/test_ordering.py | 43 ++ tests/test_repositories.py | 522 ++++++++++++++++++ 22 files changed, 2018 insertions(+), 80 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/factories.py create mode 100644 tests/models.py create mode 100644 tests/repositories.py create mode 100644 tests/test_auto_schema.py create mode 100644 tests/test_factories.py create mode 100644 tests/test_ordering.py create mode 100644 tests/test_repositories.py diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 7376074..96c8cf3 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -49,4 +49,6 @@ jobs: run: poetry install --no-interaction --all-extras - name: Run checks run: | + poetry run inv github-actions.set-up-hosts + poetry run inv docker.up poetry run inv pre-commit.run-hooks diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index adb783c..6969c2d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,3 +51,11 @@ repos: pass_filenames: false types: [ file ] stages: [ push ] + + - id: pytest + name: Run pytest + entry: inv pytest.run + language: system + pass_filenames: false + types: [ file ] + stages: [ push ] diff --git a/docker-compose.yaml b/docker-compose.yaml index ad7eda5..2a8b4b9 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,5 +1,5 @@ version: '3.7' -name: "saritasa-sqlachemy-tools" +name: "saritasa-sqlalchemy-tools" services: postgres: diff --git a/poetry.lock b/poetry.lock index 21ae784..5501e94 100644 --- a/poetry.lock +++ b/poetry.lock @@ -29,6 +29,59 @@ six = ">=1.12.0" astroid = ["astroid (>=1,<2)", "astroid (>=2,<4)"] test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"] +[[package]] +name = "asyncpg" +version = "0.28.0" +description = "An asyncio PostgreSQL driver" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "asyncpg-0.28.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a6d1b954d2b296292ddff4e0060f494bb4270d87fb3655dd23c5c6096d16d83"}, + {file = "asyncpg-0.28.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0740f836985fd2bd73dca42c50c6074d1d61376e134d7ad3ad7566c4f79f8184"}, + {file = "asyncpg-0.28.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e907cf620a819fab1737f2dd90c0f185e2a796f139ac7de6aa3212a8af96c050"}, + {file = "asyncpg-0.28.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86b339984d55e8202e0c4b252e9573e26e5afa05617ed02252544f7b3e6de3e9"}, + {file = "asyncpg-0.28.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0c402745185414e4c204a02daca3d22d732b37359db4d2e705172324e2d94e85"}, + {file = "asyncpg-0.28.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c88eef5e096296626e9688f00ab627231f709d0e7e3fb84bb4413dff81d996d7"}, + {file = "asyncpg-0.28.0-cp310-cp310-win32.whl", hash = "sha256:90a7bae882a9e65a9e448fdad3e090c2609bb4637d2a9c90bfdcebbfc334bf89"}, + {file = "asyncpg-0.28.0-cp310-cp310-win_amd64.whl", hash = "sha256:76aacdcd5e2e9999e83c8fbcb748208b60925cc714a578925adcb446d709016c"}, + {file = "asyncpg-0.28.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a0e08fe2c9b3618459caaef35979d45f4e4f8d4f79490c9fa3367251366af207"}, + {file = "asyncpg-0.28.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b24e521f6060ff5d35f761a623b0042c84b9c9b9fb82786aadca95a9cb4a893b"}, + {file = "asyncpg-0.28.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99417210461a41891c4ff301490a8713d1ca99b694fef05dabd7139f9d64bd6c"}, + {file = "asyncpg-0.28.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f029c5adf08c47b10bcdc857001bbef551ae51c57b3110964844a9d79ca0f267"}, + {file = "asyncpg-0.28.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ad1d6abf6c2f5152f46fff06b0e74f25800ce8ec6c80967f0bc789974de3c652"}, + {file = "asyncpg-0.28.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d7fa81ada2807bc50fea1dc741b26a4e99258825ba55913b0ddbf199a10d69d8"}, + {file = "asyncpg-0.28.0-cp311-cp311-win32.whl", hash = "sha256:f33c5685e97821533df3ada9384e7784bd1e7865d2b22f153f2e4bd4a083e102"}, + {file = "asyncpg-0.28.0-cp311-cp311-win_amd64.whl", hash = "sha256:5e7337c98fb493079d686a4a6965e8bcb059b8e1b8ec42106322fc6c1c889bb0"}, + {file = "asyncpg-0.28.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:1c56092465e718a9fdcc726cc3d9dcf3a692e4834031c9a9f871d92a75d20d48"}, + {file = "asyncpg-0.28.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4acd6830a7da0eb4426249d71353e8895b350daae2380cb26d11e0d4a01c5472"}, + {file = "asyncpg-0.28.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63861bb4a540fa033a56db3bb58b0c128c56fad5d24e6d0a8c37cb29b17c1c7d"}, + {file = "asyncpg-0.28.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:a93a94ae777c70772073d0512f21c74ac82a8a49be3a1d982e3f259ab5f27307"}, + {file = "asyncpg-0.28.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d14681110e51a9bc9c065c4e7944e8139076a778e56d6f6a306a26e740ed86d2"}, + {file = "asyncpg-0.28.0-cp37-cp37m-win32.whl", hash = "sha256:8aec08e7310f9ab322925ae5c768532e1d78cfb6440f63c078b8392a38aa636a"}, + {file = "asyncpg-0.28.0-cp37-cp37m-win_amd64.whl", hash = "sha256:319f5fa1ab0432bc91fb39b3960b0d591e6b5c7844dafc92c79e3f1bff96abef"}, + {file = "asyncpg-0.28.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b337ededaabc91c26bf577bfcd19b5508d879c0ad009722be5bb0a9dd30b85a0"}, + {file = "asyncpg-0.28.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4d32b680a9b16d2957a0a3cc6b7fa39068baba8e6b728f2e0a148a67644578f4"}, + {file = "asyncpg-0.28.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4f62f04cdf38441a70f279505ef3b4eadf64479b17e707c950515846a2df197"}, + {file = "asyncpg-0.28.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f20cac332c2576c79c2e8e6464791c1f1628416d1115935a34ddd7121bfc6a4"}, + {file = "asyncpg-0.28.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:59f9712ce01e146ff71d95d561fb68bd2d588a35a187116ef05028675462d5ed"}, + {file = "asyncpg-0.28.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fc9e9f9ff1aa0eddcc3247a180ac9e9b51a62311e988809ac6152e8fb8097756"}, + {file = "asyncpg-0.28.0-cp38-cp38-win32.whl", hash = "sha256:9e721dccd3838fcff66da98709ed884df1e30a95f6ba19f595a3706b4bc757e3"}, + {file = "asyncpg-0.28.0-cp38-cp38-win_amd64.whl", hash = "sha256:8ba7d06a0bea539e0487234511d4adf81dc8762249858ed2a580534e1720db00"}, + {file = "asyncpg-0.28.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d009b08602b8b18edef3a731f2ce6d3f57d8dac2a0a4140367e194eabd3de457"}, + {file = "asyncpg-0.28.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ec46a58d81446d580fb21b376ec6baecab7288ce5a578943e2fc7ab73bf7eb39"}, + {file = "asyncpg-0.28.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b48ceed606cce9e64fd5480a9b0b9a95cea2b798bb95129687abd8599c8b019"}, + {file = "asyncpg-0.28.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8858f713810f4fe67876728680f42e93b7e7d5c7b61cf2118ef9153ec16b9423"}, + {file = "asyncpg-0.28.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5e18438a0730d1c0c1715016eacda6e9a505fc5aa931b37c97d928d44941b4bf"}, + {file = "asyncpg-0.28.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e9c433f6fcdd61c21a715ee9128a3ca48be8ac16fa07be69262f016bb0f4dbd2"}, + {file = "asyncpg-0.28.0-cp39-cp39-win32.whl", hash = "sha256:41e97248d9076bc8e4849da9e33e051be7ba37cd507cbd51dfe4b2d99c70e3dc"}, + {file = "asyncpg-0.28.0-cp39-cp39-win_amd64.whl", hash = "sha256:3ed77f00c6aacfe9d79e9eff9e21729ce92a4b38e80ea99a58ed382f42ebd55b"}, + {file = "asyncpg-0.28.0.tar.gz", hash = "sha256:7252cdc3acb2f52feaa3664280d3bcd78a46bd6c10bfd681acfffefa1120e278"}, +] + +[package.extras] +docs = ["Sphinx (>=5.3.0,<5.4.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] +test = ["flake8 (>=5.0,<6.0)", "uvloop (>=0.15.3)"] + [[package]] name = "cfgv" version = "3.4.0" @@ -220,6 +273,17 @@ files = [ [package.extras] license = ["ukkonen"] +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + [[package]] name = "invoke" version = "2.2.0" @@ -420,6 +484,17 @@ files = [ [package.dependencies] setuptools = "*" +[[package]] +name = "packaging" +version = "24.0" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.7" +files = [ + {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, + {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, +] + [[package]] name = "parso" version = "0.8.3" @@ -464,6 +539,21 @@ files = [ docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] +[[package]] +name = "pluggy" +version = "1.4.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.4.0-py3-none-any.whl", hash = "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981"}, + {file = "pluggy-1.4.0.tar.gz", hash = "sha256:8c85c2876142a764e5b7548e7d9a0e0ddb46f5185161049a79b7e974454223be"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + [[package]] name = "pre-commit" version = "3.6.2" @@ -646,6 +736,78 @@ files = [ plugins = ["importlib-metadata"] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pytest" +version = "8.1.1" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-8.1.1-py3-none-any.whl", hash = "sha256:2a8386cfc11fa9d2c50ee7b2a57e7d898ef90470a7a34c4b949ff59662bb78b7"}, + {file = "pytest-8.1.1.tar.gz", hash = "sha256:ac978141a75948948817d360297b7aae0fcb9d6ff6bc9ec6d514b85d5a65c044"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=1.4,<2.0" + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "pytest-async-sqlalchemy" +version = "0.2.0" +description = "Database testing fixtures using the SQLAlchemy asyncio API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pytest-async-sqlalchemy-0.2.0.tar.gz", hash = "sha256:0dcf80fdff1ea0046834cff2bc100c82d159e45a7ae21545a6ba9119a962b9d7"}, + {file = "pytest_async_sqlalchemy-0.2.0-py3-none-any.whl", hash = "sha256:60d7159f43d21e79d7051841fd2d6e094b7267ddc8d7192daea597afca938b12"}, +] + +[package.dependencies] +pytest = ">=6.0.0" +sqlalchemy = ">=1.4.0" + +[[package]] +name = "pytest-asyncio" +version = "0.21.1" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-asyncio-0.21.1.tar.gz", hash = "sha256:40a7eae6dded22c7b604986855ea48400ab15b069ae38116e8c01238e9eeb64d"}, + {file = "pytest_asyncio-0.21.1-py3-none-any.whl", hash = "sha256:8666c1c8ac02631d7c51ba282e0c69a8a452b211ffedf2599099845da5c5c37b"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] + +[[package]] +name = "pytest-sugar" +version = "1.0.0" +description = "pytest-sugar is a plugin for pytest that changes the default look and feel of pytest (e.g. progressbar, show tests that fail instantly)." +optional = false +python-versions = "*" +files = [ + {file = "pytest-sugar-1.0.0.tar.gz", hash = "sha256:6422e83258f5b0c04ce7c632176c7732cab5fdb909cb39cca5c9139f81276c0a"}, + {file = "pytest_sugar-1.0.0-py3-none-any.whl", hash = "sha256:70ebcd8fc5795dc457ff8b69d266a4e2e8a74ae0c3edc749381c64b5246c8dfd"}, +] + +[package.dependencies] +packaging = ">=21.3" +pytest = ">=6.2.0" +termcolor = ">=2.1.0" + +[package.extras] +dev = ["black", "flake8", "pre-commit"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -879,6 +1041,20 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "termcolor" +version = "2.4.0" +description = "ANSI color formatting for output in terminal" +optional = false +python-versions = ">=3.8" +files = [ + {file = "termcolor-2.4.0-py3-none-any.whl", hash = "sha256:9297c0df9c99445c2412e832e882a7884038a25617c60cea2ad69488d4040d63"}, + {file = "termcolor-2.4.0.tar.gz", hash = "sha256:aab9e56047c8ac41ed798fa36d892a37aca6b3e9159f3e0c24bc64a9b3ac7b7a"}, +] + +[package.extras] +tests = ["pytest", "pytest-cov"] + [[package]] name = "traitlets" version = "5.14.2" @@ -943,4 +1119,4 @@ factories = ["factory-boy"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "631ce82692e7ea81db4a9c084b8def42262c4b5ee9a43d2459fba55631c70758" +content-hash = "0935f7d122d52a5bb0cbebfa28f3b51f0eb5f8dfb41ff525042f763ffeb13ceb" diff --git a/pyproject.toml b/pyproject.toml index 2a514de..92bb988 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,22 @@ saritasa_invocations = "^1.1.0" # https://mypy.readthedocs.io/en/stable/ mypy = "^1.9.0" +[tool.poetry.group.test.dependencies] +pytest = "^8.1.1" +# pytest-asyncio is a pytest plugin. It facilitates testing of code that +# uses the asyncio library. +# https://pytest-asyncio.readthedocs.io/en/latest/ +pytest-asyncio = "^0.21.1" +# Database testing fixtures using the SQLAlchemy asyncio API +# https://pypi.org/project/pytest-async-sqlalchemy/ +pytest-async-sqlalchemy = "^0.2.0" +# To prettify pytest output +pytest-sugar = "^1.0.0" +# asyncpg is a database interface library designed specifically for PostgreSQL +# and Python/asyncio. +# https://magicstack.github.io/asyncpg/current/ +asyncpg = "^0.28.0" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" @@ -171,7 +187,7 @@ section-order = [ sqlachemy = ["sqlachemy"] [tool.ruff.lint.flake8-pytest-style] -fixture-parentheses = true +fixture-parentheses = false parametrize-names-type = "list" parametrize-values-type = "list" parametrize-values-row-type = "list" @@ -210,3 +226,33 @@ ignore = [ "**/*test_*.py", "invocations/**" ] + +[tool.pytest.ini_options] +# --capture=no +# allow use of ipdb during tests +# --ff +# run last failed tests first +addopts = [ + "--capture=no", + "--ff", + "--cov=saritasa_sqlalchemy_tools", + "--cov-report=html", +] +# skip all files inside following dirs +norecursedirs = [ + "venv", + ".venv", +] +asyncio_mode = "auto" + +[tool.coverage.run] +omit = [ + "saritasa_sqlalchemy_tools/session.py", +] + +# https://docformatter.readthedocs.io/en/latest/configuration.html# +[tool.docformatter] +wrap-descriptions=0 +in-place=true +blank=true +black=true diff --git a/saritasa_sqlalchemy_tools/__init__.py b/saritasa_sqlalchemy_tools/__init__.py index c4a0852..b5a131a 100644 --- a/saritasa_sqlalchemy_tools/__init__.py +++ b/saritasa_sqlalchemy_tools/__init__.py @@ -32,7 +32,7 @@ Filter, LazyLoaded, LazyLoadedSequence, - OrderingClausesT, + OrderingClauses, OrderingEnum, OrderingEnumMeta, SQLWhereFilter, diff --git a/saritasa_sqlalchemy_tools/auto_schema.py b/saritasa_sqlalchemy_tools/auto_schema.py index 0d58257..c482eca 100644 --- a/saritasa_sqlalchemy_tools/auto_schema.py +++ b/saritasa_sqlalchemy_tools/auto_schema.py @@ -7,6 +7,7 @@ import pydantic import pydantic_core import sqlalchemy.dialects.postgresql.ranges +import sqlalchemy.orm from . import models @@ -148,7 +149,7 @@ def get_schema( pydantic.field_validator(field)(validator) ) continue - if isinstance(field, tuple): + if isinstance(field, tuple) and len(field) == 2: field_name, field_type = field generated_fields[field_name] = ( cls._generate_field_with_custom_type( @@ -159,7 +160,7 @@ def get_schema( ) ) for index, validator in enumerate( - extra_fields_validators.get(field, ()), + extra_fields_validators.get(field_name, ()), ): validators[f"{field_name}_validator_{index}"] = ( pydantic.field_validator(field_name)(validator) @@ -197,7 +198,14 @@ def _generate_field( model_attribute, extra_field_config, ) - if model_attribute.type.__class__ not in types_mapping: + if isinstance(model_attribute.property, sqlalchemy.orm.Relationship): + raise UnableProcessTypeError( + "Schema generation is not supported for relationship " + f"fields({field}), please use auto-schema or pydantic class", + ) + if ( + model_attribute.type.__class__ not in types_mapping + ): # pragma: no cover raise UnableProcessTypeError( "Can't generate generate type for" f" {model_attribute.type.__class__}" @@ -443,7 +451,7 @@ def _generate_enum_field( ) -> PydanticFieldConfig: """Generate enum field.""" if model_type.enum_class is None: # type: ignore - raise UnableToExtractEnumClassError( + raise UnableToExtractEnumClassError( # pragma: no cover f"Can't extract enum for {field} in {model}", ) return ( @@ -517,7 +525,11 @@ def _generate_array_field( model_type.item_type, # type: ignore extra_field_config, ) - return list[list_type], pydantic_core.PydanticUndefined # type: ignore + return ( + list[list_type] | None # type: ignore + if model_attribute.nullable + else list[list_type] # type: ignore + ), pydantic_core.PydanticUndefined ModelAutoSchemaT = typing.TypeVar( diff --git a/saritasa_sqlalchemy_tools/repositories/__init__.py b/saritasa_sqlalchemy_tools/repositories/__init__.py index ef19cbc..ca76145 100644 --- a/saritasa_sqlalchemy_tools/repositories/__init__.py +++ b/saritasa_sqlalchemy_tools/repositories/__init__.py @@ -6,7 +6,7 @@ WhereFilters, transform_search_filter, ) -from .ordering import OrderingClausesT, OrderingEnum, OrderingEnumMeta +from .ordering import OrderingClauses, OrderingEnum, OrderingEnumMeta from .types import ( Annotation, AnnotationSequence, diff --git a/saritasa_sqlalchemy_tools/repositories/core.py b/saritasa_sqlalchemy_tools/repositories/core.py index b43f85d..8109132 100644 --- a/saritasa_sqlalchemy_tools/repositories/core.py +++ b/saritasa_sqlalchemy_tools/repositories/core.py @@ -36,22 +36,18 @@ def __init__( self, db_session: sqlalchemy.ext.asyncio.AsyncSession, ) -> None: - self._db_session = db_session + self.db_session = db_session def init_other( self, repository_class: type[BaseRepositoryT], ) -> BaseRepositoryT: """Init other repo from current.""" - return repository_class(db_session=self._db_session) - - async def commit(self) -> None: - """Commit transaction.""" - await self._db_session.commit() + return repository_class(db_session=self.db_session) async def flush(self) -> None: """Perform changes to database.""" - await self._db_session.flush() + await self.db_session.flush() async def refresh( self, @@ -59,7 +55,7 @@ async def refresh( attribute_names: collections.abc.Sequence[str] | None = None, ) -> None: """Refresh instance.""" - await self._db_session.refresh( + await self.db_session.refresh( instance=instance, attribute_names=attribute_names, ) @@ -71,11 +67,11 @@ def expire(self, instance: models.BaseModelT) -> None: fetched from db again. """ - self._db_session.expire(instance) + self.db_session.expire(instance) async def get(self, pk: int | str) -> models.BaseModelT | None: """Return entry from DB by primary key.""" - return await self._db_session.get(self.model, pk) + return await self.db_session.get(self.model, pk) async def save( self, @@ -84,7 +80,7 @@ async def save( attribute_names: collections.abc.Sequence[str] | None = None, ) -> models.BaseModelT: """Save model instance into db.""" - self._db_session.add(instance=instance) + self.db_session.add(instance=instance) await self.flush() if refresh: await self.refresh(instance, attribute_names) @@ -92,7 +88,7 @@ async def save( async def delete(self, instance: models.BaseModelT) -> None: """Delete model instance into db.""" - await self._db_session.delete(instance=instance) + await self.db_session.delete(instance=instance) await self.flush() async def delete_batch( @@ -101,7 +97,7 @@ async def delete_batch( **filters_by: typing.Any, ) -> None: """Delete batch of objects from db.""" - await self._db_session.execute( + await self.db_session.execute( statement=( sqlalchemy.sql.delete(self.model) .where(*self.process_where_filters(*where)) @@ -141,7 +137,7 @@ async def insert_batch( exclude_fields: collections.abc.Sequence[str] = (), ) -> list[models.BaseModelT]: """Create batch of objects in db.""" - if not objects: + if not objects: # pragma: no cover return [] objects_as_dict = self.objects_as_dict( @@ -150,7 +146,7 @@ async def insert_batch( exclude_fields or self.default_exclude_bulk_create_fields ), ) - created_objects = await self._db_session.scalars( + created_objects = await self.db_session.scalars( sqlalchemy.sql.insert(self.model) .returning(self.model) .values(objects_as_dict), @@ -164,7 +160,7 @@ async def update_batch( exclude_fields: collections.abc.Sequence[str] = (), ) -> None: """Update batch of objects in db.""" - if not objects: + if not objects: # pragma: no cover return objects_as_dict = self.objects_as_dict( @@ -173,7 +169,7 @@ async def update_batch( exclude_fields or self.default_exclude_bulk_update_fields ), ) - await self._db_session.execute( + await self.db_session.execute( sqlalchemy.sql.update(self.model), objects_as_dict, ) @@ -208,7 +204,7 @@ def get_annotated_statement( def get_filter_statement( self, statement: models.SelectStatement[models.BaseModelT] | None = None, - *where_filters: filters.WhereFilter, + *where: filters.WhereFilter, **filters_by: typing.Any, ) -> models.SelectStatement[models.BaseModelT]: """Get statement with filtering.""" @@ -217,17 +213,17 @@ def get_filter_statement( else: select_statement = self.select_statement return select_statement.where( - *self.process_where_filters(*where_filters), + *self.process_where_filters(*where), ).filter_by(**filters_by) @classmethod def process_where_filters( cls, - *where_filters: filters.WhereFilter, + *where: filters.WhereFilter, ) -> collections.abc.Sequence[filters.SQLWhereFilter]: """Process where filters.""" processed_where_filters: list[filters.SQLWhereFilter] = [] - for where_filter in where_filters: + for where_filter in where: if isinstance(where_filter, filters.Filter): processed_where_filters.append( where_filter.transform_filter(cls.model), # type: ignore @@ -239,22 +235,30 @@ def process_where_filters( def get_order_statement( self, statement: models.SelectStatement[models.BaseModelT] | None = None, - *clauses: sqlalchemy.ColumnExpressionArgument[str] | str, + *clauses: ordering.OrderingClause, ) -> models.SelectStatement[models.BaseModelT]: """Get statement with ordering.""" if statement is not None: select_statement = statement else: select_statement = self.select_statement - ordering_clauses = [ - ( - clause.db_clause - if isinstance(clause, ordering.OrderingEnum) - else clause - ) - for clause in clauses - ] - return select_statement.order_by(*ordering_clauses) + return select_statement.order_by( + *self.process_ordering_clauses(*clauses), + ) + + @classmethod + def process_ordering_clauses( + cls, + *clauses: ordering.OrderingClause, + ) -> collections.abc.Sequence[ordering.SQLOrderingClause]: + """Process ordering clauses.""" + processed_ordering_clauses: list[ordering.SQLOrderingClause] = [] + for clause in clauses: + if isinstance(clause, ordering.OrderingEnum): + processed_ordering_clauses.append(clause.sql_clause) + else: + processed_ordering_clauses.append(clause) + return processed_ordering_clauses def get_pagination_statement( self, @@ -323,7 +327,7 @@ def get_fetch_statement( joined_load: types.LazyLoadedSequence = (), select_in_load: types.LazyLoadedSequence = (), annotations: types.AnnotationSequence = (), - clauses: ordering.OrderingClausesT = (), + ordering_clauses: ordering.OrderingClauses = (), where: filters.WhereFilters = (), **filters_by: typing.Any, ) -> models.SelectStatement[models.BaseModelT]: @@ -342,7 +346,7 @@ def get_fetch_statement( ) statement = self.get_order_statement( statement, - *clauses, + *ordering_clauses, ) statement = self.get_filter_statement( statement, @@ -364,12 +368,12 @@ async def fetch( joined_load: types.LazyLoadedSequence = (), select_in_load: types.LazyLoadedSequence = (), annotations: types.AnnotationSequence = (), - clauses: ordering.OrderingClausesT = (), + ordering_clauses: ordering.OrderingClauses = (), where: filters.WhereFilters = (), **filters_by: typing.Any, ) -> sqlalchemy.ScalarResult[models.BaseModelT]: """Fetch entries.""" - return await self._db_session.scalars( + return await self.db_session.scalars( statement=self.get_fetch_statement( statement=statement, offset=offset, @@ -377,7 +381,7 @@ async def fetch( joined_load=joined_load, select_in_load=select_in_load, annotations=annotations, - clauses=clauses, + ordering_clauses=ordering_clauses, where=where, **filters_by, ), @@ -390,7 +394,7 @@ async def count( ) -> int: """Get count of entries.""" return ( - await self._db_session.scalar( + await self.db_session.scalar( sqlalchemy.select(sqlalchemy.func.count()) .select_from(self.model) .where(*self.process_where_filters(*where)) @@ -405,7 +409,7 @@ async def exists( ) -> bool: """Check existence of entries.""" return ( - await self._db_session.scalar( + await self.db_session.scalar( sqlalchemy.select( sqlalchemy.sql.exists( self.select_statement.where( diff --git a/saritasa_sqlalchemy_tools/repositories/filters.py b/saritasa_sqlalchemy_tools/repositories/filters.py index 63f77bc..e99ea4f 100644 --- a/saritasa_sqlalchemy_tools/repositories/filters.py +++ b/saritasa_sqlalchemy_tools/repositories/filters.py @@ -38,7 +38,7 @@ class Filter: """Define filter value.""" - api_filter: str + field: str value: FilterType def transform_filter( @@ -46,7 +46,7 @@ def transform_filter( model: type[models.BaseModelT], ) -> SQLWhereFilter: """Transform filter valid for sqlalchemy.""" - field_name, filter_arg = self.api_filter.split("__") + field_name, filter_arg = self.field.split("__") if field_name in model.m2m_filters: return self.transform_m2m_filter( field_name=field_name, @@ -114,20 +114,6 @@ def transform_simple_filter( filter_operator = getattr(field, filter_args_mapping[filter_arg])( value, ) - if ( - filter_arg - in ( - "gt", - "gte", - "lt", - "lte", - ) - and field.nullable - ): - filter_operator = sqlalchemy.or_( - filter_operator, - field.is_(None), - ) return filter_operator diff --git a/saritasa_sqlalchemy_tools/repositories/ordering.py b/saritasa_sqlalchemy_tools/repositories/ordering.py index fd61cea..d221e3e 100644 --- a/saritasa_sqlalchemy_tools/repositories/ordering.py +++ b/saritasa_sqlalchemy_tools/repositories/ordering.py @@ -25,13 +25,13 @@ class OrderingEnum(enum.StrEnum, metaclass=OrderingEnumMeta): """Representation of ordering fields.""" @property - def db_clause(self) -> str | sqlalchemy.ColumnExpressionArgument[str]: + def sql_clause(self) -> "SQLOrderingClause": """Convert ordering value to sqlalchemy ordering clause.""" if self.startswith("-"): return sqlalchemy.desc(self[1:]) return self -OrderingClausesT: typing.TypeAlias = collections.abc.Sequence[ - str | sqlalchemy.ColumnExpressionArgument[str] | OrderingEnum -] +SQLOrderingClause = str | sqlalchemy.ColumnExpressionArgument[str] +OrderingClause = SQLOrderingClause | OrderingEnum +OrderingClauses: typing.TypeAlias = collections.abc.Sequence[OrderingClause] diff --git a/saritasa_sqlalchemy_tools/session.py b/saritasa_sqlalchemy_tools/session.py index 3e047d8..05d9730 100644 --- a/saritasa_sqlalchemy_tools/session.py +++ b/saritasa_sqlalchemy_tools/session.py @@ -1,3 +1,4 @@ +# pragma: no cover import collections.abc import contextlib import typing diff --git a/saritasa_sqlalchemy_tools/testing/factories.py b/saritasa_sqlalchemy_tools/testing/factories.py index 020c283..860b20a 100644 --- a/saritasa_sqlalchemy_tools/testing/factories.py +++ b/saritasa_sqlalchemy_tools/testing/factories.py @@ -61,7 +61,9 @@ async def create_async( ) ).first() if not instance_from_db: - raise ValueError("Created instance wasn't found in database") + raise ValueError( # pragma: no cover + "Created instance wasn't found in database", + ) return instance_from_db @classmethod @@ -87,23 +89,40 @@ async def _async_run_sub_factories( cls, session: session.Session, passed_fields: collections.abc.Sequence[str], - ) -> dict[str, models.BaseModel]: + ) -> dict[str, models.BaseModel | list[models.BaseModel]]: """Generate objects from sub factories.""" - sub_factories_map: dict[str, str] = getattr( + sub_factories_map: dict[str, str | tuple[str, int]] = getattr( cls._meta, "sub_factories", {}, ) - generated_instances: dict[str, models.BaseModel] = {} + generated_instances: dict[ + str, + models.BaseModel | list[models.BaseModel], + ] = {} for field, sub_factory_path in sub_factories_map.items(): if field in passed_fields or f"{field}_id" in passed_fields: continue - *factory_module, sub_factory_name = sub_factory_path.split(".") - sub_factory: typing.Self = getattr( - importlib.import_module(".".join(factory_module)), - sub_factory_name, - ) - generated_instances[field] = await sub_factory.create_async( - session=session, - ) + if isinstance(sub_factory_path, str): + *factory_module, sub_factory_name = sub_factory_path.split(".") + sub_factory = getattr( + importlib.import_module(".".join(factory_module)), + sub_factory_name, + ) + generated_instances[field] = await sub_factory.create_async( + session=session, + ) + else: + sub_factory_path, size = sub_factory_path + *factory_module, sub_factory_name = sub_factory_path.split(".") + sub_factory = getattr( + importlib.import_module(".".join(factory_module)), + sub_factory_name, + ) + generated_instances[ + field + ] = await sub_factory.create_batch_async( + session=session, + size=size, + ) return generated_instances diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..680e453 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,126 @@ +import asyncio +import collections.abc +import typing + +import pytest +import sqlalchemy + +import saritasa_sqlalchemy_tools + +from . import factories, models, repositories + + +@pytest.fixture(scope="session") +def event_loop() -> ( + collections.abc.Generator[ + asyncio.AbstractEventLoop, + typing.Any, + None, + ] +): + """Override `event_loop` fixture to change scope to `session`. + + Needed for pytest-async-sqlalchemy. + + """ + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="session") +def database_url(request: pytest.FixtureRequest) -> str: + """Override database url. + + Grab configs from settings and add support for pytest-xdist + + """ + worker_input = getattr( + request.config, + "workerinput", + { + "workerid": "", + }, + ) + return sqlalchemy.engine.URL( + drivername="postgresql+asyncpg", + username="saritasa-sqlalchemy-tools-user", + password="manager", + host="postgres", + port=5432, + database="_".join( + filter( + None, + ( + "saritasa-sqlalchemy-tools-dev", + "test", + worker_input["workerid"], + ), + ), + ), + query={}, # type: ignore + ).render_as_string(hide_password=False) + + +@pytest.fixture(scope="session") +def init_database() -> collections.abc.Callable[..., None]: + """Return callable object that will be called to init database. + + Overridden fixture from `pytest-async-sqlalchemy package`. + https://github.com/igortg/pytest-async-sqlalchemy + + """ + return saritasa_sqlalchemy_tools.BaseModel.metadata.create_all + + +@pytest.fixture +async def test_model( + db_session: saritasa_sqlalchemy_tools.Session, +) -> models.TestModel: + """Generate test_model instance.""" + return await factories.TestModelFactory.create_async(session=db_session) + + +@pytest.fixture +async def related_model( + db_session: saritasa_sqlalchemy_tools.Session, +) -> models.RelatedModel: + """Generate test_model instance.""" + return await factories.RelatedModelFactory.create_async(session=db_session) + + +@pytest.fixture +async def test_models( + db_session: saritasa_sqlalchemy_tools.Session, +) -> collections.abc.Sequence[models.TestModel]: + """Generate test_model instances.""" + return await factories.TestModelFactory.create_batch_async( + session=db_session, + size=5, + ) + + +@pytest.fixture +async def repository( + db_session: saritasa_sqlalchemy_tools.Session, +) -> repositories.TestModelRepository: + """Get repository.""" + return repositories.TestModelRepository(db_session=db_session) + + +@pytest.fixture +async def soft_delete_test_model( + db_session: saritasa_sqlalchemy_tools.Session, +) -> models.SoftDeleteTestModel: + """Generate soft_delete_test_model instance.""" + return await factories.SoftDeleteTestModelFactory.create_async( + session=db_session, + ) + + +@pytest.fixture +async def soft_delete_repository( + db_session: saritasa_sqlalchemy_tools.Session, +) -> repositories.SoftDeleteTestModelRepository: + """Get soft delete repository.""" + return repositories.SoftDeleteTestModelRepository(db_session=db_session) diff --git a/tests/factories.py b/tests/factories.py new file mode 100644 index 0000000..d27baf1 --- /dev/null +++ b/tests/factories.py @@ -0,0 +1,120 @@ +import typing + +import factory +import factory.fuzzy + +import saritasa_sqlalchemy_tools + +from . import models, repositories + + +class TestModelFactory( + saritasa_sqlalchemy_tools.AsyncSQLAlchemyModelFactory[models.TestModel], +): + """Factory to generate TestModel.""" + + text = factory.Faker( + "pystr", + min_chars=1, + max_chars=30, + ) + text_enum = factory.fuzzy.FuzzyChoice( + models.TestModel.TextEnum, + ) + number = factory.Faker("pyint") + small_number = factory.Faker("pyint") + decimal_number = factory.Faker( + "pydecimal", + positive=True, + left_digits=5, + right_digits=0, + ) + true_false = factory.Faker("pybool") + text_list = factory.List( + [ + factory.Faker( + "pystr", + min_chars=1, + max_chars=30, + ) + for _ in range(3) + ], + ) + date_time = factory.Faker("date_time") + date = factory.Faker("date_between") + timedelta = factory.Faker("time_delta") + + class Meta: + model = models.TestModel + repository = repositories.TestModelRepository + sub_factories: typing.ClassVar = { + "related_model": "tests.factories.RelatedModelFactory", + "related_models": ("tests.factories.RelatedModelFactory", 5), + } + + +class SoftDeleteTestModelFactory( + saritasa_sqlalchemy_tools.AsyncSQLAlchemyModelFactory[ + models.SoftDeleteTestModel + ], +): + """Factory to generate SoftDeleteTestModel.""" + + text = factory.Faker( + "pystr", + min_chars=1, + max_chars=30, + ) + text_enum = factory.fuzzy.FuzzyChoice( + models.TestModel.TextEnum, + ) + number = factory.Faker("pyint") + small_number = factory.Faker("pyint") + decimal_number = factory.Faker( + "pydecimal", + positive=True, + left_digits=5, + right_digits=0, + ) + true_false = factory.Faker("pybool") + text_list = factory.List( + [ + factory.Faker( + "pystr", + min_chars=1, + max_chars=30, + ) + for _ in range(3) + ], + ) + date_time = factory.Faker("date_time") + date = factory.Faker("date_between") + timedelta = factory.Faker("time_delta") + + class Meta: + model = models.SoftDeleteTestModel + repository = repositories.SoftDeleteTestModelRepository + + +class RelatedModelFactory( + saritasa_sqlalchemy_tools.AsyncSQLAlchemyModelFactory[models.RelatedModel], +): + """Factory to generate RelatedModel.""" + + class Meta: + model = models.RelatedModel + repository = repositories.RelatedModelRepository + + +class M2MModelFactory( + saritasa_sqlalchemy_tools.AsyncSQLAlchemyModelFactory[models.M2MModel], +): + """Factory to generate M2MModel.""" + + class Meta: + model = models.M2MModel + repository = repositories.M2MModelRepository + sub_factories: typing.ClassVar = { + "related_model": "tests.factories.RelatedModelFactory", + "test_model": "tests.factories.TestModelFactory", + } diff --git a/tests/models.py b/tests/models.py new file mode 100644 index 0000000..878c1f7 --- /dev/null +++ b/tests/models.py @@ -0,0 +1,351 @@ +import datetime +import decimal +import enum +import typing + +import sqlalchemy.orm + +import saritasa_sqlalchemy_tools + + +class RelatedModel( + saritasa_sqlalchemy_tools.TimeStampedBaseIDModel, +): + """Test models for checking relationships.""" + + __tablename__ = "related_model" + + test_models = sqlalchemy.orm.relationship( + "TestModel", + foreign_keys="TestModel.related_model_id", + back_populates="related_model", + ) + + test_models_nullable = sqlalchemy.orm.relationship( + "TestModel", + foreign_keys="TestModel.related_model_id_nullable", + back_populates="related_model_nullable", + ) + + test_model_id: sqlalchemy.orm.Mapped[int | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.ForeignKey( + "test_model.id", + ondelete="CASCADE", + ), + nullable=True, + ) + ) + + test_model = sqlalchemy.orm.relationship( + "TestModel", + foreign_keys=[test_model_id], + back_populates="related_models", + ) + + m2m_test_models = sqlalchemy.orm.relationship( + "TestModel", + secondary="m2m_model", + uselist=True, + viewonly=True, + ) + + m2m_associations = sqlalchemy.orm.relationship( + "M2MModel", + back_populates="related_model", + foreign_keys="M2MModel.related_model_id", + cascade="all, delete", + uselist=True, + ) + + +class FieldsMixin: + """Mixin which adds fields to models.""" + + text: sqlalchemy.orm.Mapped[str] = sqlalchemy.orm.mapped_column( + sqlalchemy.String(30), + nullable=False, + ) + + text_nullable: sqlalchemy.orm.Mapped[str | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.String(30), + nullable=True, + ) + ) + + class TextEnum(enum.StrEnum): + value_1 = "value1" + value_2 = "value2" + value_3 = "value3" + + text_enum: sqlalchemy.orm.Mapped[TextEnum] = sqlalchemy.orm.mapped_column( + sqlalchemy.Enum(TextEnum), + nullable=False, + ) + + text_enum_nullable: sqlalchemy.orm.Mapped[TextEnum | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.Enum(TextEnum), + nullable=True, + ) + ) + + number: sqlalchemy.orm.Mapped[int] = sqlalchemy.orm.mapped_column( + sqlalchemy.Integer(), + nullable=False, + ) + + number_nullable: sqlalchemy.orm.Mapped[int | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.Integer(), + nullable=True, + ) + ) + + small_number: sqlalchemy.orm.Mapped[int] = sqlalchemy.orm.mapped_column( + sqlalchemy.SmallInteger(), + nullable=False, + ) + + small_number_nullable: sqlalchemy.orm.Mapped[int | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.SmallInteger(), + nullable=True, + ) + ) + + decimal_number: sqlalchemy.orm.Mapped[decimal.Decimal] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.Numeric(), + nullable=False, + ) + ) + + decimal_number_nullable: sqlalchemy.orm.Mapped[decimal.Decimal | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.Numeric(), + nullable=True, + ) + ) + + true_false: sqlalchemy.orm.Mapped[bool] = sqlalchemy.orm.mapped_column( + sqlalchemy.Boolean(), + nullable=False, + ) + + true_false_nullable: sqlalchemy.orm.Mapped[bool | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.Boolean(), + nullable=True, + ) + ) + + text_list: sqlalchemy.orm.Mapped[list[str]] = sqlalchemy.orm.mapped_column( + sqlalchemy.ARRAY(sqlalchemy.String), + nullable=False, + ) + + text_list_nullable: sqlalchemy.orm.Mapped[list[str] | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.ARRAY(sqlalchemy.String), + nullable=True, + ) + ) + + date_time: sqlalchemy.orm.Mapped[datetime.datetime] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.DateTime(), + nullable=False, + ) + ) + + date_time_nullable: sqlalchemy.orm.Mapped[datetime.datetime | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.DateTime(), + nullable=True, + ) + ) + + date: sqlalchemy.orm.Mapped[datetime.date] = sqlalchemy.orm.mapped_column( + sqlalchemy.Date(), + nullable=False, + ) + + date_nullable: sqlalchemy.orm.Mapped[datetime.date | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.Date(), + nullable=True, + ) + ) + + timedelta: sqlalchemy.orm.Mapped[datetime.timedelta] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.Interval(), + nullable=False, + ) + ) + + timedelta_nullable: sqlalchemy.orm.Mapped[datetime.timedelta | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.Interval(), + nullable=True, + ) + ) + + @property + def custom_property(self) -> str: + """Implement property.""" + return "" + + @property + def custom_property_nullable(self) -> str | None: + """Implement property.""" + return "" + + +class TestModel( + FieldsMixin, + saritasa_sqlalchemy_tools.TimeStampedBaseIDModel, +): + """Test model for testing.""" + + __tablename__ = "test_model" + + m2m_filters: typing.ClassVar = { + "m2m_related_model_id": saritasa_sqlalchemy_tools.M2MFilterConfig( + relation_field="m2m_associations", + filter_field="related_model_id", + match_field="test_model_id", + ), + } + + related_model_id: sqlalchemy.orm.Mapped[int] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.ForeignKey( + "related_model.id", + ondelete="CASCADE", + ), + nullable=False, + ) + ) + + related_model = sqlalchemy.orm.relationship( + "RelatedModel", + foreign_keys=[related_model_id], + back_populates="test_models", + ) + + related_model_id_nullable: sqlalchemy.orm.Mapped[int | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.ForeignKey( + "related_model.id", + ondelete="CASCADE", + ), + nullable=True, + ) + ) + + related_model_nullable = sqlalchemy.orm.relationship( + "RelatedModel", + foreign_keys=[related_model_id_nullable], + back_populates="test_models_nullable", + ) + + related_models = sqlalchemy.orm.relationship( + "RelatedModel", + foreign_keys="RelatedModel.test_model_id", + back_populates="test_model", + ) + + related_models_count: sqlalchemy.orm.Mapped[int | None] = ( + sqlalchemy.orm.column_property( + sqlalchemy.select(sqlalchemy.func.count(RelatedModel.id)) + .correlate_except(RelatedModel) + .scalar_subquery(), + deferred=True, + ) + ) + + related_models_count_query: sqlalchemy.orm.Mapped[int | None] = ( + sqlalchemy.orm.query_expression() + ) + + m2m_related_models = sqlalchemy.orm.relationship( + "RelatedModel", + secondary="m2m_model", + uselist=True, + viewonly=True, + ) + + m2m_associations = sqlalchemy.orm.relationship( + "M2MModel", + back_populates="test_model", + foreign_keys="M2MModel.test_model_id", + cascade="all, delete", + uselist=True, + ) + + def __repr__(self) -> str: + """Get str representation.""" + return f"TestModel<{self.id}>" + + @property + def custom_property_related_model(self) -> typing.Any: + """Implement property.""" + return self.related_model + + @property + def custom_property_related_model_nullable(self) -> typing.Any | None: + """Implement property.""" + return self.related_model_nullable + + @property + def custom_property_related_models(self) -> list[typing.Any]: + """Implement property.""" + return self.related_models + + +class SoftDeleteTestModel( + FieldsMixin, + saritasa_sqlalchemy_tools.SoftDeleteBaseIDModel, +): + """Test model for testing(soft-delete case).""" + + __tablename__ = "soft_delete_test_model" + + +class M2MModel(saritasa_sqlalchemy_tools.TimeStampedBaseIDModel): + """Test model for testing m2m features.""" + + __tablename__ = "m2m_model" + + test_model_id: sqlalchemy.orm.Mapped[int] = sqlalchemy.orm.mapped_column( + sqlalchemy.ForeignKey( + "test_model.id", + ondelete="CASCADE", + ), + nullable=False, + ) + + test_model = sqlalchemy.orm.relationship( + "TestModel", + foreign_keys=[test_model_id], + back_populates="m2m_associations", + ) + + related_model_id: sqlalchemy.orm.Mapped[int] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.ForeignKey( + "related_model.id", + ondelete="CASCADE", + ), + nullable=False, + ) + ) + + related_model = sqlalchemy.orm.relationship( + "RelatedModel", + foreign_keys=[related_model_id], + back_populates="m2m_associations", + ) diff --git a/tests/repositories.py b/tests/repositories.py new file mode 100644 index 0000000..45ea211 --- /dev/null +++ b/tests/repositories.py @@ -0,0 +1,75 @@ +import typing + +import saritasa_sqlalchemy_tools + +from . import models + + +class RelatedModelRepository( + saritasa_sqlalchemy_tools.BaseRepository[models.RelatedModel], +): + """Repository for `RelatedModel` model.""" + + model: typing.TypeAlias = models.RelatedModel + default_exclude_bulk_create_fields = ( + "created", + "modified", + "id", + ) + default_exclude_bulk_update_fields = ( + "created", + "modified", + ) + + +class TestModelRepository( + saritasa_sqlalchemy_tools.BaseRepository[models.TestModel], +): + """Repository for `TestModel` model.""" + + model: typing.TypeAlias = models.TestModel + default_exclude_bulk_create_fields = ( + "created", + "modified", + "id", + ) + default_exclude_bulk_update_fields = ( + "created", + "modified", + ) + + +class SoftDeleteTestModelRepository( + saritasa_sqlalchemy_tools.BaseSoftDeleteRepository[ + models.SoftDeleteTestModel + ], +): + """Repository for `SoftDeleteTestModel` model.""" + + model: typing.TypeAlias = models.SoftDeleteTestModel + default_exclude_bulk_create_fields = ( + "created", + "modified", + "id", + ) + default_exclude_bulk_update_fields = ( + "created", + "modified", + ) + + +class M2MModelRepository( + saritasa_sqlalchemy_tools.BaseRepository[models.M2MModel], +): + """Repository for `M2MModel` model.""" + + model: typing.TypeAlias = models.M2MModel + default_exclude_bulk_create_fields = ( + "created", + "modified", + "id", + ) + default_exclude_bulk_update_fields = ( + "created", + "modified", + ) diff --git a/tests/test_auto_schema.py b/tests/test_auto_schema.py new file mode 100644 index 0000000..7308fe8 --- /dev/null +++ b/tests/test_auto_schema.py @@ -0,0 +1,341 @@ +import datetime +import decimal +import re +import typing + +import pydantic +import pytest + +import saritasa_sqlalchemy_tools + +from . import models, repositories + + +async def test_auto_schema_generation( + test_model: models.TestModel, +) -> None: + """Test auto schema generation picks correct types for schema.""" + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + model_config = pydantic.ConfigDict( + from_attributes=True, + validate_assignment=True, + ) + fields = ( + "id", + "created", + "modified", + "text", + "text_nullable", + "text_enum", + "text_enum_nullable", + "number", + "number_nullable", + "small_number", + "small_number_nullable", + "decimal_number", + "decimal_number_nullable", + "true_false", + "true_false_nullable", + "text_list", + "text_list_nullable", + "date_time", + "date_time_nullable", + "date", + "date_nullable", + "timedelta", + "timedelta_nullable", + "related_model_id", + "related_model_id_nullable", + "custom_property", + "custom_property_nullable", + ) + + schema = AutoSchema.get_schema() + model = schema.model_validate(test_model) + for field in AutoSchema.Meta.fields: + assert getattr(model, field) == getattr(test_model, field) + if "nullable" not in field and "property" not in field: + with pytest.raises(pydantic.ValidationError): + setattr(model, field, None) + + +async def test_auto_schema_type_override_generation( + test_model: models.TestModel, +) -> None: + """Test that in auto schema generation ypu can override type.""" + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + fields = ( + ("text", str | None), + ("text_enum", models.TestModel.TextEnum | None), + ("number", int | None), + ("small_number", int | None), + ("decimal_number", decimal.Decimal | None), + ("true_false", bool | None), + ("text_list", list[str] | None), + ("date_time", datetime.datetime | None), + ("date", datetime.date | None), + ("timedelta", datetime.timedelta | None), + ("custom_property", str | None), + ("related_model_id", int | None), + ) + + schema = AutoSchema.get_schema() + model = schema.model_validate(test_model) + for field, _ in AutoSchema.Meta.fields: + if "property" not in field: + setattr(model, field, None) + + +async def test_auto_schema_type_invalid_field_config( + test_model: models.TestModel, +) -> None: + """Test than on invalid field config type there is an error.""" + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + fields = (("id", int, 1),) + + with pytest.raises( + saritasa_sqlalchemy_tools.auto_schema.UnableProcessTypeError, + match=re.escape( + "Can't process the following field ('id', , 1)", + ), + ): + AutoSchema.get_schema() + + +async def test_auto_schema_related_field_with_no_schema( + test_model: models.TestModel, +) -> None: + """Test that in generation raise error on relationships without type.""" + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + fields = ( + "related_model", + "related_models", + ) + + with pytest.raises( + saritasa_sqlalchemy_tools.auto_schema.UnableProcessTypeError, + match=re.escape( + "Schema generation is not supported for relationship " + "fields(related_model), please use auto-schema or pydantic class", + ), + ): + AutoSchema.get_schema() + + +async def test_auto_schema_related_field_with_schema( + test_model: models.TestModel, + repository: repositories.TestModelRepository, +) -> None: + """Test that generation works correctly with related types auto.""" + + class RelatedAutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.RelatedModel + fields = ( + "id", + "created", + "modified", + ) + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + fields = ( + ("related_model", RelatedAutoSchema), + ("related_model_nullable", RelatedAutoSchema), + ("related_models", RelatedAutoSchema), + ("custom_property_related_model", RelatedAutoSchema), + ("custom_property_related_model_nullable", RelatedAutoSchema), + ("custom_property_related_models", RelatedAutoSchema), + ) + + schema = AutoSchema.get_schema() + instance = ( + await repository.fetch( + id=test_model.id, + select_in_load=( + models.TestModel.related_model, + models.TestModel.related_model_nullable, + models.TestModel.related_models, + ), + ) + ).first() + model = schema.model_validate(instance) + isinstance(model.related_models, list) + for field in RelatedAutoSchema.Meta.fields: + assert getattr( + model.related_model, + field, + ) == getattr( + test_model.related_model, + field, + ) + for related_model, test_related_model in zip( + model.related_models, + test_model.related_models, + strict=False, + ): + assert getattr( + related_model, + field, + ) == getattr( + test_related_model, + field, + ) + + +async def test_auto_schema_use_both_config_and_model() -> None: + """Test schema generation fails when both config and model are used.""" + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + model_config = pydantic.ConfigDict( + from_attributes=True, + validate_assignment=True, + ) + base_model = pydantic.BaseModel + fields = ( + "id", + "created", + "modified", + ) + + with pytest.raises( + ValueError, + match=re.escape( + "Only config or base model could be passed to create_model", + ), + ): + AutoSchema.get_schema() + + +async def test_auto_schema_invalid_field_config() -> None: + """Test schema generation fails when both config and model are used.""" + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + model_config = pydantic.ConfigDict( + from_attributes=True, + validate_assignment=True, + ) + base_model = pydantic.BaseModel + fields = ( + "id", + "created", + "modified", + ) + + with pytest.raises( + ValueError, + match=re.escape( + "Only config or base model could be passed to create_model", + ), + ): + AutoSchema.get_schema() + + +def custom_validator( + cls, # noqa: ANN001 + value: typing.Any, + info: pydantic.ValidationInfo, +) -> None: + """Raise value error.""" + raise ValueError("This is custom validator") + + +def test_custom_field_validators( + test_model: models.TestModel, +) -> None: + """Test field validators for schema generation.""" + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + fields = ("id",) + extra_fields_validators: typing.ClassVar = { + "id": (custom_validator,), + } + + schema = AutoSchema.get_schema() + with pytest.raises( + pydantic.ValidationError, + match=re.escape("This is custom validator"), + ): + schema.model_validate(test_model) + + +def test_custom_field_validators_custom_type( + test_model: models.TestModel, +) -> None: + """Test field validators for schema generation(custom type case).""" + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + fields = (("id", int),) + extra_fields_validators: typing.ClassVar = { + "id": (custom_validator,), + } + + schema = AutoSchema.get_schema() + with pytest.raises( + pydantic.ValidationError, + match=re.escape("This is custom validator"), + ): + schema.model_validate(test_model) + + +def test_custom_field_validators_property( + test_model: models.TestModel, +) -> None: + """Test field validators for property for schema generation.""" + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + fields = ("custom_property_related_model",) + extra_fields_validators: typing.ClassVar = { + "custom_property_related_model": (custom_validator,), + } + + schema = AutoSchema.get_schema() + with pytest.raises( + pydantic.ValidationError, + match=re.escape("This is custom validator"), + ): + schema.model_validate(test_model) + + +def test_custom_field_validators_property_custom_type( + test_model: models.TestModel, +) -> None: + """Test validators for property for schema generation(custom type case).""" + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + fields = (("custom_property_related_model", typing.Any),) + extra_fields_validators: typing.ClassVar = { + "custom_property_related_model": (custom_validator,), + } + + schema = AutoSchema.get_schema() + with pytest.raises( + pydantic.ValidationError, + match=re.escape("This is custom validator"), + ): + schema.model_validate(test_model) diff --git a/tests/test_factories.py b/tests/test_factories.py new file mode 100644 index 0000000..7f165ea --- /dev/null +++ b/tests/test_factories.py @@ -0,0 +1,106 @@ +import pytest + +import saritasa_sqlalchemy_tools + +from . import factories, models, repositories + + +async def test_factory_missing_repository( + db_session: saritasa_sqlalchemy_tools.Session, + repository: repositories.TestModelRepository, +) -> None: + """Test that error is raise if repository class missing.""" + + class TestModelFactory( + saritasa_sqlalchemy_tools.AsyncSQLAlchemyModelFactory[ + models.TestModel + ], + ): + """Factory to generate TestModel.""" + + class Meta: + model = models.TestModel + + with pytest.raises( + ValueError, + match="Repository class in not set in Meta class", + ): + await TestModelFactory.create_async(db_session) + + +async def test_factory_generation( + db_session: saritasa_sqlalchemy_tools.Session, + repository: repositories.TestModelRepository, +) -> None: + """Test that model generation works as expected.""" + instance = await factories.TestModelFactory.create_async(db_session) + assert await repository.exists(id=instance.id) + + +async def test_factory_generation_sub_factories( + db_session: saritasa_sqlalchemy_tools.Session, + repository: repositories.TestModelRepository, +) -> None: + """Test that model generation works with sub factories as expected.""" + instance = await factories.TestModelFactory.create_async(db_session) + instance_created = ( + await repository.fetch( + id=instance.id, + select_in_load=( + models.TestModel.related_model, + models.TestModel.related_models, + ), + ) + ).first() + assert instance_created + assert instance_created.related_model_id + assert instance_created.related_model + assert len(instance_created.related_models) == 5 + + +async def test_factory_generation_skip_sub_factories( + db_session: saritasa_sqlalchemy_tools.Session, + repository: repositories.TestModelRepository, +) -> None: + """Test that sub_factory will be skipped if value is passed.""" + await factories.TestModelFactory.create_async( + db_session, + related_model=await factories.RelatedModelFactory.create_async( + db_session, + ), + related_models=await factories.RelatedModelFactory.create_batch_async( + db_session, + size=3, + ), + ) + assert await repositories.RelatedModelRepository(db_session).count() == 4 + + +async def test_factory_generation_skip_sub_factories_id_passed( + db_session: saritasa_sqlalchemy_tools.Session, + repository: repositories.TestModelRepository, +) -> None: + """Test that sub_factory will be skipped if fk id is passed.""" + await factories.TestModelFactory.create_async( + db_session, + related_model_id=( + await factories.RelatedModelFactory.create_async( + db_session, + ) + ).id, + related_models=[], + ) + assert await repositories.RelatedModelRepository(db_session).count() == 1 + + +async def test_factory_generation_batch( + db_session: saritasa_sqlalchemy_tools.Session, + repository: repositories.TestModelRepository, +) -> None: + """Test that model batch generation works as expected.""" + instances = await factories.TestModelFactory.create_batch_async( + db_session, + size=5, + ) + for instance in instances: + assert await repository.exists(id=instance.id) diff --git a/tests/test_ordering.py b/tests/test_ordering.py new file mode 100644 index 0000000..910a6db --- /dev/null +++ b/tests/test_ordering.py @@ -0,0 +1,43 @@ +import pytest +import sqlalchemy + +import saritasa_sqlalchemy_tools + + +@pytest.fixture +def generated_enum() -> type[saritasa_sqlalchemy_tools.OrderingEnum]: + """Generate enum for testing.""" + return saritasa_sqlalchemy_tools.OrderingEnum( # type: ignore + "GeneratedEnum", + [ + "field", + ], + ) + + +def test_ordering_enum_generated_values( + generated_enum: type[saritasa_sqlalchemy_tools.OrderingEnum], +) -> None: + """Test ordering enum generation. + + Check that reverse option are present. + + """ + assert hasattr(generated_enum, "field") + assert generated_enum.field == "field" + assert hasattr(generated_enum, "field_desc") + assert generated_enum.field_desc == "-field" + + +def test_ordering_enum_generated_sql_clause( + generated_enum: type[saritasa_sqlalchemy_tools.OrderingEnum], +) -> None: + """Test that ordering enum generates correct sql clause.""" + assert generated_enum.field.sql_clause == "field" + expected_db_clause: sqlalchemy.UnaryExpression[str] = sqlalchemy.desc( + "field", + ) + actual_db_clause = generated_enum.field_desc.sql_clause + assert actual_db_clause.__class__ == expected_db_clause.__class__ + assert actual_db_clause.modifier == expected_db_clause.modifier + assert actual_db_clause.operator == expected_db_clause.operator diff --git a/tests/test_repositories.py b/tests/test_repositories.py new file mode 100644 index 0000000..6923a9e --- /dev/null +++ b/tests/test_repositories.py @@ -0,0 +1,522 @@ +import pytest +import sqlalchemy + +import saritasa_sqlalchemy_tools + +from . import factories, models, repositories + + +async def test_init_other( + repository: repositories.TestModelRepository, +) -> None: + """Test init other.""" + new_repo = repository.init_other( + repositories.SoftDeleteTestModelRepository, + ) + assert isinstance(new_repo, repositories.SoftDeleteTestModelRepository) + assert new_repo.db_session is repository.db_session + + +async def test_soft_delete( + soft_delete_test_model: models.SoftDeleteTestModel, + soft_delete_repository: repositories.SoftDeleteTestModelRepository, +) -> None: + """Test soft deletion.""" + await soft_delete_repository.delete(soft_delete_test_model) + deleted = await soft_delete_repository.get(soft_delete_test_model.id) + assert deleted + assert deleted.deleted + + +async def test_force_soft_delete( + soft_delete_test_model: models.SoftDeleteTestModel, + soft_delete_repository: repositories.SoftDeleteTestModelRepository, +) -> None: + """Test force soft delete.""" + await soft_delete_repository.force_delete(soft_delete_test_model) + assert not await soft_delete_repository.get(soft_delete_test_model.id) + + +async def test_insert_batch( + related_model: models.RelatedModel, + repository: repositories.TestModelRepository, +) -> None: + """Test batch insert.""" + instances = factories.TestModelFactory.build_batch( + related_model_id=related_model.id, + size=5, + ) + created_instances = await repository.insert_batch(instances) + for created_instance in created_instances: + assert created_instance.id + assert created_instance.created + assert created_instance.modified + assert created_instance.related_model_id == related_model.id + + +async def test_update_batch( + test_models: list[models.TestModel], + related_model: models.RelatedModel, + repository: repositories.TestModelRepository, +) -> None: + """Test batch update.""" + for test_model in test_models: + test_model.related_model_id = related_model.id + await repository.update_batch(test_models) + for test_model in test_models: + pk = test_model.id + repository.expire(test_model) + test_model = await repository.get(pk) # type: ignore + assert test_model.related_model_id == related_model.id + + +async def test_delete_batch( + test_models: list[models.TestModel], + repository: repositories.TestModelRepository, +) -> None: + """Test batch delete.""" + ids = [test_model.id for test_model in test_models] + await repository.delete_batch(where=[models.TestModel.id.in_(ids)]) + assert not await repository.count(where=[models.TestModel.id.in_(ids)]) + + +async def test_save( + related_model: models.RelatedModel, + repository: repositories.TestModelRepository, +) -> None: + """Test that repository properly saves instances.""" + new_test_model = factories.TestModelFactory.build( + related_model_id=related_model.id, + ) + new_test_model = await repository.save(new_test_model, refresh=True) + assert (await repository.fetch(id=new_test_model.id)).first() + + +async def test_expire( + test_model: models.TestModel, + repository: repositories.TestModelRepository, +) -> None: + """Test that repository properly expires instances.""" + repository.expire(test_model) + with pytest.raises(sqlalchemy.exc.MissingGreenlet): + _ = test_model.id + + +@pytest.mark.parametrize( + "reuse_select_statement", + [ + True, + False, + ], +) +async def test_ordering( + test_models: list[models.TestModel], + repository: repositories.TestModelRepository, + reuse_select_statement: bool, +) -> None: + """Test ordering.""" + ordered_models = sorted( + test_models, + key=lambda instance: instance.text.lower(), + ) + if reuse_select_statement: + select_statement = repository.get_order_statement( + None, + *["text"], + ) + models_from_db = ( + await repository.fetch(statement=select_statement) + ).all() + else: + models_from_db = ( + await repository.fetch(ordering_clauses=["text"]) + ).all() + for actual, expected in zip(models_from_db, ordered_models, strict=False): + assert actual.id == expected.id + + +async def test_ordering_with_enum( + test_models: list[models.TestModel], + repository: repositories.TestModelRepository, +) -> None: + """Test ordering with enum.""" + ordered_models = sorted( + test_models, + key=lambda instance: instance.text.lower(), + reverse=True, + ) + + class OrderingEnum(saritasa_sqlalchemy_tools.OrderingEnum): + text = "text" + + models_from_db = ( + await repository.fetch(ordering_clauses=[OrderingEnum.text_desc]) + ).all() + for actual, expected in zip(models_from_db, ordered_models, strict=False): + assert actual.id == expected.id + + +@pytest.mark.parametrize( + "reuse_select_statement", + [ + True, + False, + ], +) +async def test_pagination( + test_models: list[models.TestModel], + repository: repositories.TestModelRepository, + reuse_select_statement: bool, +) -> None: + """Test pagination.""" + ordered_models = sorted( + test_models, + key=lambda instance: instance.id, + ) + args = { + "limit": 2, + "offset": 1, + } + if reuse_select_statement: + select_statement = repository.get_pagination_statement(**args) # type: ignore + models_from_db = ( + await repository.fetch( + statement=select_statement, + ordering_clauses=["id"], + ) + ).all() + else: + models_from_db = ( + await repository.fetch(ordering_clauses=["id"], **args) # type: ignore + ).all() + + assert models_from_db[0].id == ordered_models[1].id + assert models_from_db[1].id == ordered_models[2].id + + +@pytest.mark.parametrize( + "prefetch_type", + [ + "select_in_load", + "joined_load", + ], +) +@pytest.mark.parametrize( + "reuse_select_statement", + [ + True, + False, + ], +) +async def test_prefetch( + test_model: models.TestModel, + repository: repositories.TestModelRepository, + prefetch_type: str, + reuse_select_statement: bool, +) -> None: + """Test prefetching.""" + await factories.RelatedModelFactory.create_batch_async( + session=repository.db_session, + test_model_id=test_model.id, + size=5, + ) + await factories.TestModelFactory.create_batch_async( + session=repository.db_session, + related_model_id=test_model.related_model_id, + size=3, + ) + targets = ( + models.TestModel.related_model, + models.TestModel.related_model_nullable, + models.TestModel.related_models, + ( + models.TestModel.related_model, + models.RelatedModel.test_models, + ), + ) + args = { + "id": test_model.id, + prefetch_type: targets, + } + if reuse_select_statement: + select_statement = getattr( + repository, + f"get_{prefetch_type}_statement", + )(None, *targets) + instance = ( + await repository.fetch( + statement=select_statement, + id=test_model.id, + ) + ).first() + else: + instance = (await repository.fetch(**args)).first() # type: ignore + assert instance + assert instance.related_model + assert not instance.related_model_nullable + assert len(instance.related_models) == 5 + # Original plus created. + assert len(instance.related_model.test_models) == 4 + + +@pytest.mark.parametrize( + "reuse_select_statement", + [ + True, + False, + ], +) +async def test_filter_in( + test_models: list[models.TestModel], + repository: repositories.TestModelRepository, + reuse_select_statement: bool, +) -> None: + """Test filter `in`.""" + filtered_models = list( + filter( + lambda instance: instance.text + in [test_models[0].text, test_models[3].text], + test_models, + ), + ) + args = { + "where": [ + saritasa_sqlalchemy_tools.Filter( + field="text__in", + value=[test_models[0].text, test_models[3].text], + ), + ], + } + if reuse_select_statement: + select_statement = repository.get_filter_statement( + None, + *args["where"], + ) + instances = ( + await repository.fetch( + statement=select_statement, + ordering_clauses=["id"], + ) + ).all() + else: + instances = ( + await repository.fetch( + **args, # type: ignore + ordering_clauses=["id"], + ) + ).all() + assert instances[0].id == filtered_models[0].id + assert instances[1].id == filtered_models[1].id + + +@pytest.mark.parametrize( + "reuse_select_statement", + [ + True, + False, + ], +) +async def test_filter_gte( + test_models: list[models.TestModel], + repository: repositories.TestModelRepository, + reuse_select_statement: bool, +) -> None: + """Test filter `gte`.""" + max_num = max(test_model.number for test_model in test_models) + + args = { + "where": [ + saritasa_sqlalchemy_tools.Filter( + field="number__gte", + value=max_num, + ), + ], + } + if reuse_select_statement: + select_statement = repository.get_filter_statement( + None, + *args["where"], + ) + instances = ( + await repository.fetch( + statement=select_statement, + ordering_clauses=["id"], + ) + ).all() + else: + instances = ( + await repository.fetch( + **args, # type: ignore + ordering_clauses=["id"], + ) + ).all() + assert len(instances) == 1 + assert instances[0].number == max_num + + +@pytest.mark.parametrize( + "reuse_select_statement", + [ + True, + False, + ], +) +async def test_filter_m2m( + test_model: models.TestModel, + repository: repositories.TestModelRepository, + reuse_select_statement: bool, +) -> None: + """Test filter `m2m`.""" + await factories.M2MModelFactory.create_batch_async( + session=repository.db_session, + size=5, + test_model_id=test_model.id, + related_model_id=test_model.related_model_id, + ) + await factories.M2MModelFactory.create_batch_async( + session=repository.db_session, + size=5, + ) + + args = { + "where": [ + saritasa_sqlalchemy_tools.Filter( + field="m2m_related_model_id__in", + value=[test_model.related_model_id], + ), + ], + } + if reuse_select_statement: + select_statement = repository.get_filter_statement( + None, + *args["where"], + ) + instances = ( + ( + await repository.fetch( + statement=select_statement, + ordering_clauses=["id"], + ) + ) + .unique() + .all() + ) + else: + instances = ( + ( + await repository.fetch( + **args, # type: ignore + ordering_clauses=["id"], + ) + ) + .unique() + .all() + ) + assert len(instances) == 1 + assert instances[0].id == test_model.id + + +async def test_search_filter( + test_models: list[models.TestModel], + repository: repositories.TestModelRepository, +) -> None: + """Test search filter.""" + search_text = test_models[0].text + instances = ( + ( + await repository.fetch( + where=[ + saritasa_sqlalchemy_tools.transform_search_filter( + models.TestModel, + search_fields=[ + "text", + "id", + ], + value=search_text, + ), + ], + ordering_clauses=["id"], + ) + ) + .unique() + .all() + ) + assert len(instances) == 1 + assert instances[0].id == test_models[0].id + + +@pytest.mark.parametrize( + "reuse_select_statement", + [ + True, + False, + ], +) +async def test_annotation( + test_model: models.TestModel, + repository: repositories.TestModelRepository, + reuse_select_statement: bool, +) -> None: + """Test annotations.""" + await factories.RelatedModelFactory.create_batch_async( + session=repository.db_session, + size=5, + test_model_id=test_model.id, + ) + annotations = (models.TestModel.related_models_count,) + if reuse_select_statement: + select_statement = repository.get_annotated_statement( + None, + *annotations, + ) + instance = ( + await repository.fetch( + statement=select_statement, + id=test_model.id, + ) + ).first() + else: + instance = ( + await repository.fetch( + id=test_model.id, + annotations=annotations, + ) + ).first() + assert instance + assert ( + instance.related_models_count + == await repository.init_other( + repositories.RelatedModelRepository, + ).count() + ) + + +async def test_annotation_query( + test_model: models.TestModel, + repository: repositories.TestModelRepository, +) -> None: + """Test annotations query.""" + instance = ( + await repository.fetch( + id=test_model.id, + annotations=( + ( + models.TestModel.related_models_count_query, + sqlalchemy.select( + sqlalchemy.func.count(models.RelatedModel.id), + ) + .where( + models.RelatedModel.test_model_id + == models.TestModel.id, + ) + .correlate_except(models.RelatedModel) + .scalar_subquery(), + ), + ), + select_in_load=(models.TestModel.related_models,), + ) + ).first() + assert instance + assert instance.related_models_count_query == len( + test_model.related_models, + )