From 13b1bd498732cd444fc039e768ff9cc540a8d8d2 Mon Sep 17 00:00:00 2001 From: Stanislav Khlud Date: Fri, 15 Mar 2024 17:26:59 +0700 Subject: [PATCH] Add tests --- .github/workflows/checks.yaml | 2 + .pre-commit-config.yaml | 8 + docker-compose.yaml | 2 +- poetry.lock | 266 ++++++++- pyproject.toml | 51 +- saritasa_sqlalchemy_tools/__init__.py | 2 +- saritasa_sqlalchemy_tools/auto_schema.py | 22 +- .../repositories/__init__.py | 2 +- .../repositories/core.py | 80 +-- .../repositories/filters.py | 18 +- .../repositories/ordering.py | 8 +- saritasa_sqlalchemy_tools/session.py | 1 + .../testing/factories.py | 45 +- tests/__init__.py | 0 tests/conftest.py | 126 +++++ tests/factories.py | 120 ++++ tests/models.py | 351 ++++++++++++ tests/repositories.py | 75 +++ tests/test_auto_schema.py | 356 ++++++++++++ tests/test_factories.py | 120 ++++ tests/test_ordering.py | 43 ++ tests/test_repositories.py | 524 ++++++++++++++++++ 22 files changed, 2138 insertions(+), 84 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/factories.py create mode 100644 tests/models.py create mode 100644 tests/repositories.py create mode 100644 tests/test_auto_schema.py create mode 100644 tests/test_factories.py create mode 100644 tests/test_ordering.py create mode 100644 tests/test_repositories.py diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 9de7573..c054342 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -49,4 +49,6 @@ jobs: run: poetry install --no-interaction --all-extras - name: Run checks run: | + poetry run inv github-actions.set-up-hosts + poetry run inv docker.up poetry run inv pre-commit.run-hooks diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index adb783c..6969c2d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,3 +51,11 @@ repos: pass_filenames: false types: [ file ] stages: [ push ] + + - id: pytest + name: Run pytest + entry: inv pytest.run + language: system + pass_filenames: false + types: [ file ] + stages: [ push ] diff --git a/docker-compose.yaml b/docker-compose.yaml index ad7eda5..2a8b4b9 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,5 +1,5 @@ version: '3.7' -name: "saritasa-sqlachemy-tools" +name: "saritasa-sqlalchemy-tools" services: postgres: diff --git a/poetry.lock b/poetry.lock index 21ae784..435186a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -29,6 +29,59 @@ six = ">=1.12.0" astroid = ["astroid (>=1,<2)", "astroid (>=2,<4)"] test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"] +[[package]] +name = "asyncpg" +version = "0.28.0" +description = "An asyncio PostgreSQL driver" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "asyncpg-0.28.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a6d1b954d2b296292ddff4e0060f494bb4270d87fb3655dd23c5c6096d16d83"}, + {file = "asyncpg-0.28.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0740f836985fd2bd73dca42c50c6074d1d61376e134d7ad3ad7566c4f79f8184"}, + {file = "asyncpg-0.28.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e907cf620a819fab1737f2dd90c0f185e2a796f139ac7de6aa3212a8af96c050"}, + {file = "asyncpg-0.28.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86b339984d55e8202e0c4b252e9573e26e5afa05617ed02252544f7b3e6de3e9"}, + {file = "asyncpg-0.28.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0c402745185414e4c204a02daca3d22d732b37359db4d2e705172324e2d94e85"}, + {file = "asyncpg-0.28.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c88eef5e096296626e9688f00ab627231f709d0e7e3fb84bb4413dff81d996d7"}, + {file = "asyncpg-0.28.0-cp310-cp310-win32.whl", hash = "sha256:90a7bae882a9e65a9e448fdad3e090c2609bb4637d2a9c90bfdcebbfc334bf89"}, + {file = "asyncpg-0.28.0-cp310-cp310-win_amd64.whl", hash = "sha256:76aacdcd5e2e9999e83c8fbcb748208b60925cc714a578925adcb446d709016c"}, + {file = "asyncpg-0.28.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a0e08fe2c9b3618459caaef35979d45f4e4f8d4f79490c9fa3367251366af207"}, + {file = "asyncpg-0.28.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b24e521f6060ff5d35f761a623b0042c84b9c9b9fb82786aadca95a9cb4a893b"}, + {file = "asyncpg-0.28.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99417210461a41891c4ff301490a8713d1ca99b694fef05dabd7139f9d64bd6c"}, + {file = "asyncpg-0.28.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f029c5adf08c47b10bcdc857001bbef551ae51c57b3110964844a9d79ca0f267"}, + {file = "asyncpg-0.28.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ad1d6abf6c2f5152f46fff06b0e74f25800ce8ec6c80967f0bc789974de3c652"}, + {file = "asyncpg-0.28.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d7fa81ada2807bc50fea1dc741b26a4e99258825ba55913b0ddbf199a10d69d8"}, + {file = "asyncpg-0.28.0-cp311-cp311-win32.whl", hash = "sha256:f33c5685e97821533df3ada9384e7784bd1e7865d2b22f153f2e4bd4a083e102"}, + {file = "asyncpg-0.28.0-cp311-cp311-win_amd64.whl", hash = "sha256:5e7337c98fb493079d686a4a6965e8bcb059b8e1b8ec42106322fc6c1c889bb0"}, + {file = "asyncpg-0.28.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:1c56092465e718a9fdcc726cc3d9dcf3a692e4834031c9a9f871d92a75d20d48"}, + {file = "asyncpg-0.28.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4acd6830a7da0eb4426249d71353e8895b350daae2380cb26d11e0d4a01c5472"}, + {file = "asyncpg-0.28.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63861bb4a540fa033a56db3bb58b0c128c56fad5d24e6d0a8c37cb29b17c1c7d"}, + {file = "asyncpg-0.28.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:a93a94ae777c70772073d0512f21c74ac82a8a49be3a1d982e3f259ab5f27307"}, + {file = "asyncpg-0.28.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d14681110e51a9bc9c065c4e7944e8139076a778e56d6f6a306a26e740ed86d2"}, + {file = "asyncpg-0.28.0-cp37-cp37m-win32.whl", hash = "sha256:8aec08e7310f9ab322925ae5c768532e1d78cfb6440f63c078b8392a38aa636a"}, + {file = "asyncpg-0.28.0-cp37-cp37m-win_amd64.whl", hash = "sha256:319f5fa1ab0432bc91fb39b3960b0d591e6b5c7844dafc92c79e3f1bff96abef"}, + {file = "asyncpg-0.28.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b337ededaabc91c26bf577bfcd19b5508d879c0ad009722be5bb0a9dd30b85a0"}, + {file = "asyncpg-0.28.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4d32b680a9b16d2957a0a3cc6b7fa39068baba8e6b728f2e0a148a67644578f4"}, + {file = "asyncpg-0.28.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4f62f04cdf38441a70f279505ef3b4eadf64479b17e707c950515846a2df197"}, + {file = "asyncpg-0.28.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f20cac332c2576c79c2e8e6464791c1f1628416d1115935a34ddd7121bfc6a4"}, + {file = "asyncpg-0.28.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:59f9712ce01e146ff71d95d561fb68bd2d588a35a187116ef05028675462d5ed"}, + {file = "asyncpg-0.28.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fc9e9f9ff1aa0eddcc3247a180ac9e9b51a62311e988809ac6152e8fb8097756"}, + {file = "asyncpg-0.28.0-cp38-cp38-win32.whl", hash = "sha256:9e721dccd3838fcff66da98709ed884df1e30a95f6ba19f595a3706b4bc757e3"}, + {file = "asyncpg-0.28.0-cp38-cp38-win_amd64.whl", hash = "sha256:8ba7d06a0bea539e0487234511d4adf81dc8762249858ed2a580534e1720db00"}, + {file = "asyncpg-0.28.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d009b08602b8b18edef3a731f2ce6d3f57d8dac2a0a4140367e194eabd3de457"}, + {file = "asyncpg-0.28.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ec46a58d81446d580fb21b376ec6baecab7288ce5a578943e2fc7ab73bf7eb39"}, + {file = "asyncpg-0.28.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b48ceed606cce9e64fd5480a9b0b9a95cea2b798bb95129687abd8599c8b019"}, + {file = "asyncpg-0.28.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8858f713810f4fe67876728680f42e93b7e7d5c7b61cf2118ef9153ec16b9423"}, + {file = "asyncpg-0.28.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5e18438a0730d1c0c1715016eacda6e9a505fc5aa931b37c97d928d44941b4bf"}, + {file = "asyncpg-0.28.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e9c433f6fcdd61c21a715ee9128a3ca48be8ac16fa07be69262f016bb0f4dbd2"}, + {file = "asyncpg-0.28.0-cp39-cp39-win32.whl", hash = "sha256:41e97248d9076bc8e4849da9e33e051be7ba37cd507cbd51dfe4b2d99c70e3dc"}, + {file = "asyncpg-0.28.0-cp39-cp39-win_amd64.whl", hash = "sha256:3ed77f00c6aacfe9d79e9eff9e21729ce92a4b38e80ea99a58ed382f42ebd55b"}, + {file = "asyncpg-0.28.0.tar.gz", hash = "sha256:7252cdc3acb2f52feaa3664280d3bcd78a46bd6c10bfd681acfffefa1120e278"}, +] + +[package.extras] +docs = ["Sphinx (>=5.3.0,<5.4.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] +test = ["flake8 (>=5.0,<6.0)", "uvloop (>=0.15.3)"] + [[package]] name = "cfgv" version = "3.4.0" @@ -51,6 +104,70 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "coverage" +version = "7.4.4" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "coverage-7.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0be5efd5127542ef31f165de269f77560d6cdef525fffa446de6f7e9186cfb2"}, + {file = "coverage-7.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ccd341521be3d1b3daeb41960ae94a5e87abe2f46f17224ba5d6f2b8398016cf"}, + {file = "coverage-7.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09fa497a8ab37784fbb20ab699c246053ac294d13fc7eb40ec007a5043ec91f8"}, + {file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1a93009cb80730c9bca5d6d4665494b725b6e8e157c1cb7f2db5b4b122ea562"}, + {file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:690db6517f09336559dc0b5f55342df62370a48f5469fabf502db2c6d1cffcd2"}, + {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:09c3255458533cb76ef55da8cc49ffab9e33f083739c8bd4f58e79fecfe288f7"}, + {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8ce1415194b4a6bd0cdcc3a1dfbf58b63f910dcb7330fe15bdff542c56949f87"}, + {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b91cbc4b195444e7e258ba27ac33769c41b94967919f10037e6355e998af255c"}, + {file = "coverage-7.4.4-cp310-cp310-win32.whl", hash = "sha256:598825b51b81c808cb6f078dcb972f96af96b078faa47af7dfcdf282835baa8d"}, + {file = "coverage-7.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:09ef9199ed6653989ebbcaacc9b62b514bb63ea2f90256e71fea3ed74bd8ff6f"}, + {file = "coverage-7.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0f9f50e7ef2a71e2fae92774c99170eb8304e3fdf9c8c3c7ae9bab3e7229c5cf"}, + {file = "coverage-7.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:623512f8ba53c422fcfb2ce68362c97945095b864cda94a92edbaf5994201083"}, + {file = "coverage-7.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0513b9508b93da4e1716744ef6ebc507aff016ba115ffe8ecff744d1322a7b63"}, + {file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40209e141059b9370a2657c9b15607815359ab3ef9918f0196b6fccce8d3230f"}, + {file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a2b2b78c78293782fd3767d53e6474582f62443d0504b1554370bde86cc8227"}, + {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:73bfb9c09951125d06ee473bed216e2c3742f530fc5acc1383883125de76d9cd"}, + {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1f384c3cc76aeedce208643697fb3e8437604b512255de6d18dae3f27655a384"}, + {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:54eb8d1bf7cacfbf2a3186019bcf01d11c666bd495ed18717162f7eb1e9dd00b"}, + {file = "coverage-7.4.4-cp311-cp311-win32.whl", hash = "sha256:cac99918c7bba15302a2d81f0312c08054a3359eaa1929c7e4b26ebe41e9b286"}, + {file = "coverage-7.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:b14706df8b2de49869ae03a5ccbc211f4041750cd4a66f698df89d44f4bd30ec"}, + {file = "coverage-7.4.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:201bef2eea65e0e9c56343115ba3814e896afe6d36ffd37bab783261db430f76"}, + {file = "coverage-7.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:41c9c5f3de16b903b610d09650e5e27adbfa7f500302718c9ffd1c12cf9d6818"}, + {file = "coverage-7.4.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d898fe162d26929b5960e4e138651f7427048e72c853607f2b200909794ed978"}, + {file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ea79bb50e805cd6ac058dfa3b5c8f6c040cb87fe83de10845857f5535d1db70"}, + {file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce4b94265ca988c3f8e479e741693d143026632672e3ff924f25fab50518dd51"}, + {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:00838a35b882694afda09f85e469c96367daa3f3f2b097d846a7216993d37f4c"}, + {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fdfafb32984684eb03c2d83e1e51f64f0906b11e64482df3c5db936ce3839d48"}, + {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:69eb372f7e2ece89f14751fbcbe470295d73ed41ecd37ca36ed2eb47512a6ab9"}, + {file = "coverage-7.4.4-cp312-cp312-win32.whl", hash = "sha256:137eb07173141545e07403cca94ab625cc1cc6bc4c1e97b6e3846270e7e1fea0"}, + {file = "coverage-7.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:d71eec7d83298f1af3326ce0ff1d0ea83c7cb98f72b577097f9083b20bdaf05e"}, + {file = "coverage-7.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d5ae728ff3b5401cc320d792866987e7e7e880e6ebd24433b70a33b643bb0384"}, + {file = "coverage-7.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc4f1358cb0c78edef3ed237ef2c86056206bb8d9140e73b6b89fbcfcbdd40e1"}, + {file = "coverage-7.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8130a2aa2acb8788e0b56938786c33c7c98562697bf9f4c7d6e8e5e3a0501e4a"}, + {file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf271892d13e43bc2b51e6908ec9a6a5094a4df1d8af0bfc360088ee6c684409"}, + {file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4cdc86d54b5da0df6d3d3a2f0b710949286094c3a6700c21e9015932b81447e"}, + {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ae71e7ddb7a413dd60052e90528f2f65270aad4b509563af6d03d53e979feafd"}, + {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:38dd60d7bf242c4ed5b38e094baf6401faa114fc09e9e6632374388a404f98e7"}, + {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa5b1c1bfc28384f1f53b69a023d789f72b2e0ab1b3787aae16992a7ca21056c"}, + {file = "coverage-7.4.4-cp38-cp38-win32.whl", hash = "sha256:dfa8fe35a0bb90382837b238fff375de15f0dcdb9ae68ff85f7a63649c98527e"}, + {file = "coverage-7.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:b2991665420a803495e0b90a79233c1433d6ed77ef282e8e152a324bbbc5e0c8"}, + {file = "coverage-7.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3b799445b9f7ee8bf299cfaed6f5b226c0037b74886a4e11515e569b36fe310d"}, + {file = "coverage-7.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b4d33f418f46362995f1e9d4f3a35a1b6322cb959c31d88ae56b0298e1c22357"}, + {file = "coverage-7.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aadacf9a2f407a4688d700e4ebab33a7e2e408f2ca04dbf4aef17585389eff3e"}, + {file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c95949560050d04d46b919301826525597f07b33beba6187d04fa64d47ac82e"}, + {file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff7687ca3d7028d8a5f0ebae95a6e4827c5616b31a4ee1192bdfde697db110d4"}, + {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5fc1de20b2d4a061b3df27ab9b7c7111e9a710f10dc2b84d33a4ab25065994ec"}, + {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c74880fc64d4958159fbd537a091d2a585448a8f8508bf248d72112723974cbd"}, + {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:742a76a12aa45b44d236815d282b03cfb1de3b4323f3e4ec933acfae08e54ade"}, + {file = "coverage-7.4.4-cp39-cp39-win32.whl", hash = "sha256:d89d7b2974cae412400e88f35d86af72208e1ede1a541954af5d944a8ba46c57"}, + {file = "coverage-7.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:9ca28a302acb19b6af89e90f33ee3e1906961f94b54ea37de6737b7ca9d8827c"}, + {file = "coverage-7.4.4-pp38.pp39.pp310-none-any.whl", hash = "sha256:b2c5edc4ac10a7ef6605a966c58929ec6c1bd0917fb8c15cb3363f65aa40e677"}, + {file = "coverage-7.4.4.tar.gz", hash = "sha256:c901df83d097649e257e803be22592aedfd5182f07b3cc87d640bbb9afd50f49"}, +] + +[package.extras] +toml = ["tomli"] + [[package]] name = "decorator" version = "5.1.1" @@ -107,13 +224,13 @@ doc = ["Sphinx", "sphinx-rtd-theme", "sphinxcontrib-spelling"] [[package]] name = "faker" -version = "24.2.0" +version = "24.3.0" description = "Faker is a Python package that generates fake data for you." optional = true python-versions = ">=3.8" files = [ - {file = "Faker-24.2.0-py3-none-any.whl", hash = "sha256:dce4754921f9fa7e2003c26834093361b8f45072e0f46f172d6ca1234774ecd4"}, - {file = "Faker-24.2.0.tar.gz", hash = "sha256:87d5e7730426e7b36817921679c4eaf3d810cedb8c81194f47adc3df2122ca18"}, + {file = "Faker-24.3.0-py3-none-any.whl", hash = "sha256:9978025e765ba79f8bf6154c9630a9c2b7f9c9b0f175d4ad5e04b19a82a8d8d6"}, + {file = "Faker-24.3.0.tar.gz", hash = "sha256:5fb5aa9749d09971e04a41281ae3ceda9414f683d4810a694f8a8eebb8f9edec"}, ] [package.dependencies] @@ -220,6 +337,17 @@ files = [ [package.extras] license = ["ukkonen"] +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + [[package]] name = "invoke" version = "2.2.0" @@ -420,6 +548,17 @@ files = [ [package.dependencies] setuptools = "*" +[[package]] +name = "packaging" +version = "24.0" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.7" +files = [ + {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, + {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, +] + [[package]] name = "parso" version = "0.8.3" @@ -464,6 +603,21 @@ files = [ docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] +[[package]] +name = "pluggy" +version = "1.4.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.4.0-py3-none-any.whl", hash = "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981"}, + {file = "pluggy-1.4.0.tar.gz", hash = "sha256:8c85c2876142a764e5b7548e7d9a0e0ddb46f5185161049a79b7e974454223be"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + [[package]] name = "pre-commit" version = "3.6.2" @@ -646,6 +800,96 @@ files = [ plugins = ["importlib-metadata"] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pytest" +version = "8.1.1" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-8.1.1-py3-none-any.whl", hash = "sha256:2a8386cfc11fa9d2c50ee7b2a57e7d898ef90470a7a34c4b949ff59662bb78b7"}, + {file = "pytest-8.1.1.tar.gz", hash = "sha256:ac978141a75948948817d360297b7aae0fcb9d6ff6bc9ec6d514b85d5a65c044"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=1.4,<2.0" + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "pytest-async-sqlalchemy" +version = "0.2.0" +description = "Database testing fixtures using the SQLAlchemy asyncio API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pytest-async-sqlalchemy-0.2.0.tar.gz", hash = "sha256:0dcf80fdff1ea0046834cff2bc100c82d159e45a7ae21545a6ba9119a962b9d7"}, + {file = "pytest_async_sqlalchemy-0.2.0-py3-none-any.whl", hash = "sha256:60d7159f43d21e79d7051841fd2d6e094b7267ddc8d7192daea597afca938b12"}, +] + +[package.dependencies] +pytest = ">=6.0.0" +sqlalchemy = ">=1.4.0" + +[[package]] +name = "pytest-asyncio" +version = "0.21.1" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-asyncio-0.21.1.tar.gz", hash = "sha256:40a7eae6dded22c7b604986855ea48400ab15b069ae38116e8c01238e9eeb64d"}, + {file = "pytest_asyncio-0.21.1-py3-none-any.whl", hash = "sha256:8666c1c8ac02631d7c51ba282e0c69a8a452b211ffedf2599099845da5c5c37b"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] + +[[package]] +name = "pytest-cov" +version = "4.1.0" +description = "Pytest plugin for measuring coverage." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"}, + {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"}, +] + +[package.dependencies] +coverage = {version = ">=5.2.1", extras = ["toml"]} +pytest = ">=4.6" + +[package.extras] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] + +[[package]] +name = "pytest-sugar" +version = "1.0.0" +description = "pytest-sugar is a plugin for pytest that changes the default look and feel of pytest (e.g. progressbar, show tests that fail instantly)." +optional = false +python-versions = "*" +files = [ + {file = "pytest-sugar-1.0.0.tar.gz", hash = "sha256:6422e83258f5b0c04ce7c632176c7732cab5fdb909cb39cca5c9139f81276c0a"}, + {file = "pytest_sugar-1.0.0-py3-none-any.whl", hash = "sha256:70ebcd8fc5795dc457ff8b69d266a4e2e8a74ae0c3edc749381c64b5246c8dfd"}, +] + +[package.dependencies] +packaging = ">=21.3" +pytest = ">=6.2.0" +termcolor = ">=2.1.0" + +[package.extras] +dev = ["black", "flake8", "pre-commit"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -879,6 +1123,20 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "termcolor" +version = "2.4.0" +description = "ANSI color formatting for output in terminal" +optional = false +python-versions = ">=3.8" +files = [ + {file = "termcolor-2.4.0-py3-none-any.whl", hash = "sha256:9297c0df9c99445c2412e832e882a7884038a25617c60cea2ad69488d4040d63"}, + {file = "termcolor-2.4.0.tar.gz", hash = "sha256:aab9e56047c8ac41ed798fa36d892a37aca6b3e9159f3e0c24bc64a9b3ac7b7a"}, +] + +[package.extras] +tests = ["pytest", "pytest-cov"] + [[package]] name = "traitlets" version = "5.14.2" @@ -943,4 +1201,4 @@ factories = ["factory-boy"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "631ce82692e7ea81db4a9c084b8def42262c4b5ee9a43d2459fba55631c70758" +content-hash = "a6de404e0099fbd55465a84bbd6472893f80344fadb18803b671bebbaddba449" diff --git a/pyproject.toml b/pyproject.toml index 2a514de..7d67c5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,25 @@ saritasa_invocations = "^1.1.0" # https://mypy.readthedocs.io/en/stable/ mypy = "^1.9.0" +[tool.poetry.group.test.dependencies] +pytest = "^8.1.1" +# pytest-asyncio is a pytest plugin. It facilitates testing of code that +# uses the asyncio library. +# https://pytest-asyncio.readthedocs.io/en/latest/ +pytest-asyncio = "^0.21.1" +# Database testing fixtures using the SQLAlchemy asyncio API +# https://pypi.org/project/pytest-async-sqlalchemy/ +pytest-async-sqlalchemy = "^0.2.0" +# To prettify pytest output +pytest-sugar = "^1.0.0" +# Coverage plugin for pytest. +# https://github.com/pytest-dev/pytest-cov +pytest-cov = "^4.1.0" +# asyncpg is a database interface library designed specifically for PostgreSQL +# and Python/asyncio. +# https://magicstack.github.io/asyncpg/current/ +asyncpg = "^0.28.0" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" @@ -171,7 +190,7 @@ section-order = [ sqlachemy = ["sqlachemy"] [tool.ruff.lint.flake8-pytest-style] -fixture-parentheses = true +fixture-parentheses = false parametrize-names-type = "list" parametrize-values-type = "list" parametrize-values-row-type = "list" @@ -210,3 +229,33 @@ ignore = [ "**/*test_*.py", "invocations/**" ] + +[tool.pytest.ini_options] +# --capture=no +# allow use of ipdb during tests +# --ff +# run last failed tests first +addopts = [ + "--capture=no", + "--ff", + "--cov=saritasa_sqlalchemy_tools", + "--cov-report=html", +] +# skip all files inside following dirs +norecursedirs = [ + "venv", + ".venv", +] +asyncio_mode = "auto" + +[tool.coverage.run] +omit = [ + "saritasa_sqlalchemy_tools/session.py", +] + +# https://docformatter.readthedocs.io/en/latest/configuration.html# +[tool.docformatter] +wrap-descriptions=0 +in-place=true +blank=true +black=true diff --git a/saritasa_sqlalchemy_tools/__init__.py b/saritasa_sqlalchemy_tools/__init__.py index c4a0852..b5a131a 100644 --- a/saritasa_sqlalchemy_tools/__init__.py +++ b/saritasa_sqlalchemy_tools/__init__.py @@ -32,7 +32,7 @@ Filter, LazyLoaded, LazyLoadedSequence, - OrderingClausesT, + OrderingClauses, OrderingEnum, OrderingEnumMeta, SQLWhereFilter, diff --git a/saritasa_sqlalchemy_tools/auto_schema.py b/saritasa_sqlalchemy_tools/auto_schema.py index 0d58257..c482eca 100644 --- a/saritasa_sqlalchemy_tools/auto_schema.py +++ b/saritasa_sqlalchemy_tools/auto_schema.py @@ -7,6 +7,7 @@ import pydantic import pydantic_core import sqlalchemy.dialects.postgresql.ranges +import sqlalchemy.orm from . import models @@ -148,7 +149,7 @@ def get_schema( pydantic.field_validator(field)(validator) ) continue - if isinstance(field, tuple): + if isinstance(field, tuple) and len(field) == 2: field_name, field_type = field generated_fields[field_name] = ( cls._generate_field_with_custom_type( @@ -159,7 +160,7 @@ def get_schema( ) ) for index, validator in enumerate( - extra_fields_validators.get(field, ()), + extra_fields_validators.get(field_name, ()), ): validators[f"{field_name}_validator_{index}"] = ( pydantic.field_validator(field_name)(validator) @@ -197,7 +198,14 @@ def _generate_field( model_attribute, extra_field_config, ) - if model_attribute.type.__class__ not in types_mapping: + if isinstance(model_attribute.property, sqlalchemy.orm.Relationship): + raise UnableProcessTypeError( + "Schema generation is not supported for relationship " + f"fields({field}), please use auto-schema or pydantic class", + ) + if ( + model_attribute.type.__class__ not in types_mapping + ): # pragma: no cover raise UnableProcessTypeError( "Can't generate generate type for" f" {model_attribute.type.__class__}" @@ -443,7 +451,7 @@ def _generate_enum_field( ) -> PydanticFieldConfig: """Generate enum field.""" if model_type.enum_class is None: # type: ignore - raise UnableToExtractEnumClassError( + raise UnableToExtractEnumClassError( # pragma: no cover f"Can't extract enum for {field} in {model}", ) return ( @@ -517,7 +525,11 @@ def _generate_array_field( model_type.item_type, # type: ignore extra_field_config, ) - return list[list_type], pydantic_core.PydanticUndefined # type: ignore + return ( + list[list_type] | None # type: ignore + if model_attribute.nullable + else list[list_type] # type: ignore + ), pydantic_core.PydanticUndefined ModelAutoSchemaT = typing.TypeVar( diff --git a/saritasa_sqlalchemy_tools/repositories/__init__.py b/saritasa_sqlalchemy_tools/repositories/__init__.py index ef19cbc..ca76145 100644 --- a/saritasa_sqlalchemy_tools/repositories/__init__.py +++ b/saritasa_sqlalchemy_tools/repositories/__init__.py @@ -6,7 +6,7 @@ WhereFilters, transform_search_filter, ) -from .ordering import OrderingClausesT, OrderingEnum, OrderingEnumMeta +from .ordering import OrderingClauses, OrderingEnum, OrderingEnumMeta from .types import ( Annotation, AnnotationSequence, diff --git a/saritasa_sqlalchemy_tools/repositories/core.py b/saritasa_sqlalchemy_tools/repositories/core.py index b43f85d..8109132 100644 --- a/saritasa_sqlalchemy_tools/repositories/core.py +++ b/saritasa_sqlalchemy_tools/repositories/core.py @@ -36,22 +36,18 @@ def __init__( self, db_session: sqlalchemy.ext.asyncio.AsyncSession, ) -> None: - self._db_session = db_session + self.db_session = db_session def init_other( self, repository_class: type[BaseRepositoryT], ) -> BaseRepositoryT: """Init other repo from current.""" - return repository_class(db_session=self._db_session) - - async def commit(self) -> None: - """Commit transaction.""" - await self._db_session.commit() + return repository_class(db_session=self.db_session) async def flush(self) -> None: """Perform changes to database.""" - await self._db_session.flush() + await self.db_session.flush() async def refresh( self, @@ -59,7 +55,7 @@ async def refresh( attribute_names: collections.abc.Sequence[str] | None = None, ) -> None: """Refresh instance.""" - await self._db_session.refresh( + await self.db_session.refresh( instance=instance, attribute_names=attribute_names, ) @@ -71,11 +67,11 @@ def expire(self, instance: models.BaseModelT) -> None: fetched from db again. """ - self._db_session.expire(instance) + self.db_session.expire(instance) async def get(self, pk: int | str) -> models.BaseModelT | None: """Return entry from DB by primary key.""" - return await self._db_session.get(self.model, pk) + return await self.db_session.get(self.model, pk) async def save( self, @@ -84,7 +80,7 @@ async def save( attribute_names: collections.abc.Sequence[str] | None = None, ) -> models.BaseModelT: """Save model instance into db.""" - self._db_session.add(instance=instance) + self.db_session.add(instance=instance) await self.flush() if refresh: await self.refresh(instance, attribute_names) @@ -92,7 +88,7 @@ async def save( async def delete(self, instance: models.BaseModelT) -> None: """Delete model instance into db.""" - await self._db_session.delete(instance=instance) + await self.db_session.delete(instance=instance) await self.flush() async def delete_batch( @@ -101,7 +97,7 @@ async def delete_batch( **filters_by: typing.Any, ) -> None: """Delete batch of objects from db.""" - await self._db_session.execute( + await self.db_session.execute( statement=( sqlalchemy.sql.delete(self.model) .where(*self.process_where_filters(*where)) @@ -141,7 +137,7 @@ async def insert_batch( exclude_fields: collections.abc.Sequence[str] = (), ) -> list[models.BaseModelT]: """Create batch of objects in db.""" - if not objects: + if not objects: # pragma: no cover return [] objects_as_dict = self.objects_as_dict( @@ -150,7 +146,7 @@ async def insert_batch( exclude_fields or self.default_exclude_bulk_create_fields ), ) - created_objects = await self._db_session.scalars( + created_objects = await self.db_session.scalars( sqlalchemy.sql.insert(self.model) .returning(self.model) .values(objects_as_dict), @@ -164,7 +160,7 @@ async def update_batch( exclude_fields: collections.abc.Sequence[str] = (), ) -> None: """Update batch of objects in db.""" - if not objects: + if not objects: # pragma: no cover return objects_as_dict = self.objects_as_dict( @@ -173,7 +169,7 @@ async def update_batch( exclude_fields or self.default_exclude_bulk_update_fields ), ) - await self._db_session.execute( + await self.db_session.execute( sqlalchemy.sql.update(self.model), objects_as_dict, ) @@ -208,7 +204,7 @@ def get_annotated_statement( def get_filter_statement( self, statement: models.SelectStatement[models.BaseModelT] | None = None, - *where_filters: filters.WhereFilter, + *where: filters.WhereFilter, **filters_by: typing.Any, ) -> models.SelectStatement[models.BaseModelT]: """Get statement with filtering.""" @@ -217,17 +213,17 @@ def get_filter_statement( else: select_statement = self.select_statement return select_statement.where( - *self.process_where_filters(*where_filters), + *self.process_where_filters(*where), ).filter_by(**filters_by) @classmethod def process_where_filters( cls, - *where_filters: filters.WhereFilter, + *where: filters.WhereFilter, ) -> collections.abc.Sequence[filters.SQLWhereFilter]: """Process where filters.""" processed_where_filters: list[filters.SQLWhereFilter] = [] - for where_filter in where_filters: + for where_filter in where: if isinstance(where_filter, filters.Filter): processed_where_filters.append( where_filter.transform_filter(cls.model), # type: ignore @@ -239,22 +235,30 @@ def process_where_filters( def get_order_statement( self, statement: models.SelectStatement[models.BaseModelT] | None = None, - *clauses: sqlalchemy.ColumnExpressionArgument[str] | str, + *clauses: ordering.OrderingClause, ) -> models.SelectStatement[models.BaseModelT]: """Get statement with ordering.""" if statement is not None: select_statement = statement else: select_statement = self.select_statement - ordering_clauses = [ - ( - clause.db_clause - if isinstance(clause, ordering.OrderingEnum) - else clause - ) - for clause in clauses - ] - return select_statement.order_by(*ordering_clauses) + return select_statement.order_by( + *self.process_ordering_clauses(*clauses), + ) + + @classmethod + def process_ordering_clauses( + cls, + *clauses: ordering.OrderingClause, + ) -> collections.abc.Sequence[ordering.SQLOrderingClause]: + """Process ordering clauses.""" + processed_ordering_clauses: list[ordering.SQLOrderingClause] = [] + for clause in clauses: + if isinstance(clause, ordering.OrderingEnum): + processed_ordering_clauses.append(clause.sql_clause) + else: + processed_ordering_clauses.append(clause) + return processed_ordering_clauses def get_pagination_statement( self, @@ -323,7 +327,7 @@ def get_fetch_statement( joined_load: types.LazyLoadedSequence = (), select_in_load: types.LazyLoadedSequence = (), annotations: types.AnnotationSequence = (), - clauses: ordering.OrderingClausesT = (), + ordering_clauses: ordering.OrderingClauses = (), where: filters.WhereFilters = (), **filters_by: typing.Any, ) -> models.SelectStatement[models.BaseModelT]: @@ -342,7 +346,7 @@ def get_fetch_statement( ) statement = self.get_order_statement( statement, - *clauses, + *ordering_clauses, ) statement = self.get_filter_statement( statement, @@ -364,12 +368,12 @@ async def fetch( joined_load: types.LazyLoadedSequence = (), select_in_load: types.LazyLoadedSequence = (), annotations: types.AnnotationSequence = (), - clauses: ordering.OrderingClausesT = (), + ordering_clauses: ordering.OrderingClauses = (), where: filters.WhereFilters = (), **filters_by: typing.Any, ) -> sqlalchemy.ScalarResult[models.BaseModelT]: """Fetch entries.""" - return await self._db_session.scalars( + return await self.db_session.scalars( statement=self.get_fetch_statement( statement=statement, offset=offset, @@ -377,7 +381,7 @@ async def fetch( joined_load=joined_load, select_in_load=select_in_load, annotations=annotations, - clauses=clauses, + ordering_clauses=ordering_clauses, where=where, **filters_by, ), @@ -390,7 +394,7 @@ async def count( ) -> int: """Get count of entries.""" return ( - await self._db_session.scalar( + await self.db_session.scalar( sqlalchemy.select(sqlalchemy.func.count()) .select_from(self.model) .where(*self.process_where_filters(*where)) @@ -405,7 +409,7 @@ async def exists( ) -> bool: """Check existence of entries.""" return ( - await self._db_session.scalar( + await self.db_session.scalar( sqlalchemy.select( sqlalchemy.sql.exists( self.select_statement.where( diff --git a/saritasa_sqlalchemy_tools/repositories/filters.py b/saritasa_sqlalchemy_tools/repositories/filters.py index 63f77bc..e99ea4f 100644 --- a/saritasa_sqlalchemy_tools/repositories/filters.py +++ b/saritasa_sqlalchemy_tools/repositories/filters.py @@ -38,7 +38,7 @@ class Filter: """Define filter value.""" - api_filter: str + field: str value: FilterType def transform_filter( @@ -46,7 +46,7 @@ def transform_filter( model: type[models.BaseModelT], ) -> SQLWhereFilter: """Transform filter valid for sqlalchemy.""" - field_name, filter_arg = self.api_filter.split("__") + field_name, filter_arg = self.field.split("__") if field_name in model.m2m_filters: return self.transform_m2m_filter( field_name=field_name, @@ -114,20 +114,6 @@ def transform_simple_filter( filter_operator = getattr(field, filter_args_mapping[filter_arg])( value, ) - if ( - filter_arg - in ( - "gt", - "gte", - "lt", - "lte", - ) - and field.nullable - ): - filter_operator = sqlalchemy.or_( - filter_operator, - field.is_(None), - ) return filter_operator diff --git a/saritasa_sqlalchemy_tools/repositories/ordering.py b/saritasa_sqlalchemy_tools/repositories/ordering.py index fd61cea..d221e3e 100644 --- a/saritasa_sqlalchemy_tools/repositories/ordering.py +++ b/saritasa_sqlalchemy_tools/repositories/ordering.py @@ -25,13 +25,13 @@ class OrderingEnum(enum.StrEnum, metaclass=OrderingEnumMeta): """Representation of ordering fields.""" @property - def db_clause(self) -> str | sqlalchemy.ColumnExpressionArgument[str]: + def sql_clause(self) -> "SQLOrderingClause": """Convert ordering value to sqlalchemy ordering clause.""" if self.startswith("-"): return sqlalchemy.desc(self[1:]) return self -OrderingClausesT: typing.TypeAlias = collections.abc.Sequence[ - str | sqlalchemy.ColumnExpressionArgument[str] | OrderingEnum -] +SQLOrderingClause = str | sqlalchemy.ColumnExpressionArgument[str] +OrderingClause = SQLOrderingClause | OrderingEnum +OrderingClauses: typing.TypeAlias = collections.abc.Sequence[OrderingClause] diff --git a/saritasa_sqlalchemy_tools/session.py b/saritasa_sqlalchemy_tools/session.py index 3e047d8..05d9730 100644 --- a/saritasa_sqlalchemy_tools/session.py +++ b/saritasa_sqlalchemy_tools/session.py @@ -1,3 +1,4 @@ +# pragma: no cover import collections.abc import contextlib import typing diff --git a/saritasa_sqlalchemy_tools/testing/factories.py b/saritasa_sqlalchemy_tools/testing/factories.py index 020c283..7113a65 100644 --- a/saritasa_sqlalchemy_tools/testing/factories.py +++ b/saritasa_sqlalchemy_tools/testing/factories.py @@ -48,7 +48,7 @@ async def create_async( None, ) if not repository_class: - raise ValueError("Repository class in not set in Meta class") + raise ValueError("Repository class is not set in Meta class") repository = repository_class(db_session=session) pk_attr: str = instance.pk_field @@ -61,7 +61,9 @@ async def create_async( ) ).first() if not instance_from_db: - raise ValueError("Created instance wasn't found in database") + raise ValueError( # pragma: no cover + "Created instance wasn't found in database", + ) return instance_from_db @classmethod @@ -87,23 +89,40 @@ async def _async_run_sub_factories( cls, session: session.Session, passed_fields: collections.abc.Sequence[str], - ) -> dict[str, models.BaseModel]: + ) -> dict[str, models.BaseModel | list[models.BaseModel]]: """Generate objects from sub factories.""" - sub_factories_map: dict[str, str] = getattr( + sub_factories_map: dict[str, str | tuple[str, int]] = getattr( cls._meta, "sub_factories", {}, ) - generated_instances: dict[str, models.BaseModel] = {} + generated_instances: dict[ + str, + models.BaseModel | list[models.BaseModel], + ] = {} for field, sub_factory_path in sub_factories_map.items(): if field in passed_fields or f"{field}_id" in passed_fields: continue - *factory_module, sub_factory_name = sub_factory_path.split(".") - sub_factory: typing.Self = getattr( - importlib.import_module(".".join(factory_module)), - sub_factory_name, - ) - generated_instances[field] = await sub_factory.create_async( - session=session, - ) + if isinstance(sub_factory_path, str): + *factory_module, sub_factory_name = sub_factory_path.split(".") + sub_factory = getattr( + importlib.import_module(".".join(factory_module)), + sub_factory_name, + ) + generated_instances[field] = await sub_factory.create_async( + session=session, + ) + else: + sub_factory_path, size = sub_factory_path + *factory_module, sub_factory_name = sub_factory_path.split(".") + sub_factory = getattr( + importlib.import_module(".".join(factory_module)), + sub_factory_name, + ) + generated_instances[ + field + ] = await sub_factory.create_batch_async( + session=session, + size=size, + ) return generated_instances diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..77ecbab --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,126 @@ +import asyncio +import collections.abc +import typing + +import pytest +import sqlalchemy + +import saritasa_sqlalchemy_tools + +from . import factories, models, repositories + + +@pytest.fixture(scope="session") +def event_loop() -> ( + collections.abc.Generator[ + asyncio.AbstractEventLoop, + typing.Any, + None, + ] +): + """Override `event_loop` fixture to change scope to `session`. + + Needed for pytest-async-sqlalchemy. + + """ + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="session") +def database_url(request: pytest.FixtureRequest) -> str: + """Override database url. + + Grab configs from settings and add support for pytest-xdist + + """ + worker_input = getattr( + request.config, + "workerinput", + { + "workerid": "", + }, + ) + return sqlalchemy.engine.URL( + drivername="postgresql+asyncpg", + username="saritasa-sqlalchemy-tools-user", + password="manager", + host="postgres", + port=5432, + database="_".join( + filter( + None, + ( + "saritasa-sqlalchemy-tools-dev", + "test", + worker_input["workerid"], + ), + ), + ), + query={}, # type: ignore + ).render_as_string(hide_password=False) + + +@pytest.fixture(scope="session") +def init_database() -> collections.abc.Callable[..., None]: + """Return callable object that will be called to init database. + + Overridden fixture from `pytest-async-sqlalchemy package`. + https://github.com/igortg/pytest-async-sqlalchemy + + """ + return saritasa_sqlalchemy_tools.BaseModel.metadata.create_all + + +@pytest.fixture +async def test_model( + db_session: saritasa_sqlalchemy_tools.Session, +) -> models.TestModel: + """Generate test_model instance.""" + return await factories.TestModelFactory.create_async(session=db_session) + + +@pytest.fixture +async def related_model( + db_session: saritasa_sqlalchemy_tools.Session, +) -> models.RelatedModel: + """Generate test_model instance.""" + return await factories.RelatedModelFactory.create_async(session=db_session) + + +@pytest.fixture +async def test_model_list( + db_session: saritasa_sqlalchemy_tools.Session, +) -> collections.abc.Sequence[models.TestModel]: + """Generate test_model instances.""" + return await factories.TestModelFactory.create_batch_async( + session=db_session, + size=5, + ) + + +@pytest.fixture +async def repository( + db_session: saritasa_sqlalchemy_tools.Session, +) -> repositories.TestModelRepository: + """Get repository for `TestModel`.""" + return repositories.TestModelRepository(db_session=db_session) + + +@pytest.fixture +async def soft_delete_test_model( + db_session: saritasa_sqlalchemy_tools.Session, +) -> models.SoftDeleteTestModel: + """Generate `SoftDeleteTestModel` instance.""" + return await factories.SoftDeleteTestModelFactory.create_async( + session=db_session, + ) + + +@pytest.fixture +async def soft_delete_repository( + db_session: saritasa_sqlalchemy_tools.Session, +) -> repositories.SoftDeleteTestModelRepository: + """Get soft delete repository.""" + return repositories.SoftDeleteTestModelRepository(db_session=db_session) diff --git a/tests/factories.py b/tests/factories.py new file mode 100644 index 0000000..ff881b1 --- /dev/null +++ b/tests/factories.py @@ -0,0 +1,120 @@ +import typing + +import factory +import factory.fuzzy + +import saritasa_sqlalchemy_tools + +from . import models, repositories + + +class TestModelFactory( + saritasa_sqlalchemy_tools.AsyncSQLAlchemyModelFactory[models.TestModel], +): + """Factory to generate TestModel.""" + + text = factory.Faker( + "pystr", + min_chars=1, + max_chars=30, + ) + text_enum = factory.fuzzy.FuzzyChoice( + models.TestModel.TextEnum, + ) + number = factory.Faker("pyint") + small_number = factory.Faker("pyint") + decimal_number = factory.Faker( + "pydecimal", + positive=True, + left_digits=5, + right_digits=0, + ) + boolean = factory.Faker("pybool") + text_list = factory.List( + [ + factory.Faker( + "pystr", + min_chars=1, + max_chars=30, + ) + for _ in range(3) + ], + ) + date_time = factory.Faker("date_time") + date = factory.Faker("date_between") + timedelta = factory.Faker("time_delta") + + class Meta: + model = models.TestModel + repository = repositories.TestModelRepository + sub_factories: typing.ClassVar = { + "related_model": "tests.factories.RelatedModelFactory", + "related_models": ("tests.factories.RelatedModelFactory", 5), + } + + +class SoftDeleteTestModelFactory( + saritasa_sqlalchemy_tools.AsyncSQLAlchemyModelFactory[ + models.SoftDeleteTestModel + ], +): + """Factory to generate SoftDeleteTestModel.""" + + text = factory.Faker( + "pystr", + min_chars=200, + max_chars=250, + ) + text_enum = factory.fuzzy.FuzzyChoice( + models.TestModel.TextEnum, + ) + number = factory.Faker("pyint") + small_number = factory.Faker("pyint") + decimal_number = factory.Faker( + "pydecimal", + positive=True, + left_digits=5, + right_digits=0, + ) + boolean = factory.Faker("pybool") + text_list = factory.List( + [ + factory.Faker( + "pystr", + min_chars=1, + max_chars=30, + ) + for _ in range(3) + ], + ) + date_time = factory.Faker("date_time") + date = factory.Faker("date_between") + timedelta = factory.Faker("time_delta") + + class Meta: + model = models.SoftDeleteTestModel + repository = repositories.SoftDeleteTestModelRepository + + +class RelatedModelFactory( + saritasa_sqlalchemy_tools.AsyncSQLAlchemyModelFactory[models.RelatedModel], +): + """Factory to generate RelatedModel.""" + + class Meta: + model = models.RelatedModel + repository = repositories.RelatedModelRepository + + +class M2MModelFactory( + saritasa_sqlalchemy_tools.AsyncSQLAlchemyModelFactory[models.M2MModel], +): + """Factory to generate M2MModel.""" + + class Meta: + model = models.M2MModel + repository = repositories.M2MModelRepository + sub_factories: typing.ClassVar = { + "related_model": "tests.factories.RelatedModelFactory", + "test_model": "tests.factories.TestModelFactory", + } diff --git a/tests/models.py b/tests/models.py new file mode 100644 index 0000000..5d79377 --- /dev/null +++ b/tests/models.py @@ -0,0 +1,351 @@ +import datetime +import decimal +import enum +import typing + +import sqlalchemy.orm + +import saritasa_sqlalchemy_tools + + +class RelatedModel( + saritasa_sqlalchemy_tools.TimeStampedBaseIDModel, +): + """Test models for checking relationships.""" + + __tablename__ = "related_model" + + test_model_list = sqlalchemy.orm.relationship( + "TestModel", + foreign_keys="TestModel.related_model_id", + back_populates="related_model", + ) + + test_model_list_nullable = sqlalchemy.orm.relationship( + "TestModel", + foreign_keys="TestModel.related_model_id_nullable", + back_populates="related_model_nullable", + ) + + test_model_id: sqlalchemy.orm.Mapped[int | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.ForeignKey( + "test_model.id", + ondelete="CASCADE", + ), + nullable=True, + ) + ) + + test_model = sqlalchemy.orm.relationship( + "TestModel", + foreign_keys=[test_model_id], + back_populates="related_models", + ) + + m2m_test_model_list = sqlalchemy.orm.relationship( + "TestModel", + secondary="m2m_model", + uselist=True, + viewonly=True, + ) + + m2m_associations = sqlalchemy.orm.relationship( + "M2MModel", + back_populates="related_model", + foreign_keys="M2MModel.related_model_id", + cascade="all, delete", + uselist=True, + ) + + +class FieldsMixin: + """Mixin which adds fields to models.""" + + text: sqlalchemy.orm.Mapped[str] = sqlalchemy.orm.mapped_column( + sqlalchemy.String(250), + nullable=False, + ) + + text_nullable: sqlalchemy.orm.Mapped[str | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.String(30), + nullable=True, + ) + ) + + class TextEnum(enum.StrEnum): + value_1 = "value1" + value_2 = "value2" + value_3 = "value3" + + text_enum: sqlalchemy.orm.Mapped[TextEnum] = sqlalchemy.orm.mapped_column( + sqlalchemy.Enum(TextEnum), + nullable=False, + ) + + text_enum_nullable: sqlalchemy.orm.Mapped[TextEnum | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.Enum(TextEnum), + nullable=True, + ) + ) + + number: sqlalchemy.orm.Mapped[int] = sqlalchemy.orm.mapped_column( + sqlalchemy.Integer(), + nullable=False, + ) + + number_nullable: sqlalchemy.orm.Mapped[int | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.Integer(), + nullable=True, + ) + ) + + small_number: sqlalchemy.orm.Mapped[int] = sqlalchemy.orm.mapped_column( + sqlalchemy.SmallInteger(), + nullable=False, + ) + + small_number_nullable: sqlalchemy.orm.Mapped[int | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.SmallInteger(), + nullable=True, + ) + ) + + decimal_number: sqlalchemy.orm.Mapped[decimal.Decimal] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.Numeric(), + nullable=False, + ) + ) + + decimal_number_nullable: sqlalchemy.orm.Mapped[decimal.Decimal | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.Numeric(), + nullable=True, + ) + ) + + boolean: sqlalchemy.orm.Mapped[bool] = sqlalchemy.orm.mapped_column( + sqlalchemy.Boolean(), + nullable=False, + ) + + boolean_nullable: sqlalchemy.orm.Mapped[bool | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.Boolean(), + nullable=True, + ) + ) + + text_list: sqlalchemy.orm.Mapped[list[str]] = sqlalchemy.orm.mapped_column( + sqlalchemy.ARRAY(sqlalchemy.String), + nullable=False, + ) + + text_list_nullable: sqlalchemy.orm.Mapped[list[str] | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.ARRAY(sqlalchemy.String), + nullable=True, + ) + ) + + date_time: sqlalchemy.orm.Mapped[datetime.datetime] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.DateTime(), + nullable=False, + ) + ) + + date_time_nullable: sqlalchemy.orm.Mapped[datetime.datetime | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.DateTime(), + nullable=True, + ) + ) + + date: sqlalchemy.orm.Mapped[datetime.date] = sqlalchemy.orm.mapped_column( + sqlalchemy.Date(), + nullable=False, + ) + + date_nullable: sqlalchemy.orm.Mapped[datetime.date | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.Date(), + nullable=True, + ) + ) + + timedelta: sqlalchemy.orm.Mapped[datetime.timedelta] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.Interval(), + nullable=False, + ) + ) + + timedelta_nullable: sqlalchemy.orm.Mapped[datetime.timedelta | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.Interval(), + nullable=True, + ) + ) + + @property + def custom_property(self) -> str: + """Implement property.""" + return "" + + @property + def custom_property_nullable(self) -> str | None: + """Implement property.""" + return "" + + +class TestModel( + FieldsMixin, + saritasa_sqlalchemy_tools.TimeStampedBaseIDModel, +): + """Test model for testing.""" + + __tablename__ = "test_model" + + m2m_filters: typing.ClassVar = { + "m2m_related_model_id": saritasa_sqlalchemy_tools.M2MFilterConfig( + relation_field="m2m_associations", + filter_field="related_model_id", + match_field="test_model_id", + ), + } + + related_model_id: sqlalchemy.orm.Mapped[int] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.ForeignKey( + "related_model.id", + ondelete="CASCADE", + ), + nullable=False, + ) + ) + + related_model = sqlalchemy.orm.relationship( + "RelatedModel", + foreign_keys=[related_model_id], + back_populates="test_model_list", + ) + + related_model_id_nullable: sqlalchemy.orm.Mapped[int | None] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.ForeignKey( + "related_model.id", + ondelete="CASCADE", + ), + nullable=True, + ) + ) + + related_model_nullable = sqlalchemy.orm.relationship( + "RelatedModel", + foreign_keys=[related_model_id_nullable], + back_populates="test_model_list_nullable", + ) + + related_models = sqlalchemy.orm.relationship( + "RelatedModel", + foreign_keys="RelatedModel.test_model_id", + back_populates="test_model", + ) + + related_models_count: sqlalchemy.orm.Mapped[int | None] = ( + sqlalchemy.orm.column_property( + sqlalchemy.select(sqlalchemy.func.count(RelatedModel.id)) + .correlate_except(RelatedModel) + .scalar_subquery(), + deferred=True, + ) + ) + + related_models_count_query: sqlalchemy.orm.Mapped[int | None] = ( + sqlalchemy.orm.query_expression() + ) + + m2m_related_models = sqlalchemy.orm.relationship( + "RelatedModel", + secondary="m2m_model", + uselist=True, + viewonly=True, + ) + + m2m_associations = sqlalchemy.orm.relationship( + "M2MModel", + back_populates="test_model", + foreign_keys="M2MModel.test_model_id", + cascade="all, delete", + uselist=True, + ) + + def __repr__(self) -> str: + """Get str representation.""" + return f"TestModel<{self.id}>" + + @property + def custom_property_related_model(self) -> typing.Any: + """Implement property.""" + return self.related_model + + @property + def custom_property_related_model_nullable(self) -> typing.Any | None: + """Implement property.""" + return self.related_model_nullable + + @property + def custom_property_related_models(self) -> list[typing.Any]: + """Implement property.""" + return self.related_models + + +class SoftDeleteTestModel( + FieldsMixin, + saritasa_sqlalchemy_tools.SoftDeleteBaseIDModel, +): + """Test model for testing(soft-delete case).""" + + __tablename__ = "soft_delete_test_model" + + +class M2MModel(saritasa_sqlalchemy_tools.TimeStampedBaseIDModel): + """Test model for testing m2m features.""" + + __tablename__ = "m2m_model" + + test_model_id: sqlalchemy.orm.Mapped[int] = sqlalchemy.orm.mapped_column( + sqlalchemy.ForeignKey( + "test_model.id", + ondelete="CASCADE", + ), + nullable=False, + ) + + test_model = sqlalchemy.orm.relationship( + "TestModel", + foreign_keys=[test_model_id], + back_populates="m2m_associations", + ) + + related_model_id: sqlalchemy.orm.Mapped[int] = ( + sqlalchemy.orm.mapped_column( + sqlalchemy.ForeignKey( + "related_model.id", + ondelete="CASCADE", + ), + nullable=False, + ) + ) + + related_model = sqlalchemy.orm.relationship( + "RelatedModel", + foreign_keys=[related_model_id], + back_populates="m2m_associations", + ) diff --git a/tests/repositories.py b/tests/repositories.py new file mode 100644 index 0000000..45ea211 --- /dev/null +++ b/tests/repositories.py @@ -0,0 +1,75 @@ +import typing + +import saritasa_sqlalchemy_tools + +from . import models + + +class RelatedModelRepository( + saritasa_sqlalchemy_tools.BaseRepository[models.RelatedModel], +): + """Repository for `RelatedModel` model.""" + + model: typing.TypeAlias = models.RelatedModel + default_exclude_bulk_create_fields = ( + "created", + "modified", + "id", + ) + default_exclude_bulk_update_fields = ( + "created", + "modified", + ) + + +class TestModelRepository( + saritasa_sqlalchemy_tools.BaseRepository[models.TestModel], +): + """Repository for `TestModel` model.""" + + model: typing.TypeAlias = models.TestModel + default_exclude_bulk_create_fields = ( + "created", + "modified", + "id", + ) + default_exclude_bulk_update_fields = ( + "created", + "modified", + ) + + +class SoftDeleteTestModelRepository( + saritasa_sqlalchemy_tools.BaseSoftDeleteRepository[ + models.SoftDeleteTestModel + ], +): + """Repository for `SoftDeleteTestModel` model.""" + + model: typing.TypeAlias = models.SoftDeleteTestModel + default_exclude_bulk_create_fields = ( + "created", + "modified", + "id", + ) + default_exclude_bulk_update_fields = ( + "created", + "modified", + ) + + +class M2MModelRepository( + saritasa_sqlalchemy_tools.BaseRepository[models.M2MModel], +): + """Repository for `M2MModel` model.""" + + model: typing.TypeAlias = models.M2MModel + default_exclude_bulk_create_fields = ( + "created", + "modified", + "id", + ) + default_exclude_bulk_update_fields = ( + "created", + "modified", + ) diff --git a/tests/test_auto_schema.py b/tests/test_auto_schema.py new file mode 100644 index 0000000..6423ba0 --- /dev/null +++ b/tests/test_auto_schema.py @@ -0,0 +1,356 @@ +import datetime +import decimal +import re +import typing + +import pydantic +import pytest + +import saritasa_sqlalchemy_tools + +from . import models, repositories + + +async def test_auto_schema_generation( + test_model: models.TestModel, +) -> None: + """Test schema generation picks correct types from model for schema.""" + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + model_config = pydantic.ConfigDict( + from_attributes=True, + validate_assignment=True, + ) + fields = ( + "id", + "created", + "modified", + "text", + "text_nullable", + "text_enum", + "text_enum_nullable", + "number", + "number_nullable", + "small_number", + "small_number_nullable", + "decimal_number", + "decimal_number_nullable", + "boolean", + "boolean_nullable", + "text_list", + "text_list_nullable", + "date_time", + "date_time_nullable", + "date", + "date_nullable", + "timedelta", + "timedelta_nullable", + "related_model_id", + "related_model_id_nullable", + "custom_property", + "custom_property_nullable", + ) + + schema = AutoSchema.get_schema() + model = schema.model_validate(test_model) + for field in AutoSchema.Meta.fields: + assert getattr(model, field) == getattr(test_model, field) + if "nullable" not in field and "property" not in field: + with pytest.raises(pydantic.ValidationError): + setattr(model, field, None) + + +async def test_auto_schema_type_override_generation( + test_model: models.TestModel, +) -> None: + """Test that in auto schema generation you can override type. + + Check if you specify type for field like this (field, type), then schema + generation will ignore type from model, and will use specified one. + + """ + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + fields = ( + ("text", str | None), + ("text_enum", models.TestModel.TextEnum | None), + ("number", int | None), + ("small_number", int | None), + ("decimal_number", decimal.Decimal | None), + ("boolean", bool | None), + ("text_list", list[str] | None), + ("date_time", datetime.datetime | None), + ("date", datetime.date | None), + ("timedelta", datetime.timedelta | None), + ("custom_property", str | None), + ("related_model_id", int | None), + ) + + schema = AutoSchema.get_schema() + model = schema.model_validate(test_model) + for field, _ in AutoSchema.Meta.fields: + if "property" not in field: + setattr(model, field, None) + + +async def test_auto_schema_type_invalid_field_config( + test_model: models.TestModel, +) -> None: + """Test than on invalid field config type there is an error. + + Check that if use any type other than str or tuple[str, type], schema + generation will raise an error. + + """ + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + fields = (("id", int, 1),) + + with pytest.raises( + saritasa_sqlalchemy_tools.auto_schema.UnableProcessTypeError, + match=re.escape( + "Can't process the following field ('id', , 1)", + ), + ): + AutoSchema.get_schema() + + +async def test_auto_schema_related_field_with_no_schema( + test_model: models.TestModel, +) -> None: + """Test that in generation raise error on relationships without type. + + For relationship field developer must specify type, otherwise generation + must throw an error. + + """ + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + fields = ( + "related_model", + "related_models", + ) + + with pytest.raises( + saritasa_sqlalchemy_tools.auto_schema.UnableProcessTypeError, + match=re.escape( + "Schema generation is not supported for relationship " + "fields(related_model), please use auto-schema or pydantic class", + ), + ): + AutoSchema.get_schema() + + +async def test_auto_schema_related_field_with_schema( + test_model: models.TestModel, + repository: repositories.TestModelRepository, +) -> None: + """Test that generation works correctly with related types auto. + + Verify that ModelAutoSchema can be used as a type. + + """ + + class RelatedAutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.RelatedModel + fields = ( + "id", + "created", + "modified", + ) + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + fields = ( + ("related_model", RelatedAutoSchema), + ("related_model_nullable", RelatedAutoSchema), + ("related_models", RelatedAutoSchema), + ("custom_property_related_model", RelatedAutoSchema), + ("custom_property_related_model_nullable", RelatedAutoSchema), + ("custom_property_related_models", RelatedAutoSchema), + ) + + schema = AutoSchema.get_schema() + instance = ( + await repository.fetch( + id=test_model.id, + select_in_load=( + models.TestModel.related_model, + models.TestModel.related_model_nullable, + models.TestModel.related_models, + ), + ) + ).first() + model = schema.model_validate(instance) + isinstance(model.related_models, list) + for field in RelatedAutoSchema.Meta.fields: + assert getattr( + model.related_model, + field, + ) == getattr( + test_model.related_model, + field, + ) + for related_model, test_related_model in zip( + model.related_models, + test_model.related_models, + strict=False, + ): + assert getattr( + related_model, + field, + ) == getattr( + test_related_model, + field, + ) + + +async def test_auto_schema_use_both_config_and_model() -> None: + """Test schema generation fails when both config and model are used. + + It's not allowed to specify both model config and base model in Meta. + + """ + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + model_config = pydantic.ConfigDict( + from_attributes=True, + validate_assignment=True, + ) + base_model = pydantic.BaseModel + fields = ( + "id", + "created", + "modified", + ) + + with pytest.raises( + ValueError, + match=re.escape( + "Only config or base model could be passed to create_model", + ), + ): + AutoSchema.get_schema() + + +def custom_validator( + cls, # noqa: ANN001 + value: typing.Any, + info: pydantic.ValidationInfo, +) -> None: + """Raise value error.""" + raise ValueError("This is custom validator") + + +def test_custom_field_validators( + test_model: models.TestModel, +) -> None: + """Test field validators for schema generation. + + Verify that schema generation would assign validators from + extra_fields_validators for each field. + + """ + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + fields = ("id",) + extra_fields_validators: typing.ClassVar = { + "id": (custom_validator,), + } + + schema = AutoSchema.get_schema() + with pytest.raises( + pydantic.ValidationError, + match=re.escape("This is custom validator"), + ): + schema.model_validate(test_model) + + +def test_custom_field_validators_custom_type( + test_model: models.TestModel, +) -> None: + """Test field validators for schema generation(custom type case). + + Same as test_custom_field_validators, but in this case custom type is used. + + """ + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + fields = (("id", int),) + extra_fields_validators: typing.ClassVar = { + "id": (custom_validator,), + } + + schema = AutoSchema.get_schema() + with pytest.raises( + pydantic.ValidationError, + match=re.escape("This is custom validator"), + ): + schema.model_validate(test_model) + + +def test_custom_field_validators_property( + test_model: models.TestModel, +) -> None: + """Test field validators for property for schema generation. + + Verify that schema generation correctly works with models @property attrs + + """ + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + fields = ("custom_property_related_model",) + extra_fields_validators: typing.ClassVar = { + "custom_property_related_model": (custom_validator,), + } + + schema = AutoSchema.get_schema() + with pytest.raises( + pydantic.ValidationError, + match=re.escape("This is custom validator"), + ): + schema.model_validate(test_model) + + +def test_custom_field_validators_property_custom_type( + test_model: models.TestModel, +) -> None: + """Test validators for property for schema generation(custom type case). + + Same as test_custom_field_validators_property, but custom type is + specified. + + """ + + class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema): + class Meta: + model = models.TestModel + fields = (("custom_property_related_model", typing.Any),) + extra_fields_validators: typing.ClassVar = { + "custom_property_related_model": (custom_validator,), + } + + schema = AutoSchema.get_schema() + with pytest.raises( + pydantic.ValidationError, + match=re.escape("This is custom validator"), + ): + schema.model_validate(test_model) diff --git a/tests/test_factories.py b/tests/test_factories.py new file mode 100644 index 0000000..1a1ca8e --- /dev/null +++ b/tests/test_factories.py @@ -0,0 +1,120 @@ +import pytest + +import saritasa_sqlalchemy_tools + +from . import factories, models, repositories + + +async def test_factory_missing_repository( + db_session: saritasa_sqlalchemy_tools.Session, + repository: repositories.TestModelRepository, +) -> None: + """Test that error is raised if repository class missing.""" + + class TestModelFactory( + saritasa_sqlalchemy_tools.AsyncSQLAlchemyModelFactory[ + models.TestModel + ], + ): + """Factory to generate TestModel.""" + + class Meta: + model = models.TestModel + + with pytest.raises( + ValueError, + match="Repository class is not set in Meta class", + ): + await TestModelFactory.create_async(db_session) + + +async def test_factory_generation( + db_session: saritasa_sqlalchemy_tools.Session, + repository: repositories.TestModelRepository, +) -> None: + """Test that model generation works as expected.""" + instance = await factories.TestModelFactory.create_async(db_session) + assert await repository.exists(id=instance.id) + + +async def test_factory_generation_sub_factories( + db_session: saritasa_sqlalchemy_tools.Session, + repository: repositories.TestModelRepository, +) -> None: + """Test that model generation works with sub factories as expected. + + It should create and attach all related models. + + """ + instance = await factories.TestModelFactory.create_async(db_session) + instance_created = ( + await repository.fetch( + id=instance.id, + select_in_load=( + models.TestModel.related_model, + models.TestModel.related_models, + ), + ) + ).first() + assert instance_created + assert instance_created.related_model_id + assert instance_created.related_model + assert len(instance_created.related_models) == 5 + + +async def test_factory_generation_skip_sub_factories( + db_session: saritasa_sqlalchemy_tools.Session, + repository: repositories.TestModelRepository, +) -> None: + """Test that sub_factory will be skipped if value is passed. + + If passed value for fields which are generated via sub-factories, + sub-factories should be called. + + """ + await factories.TestModelFactory.create_async( + db_session, + related_model=await factories.RelatedModelFactory.create_async( + db_session, + ), + related_models=await factories.RelatedModelFactory.create_batch_async( + db_session, + size=3, + ), + ) + assert await repositories.RelatedModelRepository(db_session).count() == 4 + + +async def test_factory_generation_skip_sub_factories_id_passed( + db_session: saritasa_sqlalchemy_tools.Session, + repository: repositories.TestModelRepository, +) -> None: + """Test that sub_factory will be skipped if fk id is passed. + + Same as test_factory_generation_skip_sub_factories, but in this case we + pass id of related object. + + """ + await factories.TestModelFactory.create_async( + db_session, + related_model_id=( + await factories.RelatedModelFactory.create_async( + db_session, + ) + ).id, + related_models=[], + ) + assert await repositories.RelatedModelRepository(db_session).count() == 1 + + +async def test_factory_generation_batch( + db_session: saritasa_sqlalchemy_tools.Session, + repository: repositories.TestModelRepository, +) -> None: + """Test that model batch generation works as expected.""" + instances = await factories.TestModelFactory.create_batch_async( + db_session, + size=5, + ) + for instance in instances: + assert await repository.exists(id=instance.id) diff --git a/tests/test_ordering.py b/tests/test_ordering.py new file mode 100644 index 0000000..910a6db --- /dev/null +++ b/tests/test_ordering.py @@ -0,0 +1,43 @@ +import pytest +import sqlalchemy + +import saritasa_sqlalchemy_tools + + +@pytest.fixture +def generated_enum() -> type[saritasa_sqlalchemy_tools.OrderingEnum]: + """Generate enum for testing.""" + return saritasa_sqlalchemy_tools.OrderingEnum( # type: ignore + "GeneratedEnum", + [ + "field", + ], + ) + + +def test_ordering_enum_generated_values( + generated_enum: type[saritasa_sqlalchemy_tools.OrderingEnum], +) -> None: + """Test ordering enum generation. + + Check that reverse option are present. + + """ + assert hasattr(generated_enum, "field") + assert generated_enum.field == "field" + assert hasattr(generated_enum, "field_desc") + assert generated_enum.field_desc == "-field" + + +def test_ordering_enum_generated_sql_clause( + generated_enum: type[saritasa_sqlalchemy_tools.OrderingEnum], +) -> None: + """Test that ordering enum generates correct sql clause.""" + assert generated_enum.field.sql_clause == "field" + expected_db_clause: sqlalchemy.UnaryExpression[str] = sqlalchemy.desc( + "field", + ) + actual_db_clause = generated_enum.field_desc.sql_clause + assert actual_db_clause.__class__ == expected_db_clause.__class__ + assert actual_db_clause.modifier == expected_db_clause.modifier + assert actual_db_clause.operator == expected_db_clause.operator diff --git a/tests/test_repositories.py b/tests/test_repositories.py new file mode 100644 index 0000000..21535c3 --- /dev/null +++ b/tests/test_repositories.py @@ -0,0 +1,524 @@ +import pytest +import sqlalchemy + +import saritasa_sqlalchemy_tools + +from . import factories, models, repositories + + +async def test_init_other( + repository: repositories.TestModelRepository, +) -> None: + """Test init other.""" + new_repo = repository.init_other( + repositories.SoftDeleteTestModelRepository, + ) + assert isinstance(new_repo, repositories.SoftDeleteTestModelRepository) + assert new_repo.db_session is repository.db_session + + +async def test_soft_delete( + soft_delete_test_model: models.SoftDeleteTestModel, + soft_delete_repository: repositories.SoftDeleteTestModelRepository, +) -> None: + """Test soft deletion, model instance should be still present in db.""" + await soft_delete_repository.delete(soft_delete_test_model) + soft_deleted_object = await soft_delete_repository.get( + soft_delete_test_model.id, + ) + assert soft_deleted_object + assert soft_deleted_object.deleted + + +async def test_force_soft_delete( + soft_delete_test_model: models.SoftDeleteTestModel, + soft_delete_repository: repositories.SoftDeleteTestModelRepository, +) -> None: + """Test force soft delete.""" + await soft_delete_repository.force_delete(soft_delete_test_model) + assert not await soft_delete_repository.get(soft_delete_test_model.id) + + +async def test_insert_batch( + related_model: models.RelatedModel, + repository: repositories.TestModelRepository, +) -> None: + """Test batch insert.""" + instances = factories.TestModelFactory.build_batch( + related_model_id=related_model.id, + size=5, + ) + created_instances = await repository.insert_batch(instances) + for created_instance in created_instances: + assert created_instance.id + assert created_instance.created + assert created_instance.modified + assert created_instance.related_model_id == related_model.id + + +async def test_update_batch( + test_model_list: list[models.TestModel], + related_model: models.RelatedModel, + repository: repositories.TestModelRepository, +) -> None: + """Test batch update.""" + for test_model in test_model_list: + test_model.related_model_id = related_model.id + await repository.update_batch(test_model_list) + for test_model in test_model_list: + pk = test_model.id + repository.expire(test_model) + test_model = await repository.get(pk) # type: ignore + assert test_model.related_model_id == related_model.id + + +async def test_delete_batch( + test_model_list: list[models.TestModel], + repository: repositories.TestModelRepository, +) -> None: + """Test batch delete.""" + ids = [test_model.id for test_model in test_model_list] + await repository.delete_batch(where=[models.TestModel.id.in_(ids)]) + assert not await repository.count(where=[models.TestModel.id.in_(ids)]) + + +async def test_save( + related_model: models.RelatedModel, + repository: repositories.TestModelRepository, +) -> None: + """Test that repository properly saves instances.""" + new_test_model = factories.TestModelFactory.build( + related_model_id=related_model.id, + ) + new_test_model = await repository.save(new_test_model, refresh=True) + assert (await repository.fetch(id=new_test_model.id)).first() + + +async def test_expire( + test_model: models.TestModel, + repository: repositories.TestModelRepository, +) -> None: + """Test that repository properly expires instances.""" + repository.expire(test_model) + with pytest.raises(sqlalchemy.exc.MissingGreenlet): + _ = test_model.id + + +@pytest.mark.parametrize( + "reuse_select_statement", + [ + True, + False, + ], +) +async def test_ordering( + test_model_list: list[models.TestModel], + repository: repositories.TestModelRepository, + reuse_select_statement: bool, +) -> None: + """Test ordering.""" + ordered_models = sorted( + test_model_list, + key=lambda instance: instance.text.lower(), + ) + if reuse_select_statement: + select_statement = repository.get_order_statement( + None, + *["text"], + ) + models_from_db = ( + await repository.fetch(statement=select_statement) + ).all() + else: + models_from_db = ( + await repository.fetch(ordering_clauses=["text"]) + ).all() + for actual, expected in zip(models_from_db, ordered_models, strict=False): + assert actual.id == expected.id + + +async def test_ordering_with_enum( + test_model_list: list[models.TestModel], + repository: repositories.TestModelRepository, +) -> None: + """Test ordering with enum.""" + ordered_models = sorted( + test_model_list, + key=lambda instance: instance.text.lower(), + reverse=True, + ) + + class OrderingEnum(saritasa_sqlalchemy_tools.OrderingEnum): + text = "text" + + models_from_db = ( + await repository.fetch(ordering_clauses=[OrderingEnum.text_desc]) + ).all() + for actual, expected in zip(models_from_db, ordered_models, strict=False): + assert actual.id == expected.id + + +@pytest.mark.parametrize( + "reuse_select_statement", + [ + True, + False, + ], +) +async def test_pagination( + test_model_list: list[models.TestModel], + repository: repositories.TestModelRepository, + reuse_select_statement: bool, +) -> None: + """Test pagination.""" + ordered_models = sorted( + test_model_list, + key=lambda instance: instance.id, + ) + args = { + "limit": 2, + "offset": 1, + } + if reuse_select_statement: + select_statement = repository.get_pagination_statement(**args) # type: ignore + models_from_db = ( + await repository.fetch( + statement=select_statement, + ordering_clauses=["id"], + ) + ).all() + else: + models_from_db = ( + await repository.fetch(ordering_clauses=["id"], **args) # type: ignore + ).all() + + assert models_from_db[0].id == ordered_models[1].id + assert models_from_db[1].id == ordered_models[2].id + + +@pytest.mark.parametrize( + "prefetch_type", + [ + "select_in_load", + "joined_load", + ], +) +@pytest.mark.parametrize( + "reuse_select_statement", + [ + True, + False, + ], +) +async def test_prefetch( + test_model: models.TestModel, + repository: repositories.TestModelRepository, + prefetch_type: str, + reuse_select_statement: bool, +) -> None: + """Test prefetching of related fields of model.""" + await factories.RelatedModelFactory.create_batch_async( + session=repository.db_session, + test_model_id=test_model.id, + size=5, + ) + await factories.TestModelFactory.create_batch_async( + session=repository.db_session, + related_model_id=test_model.related_model_id, + size=3, + ) + targets = ( + models.TestModel.related_model, + models.TestModel.related_model_nullable, + models.TestModel.related_models, + ( + models.TestModel.related_model, + models.RelatedModel.test_model_list, + ), + ) + args = { + "id": test_model.id, + prefetch_type: targets, + } + if reuse_select_statement: + select_statement = getattr( + repository, + f"get_{prefetch_type}_statement", + )(None, *targets) + instance = ( + await repository.fetch( + statement=select_statement, + id=test_model.id, + ) + ).first() + else: + instance = (await repository.fetch(**args)).first() # type: ignore + assert instance + assert instance.related_model + assert not instance.related_model_nullable + assert len(instance.related_models) == 5 + # Original plus created. + assert len(instance.related_model.test_model_list) == 4 + + +@pytest.mark.parametrize( + "reuse_select_statement", + [ + True, + False, + ], +) +async def test_filter_in( + test_model_list: list[models.TestModel], + repository: repositories.TestModelRepository, + reuse_select_statement: bool, +) -> None: + """Test filter `in`.""" + filtered_models = list( + filter( + lambda instance: instance.text + in [test_model_list[0].text, test_model_list[3].text], + test_model_list, + ), + ) + args = { + "where": [ + saritasa_sqlalchemy_tools.Filter( + field="text__in", + value=[test_model_list[0].text, test_model_list[3].text], + ), + ], + } + if reuse_select_statement: + select_statement = repository.get_filter_statement( + None, + *args["where"], + ) + instances = ( + await repository.fetch( + statement=select_statement, + ordering_clauses=["id"], + ) + ).all() + else: + instances = ( + await repository.fetch( + **args, # type: ignore + ordering_clauses=["id"], + ) + ).all() + assert instances[0].id == filtered_models[0].id + assert instances[1].id == filtered_models[1].id + + +@pytest.mark.parametrize( + "reuse_select_statement", + [ + True, + False, + ], +) +async def test_filter_gte( + test_model_list: list[models.TestModel], + repository: repositories.TestModelRepository, + reuse_select_statement: bool, +) -> None: + """Test filter `gte`.""" + max_num = max(test_model.number for test_model in test_model_list) + + args = { + "where": [ + saritasa_sqlalchemy_tools.Filter( + field="number__gte", + value=max_num, + ), + ], + } + if reuse_select_statement: + select_statement = repository.get_filter_statement( + None, + *args["where"], + ) + instances = ( + await repository.fetch( + statement=select_statement, + ordering_clauses=["id"], + ) + ).all() + else: + instances = ( + await repository.fetch( + **args, # type: ignore + ordering_clauses=["id"], + ) + ).all() + assert len(instances) == 1 + assert instances[0].number == max_num + + +@pytest.mark.parametrize( + "reuse_select_statement", + [ + True, + False, + ], +) +async def test_filter_m2m( + test_model: models.TestModel, + repository: repositories.TestModelRepository, + reuse_select_statement: bool, +) -> None: + """Test filter related to m2m fields.""" + await factories.M2MModelFactory.create_batch_async( + session=repository.db_session, + size=5, + test_model_id=test_model.id, + related_model_id=test_model.related_model_id, + ) + await factories.M2MModelFactory.create_batch_async( + session=repository.db_session, + size=5, + ) + + args = { + "where": [ + saritasa_sqlalchemy_tools.Filter( + field="m2m_related_model_id__in", + value=[test_model.related_model_id], + ), + ], + } + if reuse_select_statement: + select_statement = repository.get_filter_statement( + None, + *args["where"], + ) + instances = ( + ( + await repository.fetch( + statement=select_statement, + ordering_clauses=["id"], + ) + ) + .unique() + .all() + ) + else: + instances = ( + ( + await repository.fetch( + **args, # type: ignore + ordering_clauses=["id"], + ) + ) + .unique() + .all() + ) + assert len(instances) == 1 + assert instances[0].id == test_model.id + + +async def test_search_filter( + test_model_list: list[models.TestModel], + repository: repositories.TestModelRepository, +) -> None: + """Test search filter.""" + search_text = test_model_list[0].text + instances = ( + ( + await repository.fetch( + where=[ + saritasa_sqlalchemy_tools.transform_search_filter( + models.TestModel, + search_fields=[ + "text", + "id", + ], + value=search_text, + ), + ], + ordering_clauses=["id"], + ) + ) + .unique() + .all() + ) + assert len(instances) == 1 + assert instances[0].id == test_model_list[0].id + + +@pytest.mark.parametrize( + "reuse_select_statement", + [ + True, + False, + ], +) +async def test_annotation( + test_model: models.TestModel, + repository: repositories.TestModelRepository, + reuse_select_statement: bool, +) -> None: + """Test annotations loading.""" + await factories.RelatedModelFactory.create_batch_async( + session=repository.db_session, + size=5, + test_model_id=test_model.id, + ) + annotations = (models.TestModel.related_models_count,) + if reuse_select_statement: + select_statement = repository.get_annotated_statement( + None, + *annotations, + ) + instance = ( + await repository.fetch( + statement=select_statement, + id=test_model.id, + ) + ).first() + else: + instance = ( + await repository.fetch( + id=test_model.id, + annotations=annotations, + ) + ).first() + assert instance + assert ( + instance.related_models_count + == await repository.init_other( + repositories.RelatedModelRepository, + ).count() + ) + + +async def test_annotation_query( + test_model: models.TestModel, + repository: repositories.TestModelRepository, +) -> None: + """Test annotations loading when using dynamic queries.""" + instance = ( + await repository.fetch( + id=test_model.id, + annotations=( + ( + models.TestModel.related_models_count_query, + sqlalchemy.select( + sqlalchemy.func.count(models.RelatedModel.id), + ) + .where( + models.RelatedModel.test_model_id + == models.TestModel.id, + ) + .correlate_except(models.RelatedModel) + .scalar_subquery(), + ), + ), + select_in_load=(models.TestModel.related_models,), + ) + ).first() + assert instance + assert instance.related_models_count_query == len( + test_model.related_models, + )