From c55dd7b485bb6838be6631d5a1cba71f42c1e4ae Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Tue, 5 Dec 2023 13:20:44 +0100 Subject: [PATCH] Migrate from Flake8 & Autopep8 to Ruff (#133) --- .flake8 | 30 - .gitignore | 1 + .isort.cfg | 7 - CHANGELOG.md | 6 +- CONTRIBUTING.md | 4 +- Makefile | 18 +- pyproject.toml | 100 ++- scripts/check_version_availability.py | 3 + scripts/check_version_in_changelog.py | 4 +- scripts/print_current_package_version.py | 2 + scripts/update_version_for_prerelease.py | 2 + scripts/utils.py | 22 +- src/apify/_crypto.py | 26 +- .../_memory_storage/file_storage_utils.py | 17 +- .../_memory_storage/memory_storage_client.py | 62 +- .../resource_clients/__init__.py | 13 +- .../resource_clients/base_resource_client.py | 62 +- .../base_resource_collection_client.py | 64 +- .../resource_clients/dataset.py | 215 ++++--- .../resource_clients/dataset_collection.py | 26 +- .../resource_clients/key_value_store.py | 223 ++++--- .../key_value_store_collection.py | 24 +- .../resource_clients/request_queue.py | 182 +++--- .../request_queue_collection.py | 26 +- src/apify/_utils.py | 149 +++-- src/apify/actor.py | 596 +++++++++--------- src/apify/config.py | 131 ++-- src/apify/consts.py | 4 +- src/apify/event_manager.py | 65 +- src/apify/log.py | 25 +- src/apify/proxy_configuration.py | 156 +++-- src/apify/scrapy/middlewares.py | 40 +- src/apify/scrapy/pipelines.py | 12 +- src/apify/scrapy/scheduler.py | 33 +- src/apify/scrapy/utils.py | 22 +- src/apify/storages/__init__.py | 7 +- src/apify/storages/base_storage.py | 74 ++- src/apify/storages/dataset.py | 218 ++++--- src/apify/storages/key_value_store.py | 122 ++-- src/apify/storages/request_queue.py | 227 ++++--- src/apify/storages/storage_client_manager.py | 29 +- tests/integration/_utils.py | 2 + .../actor_source_base/src/__main__.py | 2 + .../integration/actor_source_base/src/main.py | 2 + tests/integration/conftest.py | 66 +- tests/integration/test_actor_api_helpers.py | 94 +-- .../test_actor_create_proxy_configuration.py | 37 +- tests/integration/test_actor_dataset.py | 27 +- tests/integration/test_actor_events.py | 12 +- .../integration/test_actor_key_value_store.py | 45 +- tests/integration/test_actor_lifecycle.py | 43 +- tests/integration/test_actor_log.py | 9 +- tests/integration/test_actor_request_queue.py | 23 +- tests/integration/test_fixtures.py | 20 +- tests/integration/test_request_queue.py | 13 +- .../test_actor_create_proxy_configuration.py | 104 +-- tests/unit/actor/test_actor_dataset.py | 18 +- tests/unit/actor/test_actor_env_helpers.py | 17 +- tests/unit/actor/test_actor_helpers.py | 56 +- .../unit/actor/test_actor_key_value_store.py | 29 +- tests/unit/actor/test_actor_lifecycle.py | 33 +- tests/unit/actor/test_actor_log.py | 10 +- .../actor/test_actor_memory_storage_e2e.py | 32 +- tests/unit/actor/test_actor_request_queue.py | 6 +- tests/unit/conftest.py | 51 +- .../resource_clients/test_dataset.py | 10 +- .../test_dataset_collection.py | 10 +- .../resource_clients/test_key_value_store.py | 224 ++++--- .../test_key_value_store_collection.py | 11 +- .../resource_clients/test_request_queue.py | 150 +++-- .../test_request_queue_collection.py | 11 +- .../memory_storage/test_memory_storage.py | 7 +- tests/unit/storages/test_dataset.py | 4 +- tests/unit/storages/test_key_value_store.py | 4 +- tests/unit/storages/test_request_queue.py | 50 +- tests/unit/test_config.py | 14 +- tests/unit/test_crypto.py | 61 +- tests/unit/test_event_manager.py | 41 +- tests/unit/test_lru_cache.py | 4 +- tests/unit/test_proxy_configuration.py | 200 +++--- tests/unit/test_utils.py | 181 +++--- 81 files changed, 2797 insertions(+), 1985 deletions(-) delete mode 100644 .flake8 delete mode 100644 .isort.cfg mode change 100644 => 100755 scripts/check_version_availability.py mode change 100644 => 100755 scripts/check_version_in_changelog.py mode change 100644 => 100755 scripts/print_current_package_version.py mode change 100644 => 100755 scripts/update_version_for_prerelease.py diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 0b2996e3..00000000 --- a/.flake8 +++ /dev/null @@ -1,30 +0,0 @@ -[flake8] -filename = - ./scripts/*.py, - ./src/*.py, - ./tests/*.py -per-file-ignores = - scripts/*: D - tests/*: D - **/__init__.py: F401 - -# Google docstring convention + D204 & D401 -docstring-convention = all -ignore = - D100 - D104 - D203 - D213 - D215 - D406 - D407 - D408 - D409 - D413 - U101 - -max_line_length = 150 -unused-arguments-ignore-overload-functions = True -unused-arguments-ignore-stub-functions = True -pytest-fixture-no-parentheses = True -pytest-mark-no-parentheses = True diff --git a/.gitignore b/.gitignore index e05d411c..a89c2ed0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ __pycache__ .mypy_cache .pytest_cache +.ruff_cache .venv .direnv diff --git a/.isort.cfg b/.isort.cfg deleted file mode 100644 index 13919fae..00000000 --- a/.isort.cfg +++ /dev/null @@ -1,7 +0,0 @@ -[isort] -include_trailing_comma = True -line_length = 150 -use_parentheses = True -multi_line_output = 3 -sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER -known_first_party = apify,apify_client,apify_shared diff --git a/CHANGELOG.md b/CHANGELOG.md index 55182873..c3958929 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,12 @@ Changelog ========= -[1.3.1](../../releases/tag/v1.3.1) - Unreleased +[1.4.0](../../releases/tag/v1.4.0) - Unreleased ----------------------------------------------- -... +### Internal changes + +- Migrate from Autopep8 and Flake8 to Ruff [1.3.0](../../releases/tag/v1.3.0) - 2023-11-15 ----------------------------------------------- diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index cdbdfef7..2863e61a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -24,11 +24,11 @@ To install this package and its development dependencies, run `make install-dev` ## Formatting -We use `autopep8` and `isort` to automatically format the code to a common format. To run the formatting, just run `make format`. +We use `ruff` to automatically format the code to a common format. To run the formatting, just run `make format`. ## Linting, type-checking and unit testing -We use `flake8` for linting, `mypy` for type checking and `pytest` for unit testing. To run these tools, just run `make check-code`. +We use `ruff` for linting, `mypy` for type checking and `pytest` for unit testing. To run these tools, just run `make check-code`. ## Integration tests diff --git a/Makefile b/Makefile index b36245fb..d0cf1899 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,7 @@ .PHONY: clean install-dev build publish twine-check lint unit-tests integration-tests type-check check-code format check-version-availability check-changelog-entry build-api-reference +DIRS_WITH_CODE = src tests scripts + # This is default for local testing, but GitHub workflows override it to a higher value in CI INTEGRATION_TESTS_CONCURRENCY = 1 @@ -7,21 +9,21 @@ clean: rm -rf build dist .mypy_cache .pytest_cache src/*.egg-info __pycache__ install-dev: - python -m pip install --upgrade pip + python3 -m pip install --upgrade pip pip install --no-cache-dir -e ".[dev,scrapy]" pre-commit install build: - python -m build + python3 -m build publish: - python -m twine upload dist/* + python3 -m twine upload dist/* twine-check: - python -m twine check dist/* + python3 -m twine check dist/* lint: - python3 -m flake8 + python3 -m ruff check $(DIRS_WITH_CODE) unit-tests: python3 -m pytest -n auto -ra tests/unit @@ -30,13 +32,13 @@ integration-tests: python3 -m pytest -n $(INTEGRATION_TESTS_CONCURRENCY) -ra tests/integration type-check: - python3 -m mypy + python3 -m mypy $(DIRS_WITH_CODE) check-code: lint type-check unit-tests format: - python3 -m isort src tests - python3 -m autopep8 --in-place --recursive src tests + python3 -m ruff check --fix $(DIRS_WITH_CODE) + python3 -m ruff format $(DIRS_WITH_CODE) check-version-availability: python3 scripts/check_version_availability.py diff --git a/pyproject.toml b/pyproject.toml index 4415ff39..d2b55f35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,10 @@ [project] name = "apify" -version = "1.3.1" +version = "1.4.0" description = "Apify SDK for Python" readme = "README.md" -license = {text = "Apache Software License"} -authors = [ - {name = "Apify Technologies s.r.o.", email = "support@apify.com"}, -] +license = { text = "Apache Software License" } +authors = [{ name = "Apify Technologies s.r.o.", email = "support@apify.com" }] keywords = ["apify", "sdk", "actor", "scraping", "automation"] classifiers = [ @@ -26,8 +24,8 @@ requires-python = ">=3.8" dependencies = [ "aiofiles >= 22.1.0", "aioshutil >= 1.0", - "apify-client ~= 1.5.0", - "apify-shared ~= 1.0.4", + "apify-client ~= 1.6.0", + "apify-shared ~= 1.1.0", "colorama >= 0.4.6", "cryptography >= 39.0.0", "httpx >= 0.24.1", @@ -40,25 +38,9 @@ dependencies = [ [project.optional-dependencies] dev = [ - "autopep8 ~= 2.0.4", "build ~= 1.0.3", "filelock ~= 3.12.4", - "flake8 ~= 6.1.0", - "flake8-bugbear ~= 23.9.16", - "flake8-commas ~= 2.1.0; python_version < '3.12'", - "flake8-comprehensions ~= 3.14.0", - "flake8-datetimez ~= 20.10.0", - "flake8-docstrings ~= 1.7.0", - "flake8-encodings ~= 0.5.0", - "flake8-isort ~= 6.1.0", - "flake8-noqa ~= 1.3.1; python_version < '3.12'", - "flake8-pytest-style ~= 1.7.2", - "flake8-quotes ~= 3.3.2; python_version < '3.12'", - "flake8-simplify ~= 0.21.0", - "flake8-unused-arguments ~= 0.0.13", - "isort ~= 5.12.0", - "mypy ~= 1.5.1", - "pep8-naming ~= 0.13.3", + "mypy ~= 1.7.1", "pre-commit ~= 3.4.0", "pydoc-markdown ~= 4.8.2", "pytest ~= 7.4.2", @@ -67,6 +49,7 @@ dev = [ "pytest-timeout ~= 2.2.0", "pytest-xdist ~= 3.3.1", "respx ~= 0.20.1", + "ruff ~= 0.1.6", "twine ~= 4.0.2", "types-aiofiles ~= 23.2.0.0", "types-colorama ~= 0.4.15.11", @@ -94,3 +77,72 @@ include = ["apify*"] [tool.setuptools.package-data] apify = ["py.typed"] + +[tool.ruff] +line-length = 150 +select = ["ALL"] +ignore = [ + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed in {filename} + "BLE001", # Do not catch blind exception + "C901", # `{name}` is too complex + "COM812", # This rule may cause conflicts when used with the formatter + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "EM", # flake8-errmsg + "G004", # Logging statement uses f-string + "ISC001", # This rule may cause conflicts when used with the formatter + "FIX", # flake8-fixme + "PGH003", # Use specific rule codes when ignoring type issues + "PLR0911", # Too many return statements + "PLR0913", # Too many arguments in function definition + "PLR0915", # Too many statements + "PTH", # flake8-use-pathlib + "PYI034", # `__aenter__` methods in classes like `{name}` usually return `self` at runtime + "PYI036", # The second argument in `__aexit__` should be annotated with `object` or `BaseException | None` + "S102", # Use of `exec` detected + "S105", # Possible hardcoded password assigned to + "S106", # Possible hardcoded password assigned to argument: "{name}" + "S301", # `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue + "S303", # Use of insecure MD2, MD4, MD5, or SHA1 hash function + "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes + "TD002", # Missing author in TODO; try: `# TODO(): ...` or `# TODO @: ... + "TID252", # Relative imports from parent modules are bannedRuff + "TRY003", # Avoid specifying long messages outside the exception class + + # TODO: Remove this once the following issue is fixed + # https://github.com/apify/apify-sdk-python/issues/150 + "SLF001", # Private member accessed: `{name}` +] + +[tool.ruff.format] +quote-style = "single" +indent-style = "space" + +[tool.ruff.lint.per-file-ignores] +"**/__init__.py" = [ + "F401", # Unused imports +] +"**/{scripts}/*" = [ + "D", # Everything from the pydocstyle + "INP001", # File {filename} is part of an implicit namespace package, add an __init__.py + "PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable + "T20", # flake8-print +] +"**/{tests}/*" = [ + "D", # Everything from the pydocstyle + "INP001", # File {filename} is part of an implicit namespace package, add an __init__.py + "PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable + "S101", # Use of assert detected + "T20", # flake8-print + "TRY301", # Abstract `raise` to an inner function +] + +[tool.ruff.lint.flake8-quotes] +docstring-quotes = "double" +inline-quotes = "single" + +[tool.ruff.lint.isort] +known-first-party = ["apify", "apify_client", "apify_shared"] + +[tool.ruff.lint.pydocstyle] +convention = "google" diff --git a/scripts/check_version_availability.py b/scripts/check_version_availability.py old mode 100644 new mode 100755 index b133a9ae..372d3215 --- a/scripts/check_version_availability.py +++ b/scripts/check_version_availability.py @@ -1,4 +1,7 @@ #!/usr/bin/env python3 + +from __future__ import annotations + from utils import get_current_package_version, get_published_package_versions # Checks whether the current package version number was not already used in a published release. diff --git a/scripts/check_version_in_changelog.py b/scripts/check_version_in_changelog.py old mode 100644 new mode 100755 index f72bee87..87aed57c --- a/scripts/check_version_in_changelog.py +++ b/scripts/check_version_in_changelog.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +from __future__ import annotations + import re from utils import REPO_ROOT, get_current_package_version @@ -16,7 +18,7 @@ with open(CHANGELOG_PATH, encoding='utf-8') as changelog_file: for line in changelog_file: # The heading for the changelog entry for the given version can start with either the version number, or the version number in a link - if re.match(fr'\[?{current_package_version}([\] ]|$)', line): + if re.match(rf'\[?{current_package_version}([\] ]|$)', line): break else: raise RuntimeError(f'There is no entry in the changelog for the current package version ({current_package_version})') diff --git a/scripts/print_current_package_version.py b/scripts/print_current_package_version.py old mode 100644 new mode 100755 index 3c78a5b6..9cff474b --- a/scripts/print_current_package_version.py +++ b/scripts/print_current_package_version.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +from __future__ import annotations + from utils import get_current_package_version # Print the current package version from the pyproject.toml file to stdout diff --git a/scripts/update_version_for_prerelease.py b/scripts/update_version_for_prerelease.py old mode 100644 new mode 100755 index 51ddf655..ee1882f4 --- a/scripts/update_version_for_prerelease.py +++ b/scripts/update_version_for_prerelease.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +from __future__ import annotations + import re import sys diff --git a/scripts/utils.py b/scripts/utils.py index 06a3ace7..ec575112 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import json import pathlib -import urllib.request +from urllib.error import HTTPError +from urllib.request import urlopen PACKAGE_NAME = 'apify' REPO_ROOT = pathlib.Path(__file__).parent.resolve() / '..' @@ -10,13 +13,12 @@ # Load the current version number from pyproject.toml # It is on a line in the format `version = "1.2.3"` def get_current_package_version() -> str: - with open(PYPROJECT_TOML_FILE_PATH, 'r', encoding='utf-8') as pyproject_toml_file: + with open(PYPROJECT_TOML_FILE_PATH, encoding='utf-8') as pyproject_toml_file: for line in pyproject_toml_file: if line.startswith('version = '): delim = '"' if '"' in line else "'" - version = line.split(delim)[1] - return version - else: + return line.split(delim)[1] + else: # noqa: PLW0120 raise RuntimeError('Unable to find version string.') @@ -29,7 +31,7 @@ def set_current_package_version(version: str) -> None: for line in pyproject_toml_file: if line.startswith('version = '): version_string_found = True - line = f'version = "{version}"\n' + line = f'version = "{version}"\n' # noqa: PLW2901 updated_pyproject_toml_file_lines.append(line) if not version_string_found: @@ -44,11 +46,11 @@ def set_current_package_version(version: str) -> None: def get_published_package_versions() -> list: package_info_url = f'https://pypi.org/pypi/{PACKAGE_NAME}/json' try: - package_data = json.load(urllib.request.urlopen(package_info_url)) + package_data = json.load(urlopen(package_info_url)) # noqa: S310 published_versions = list(package_data['releases'].keys()) # If the URL returns 404, it means the package has no releases yet (which is okay in our case) - except urllib.error.HTTPError as e: - if e.code != 404: - raise e + except HTTPError as exc: + if exc.code != 404: + raise published_versions = [] return published_versions diff --git a/src/apify/_crypto.py b/src/apify/_crypto.py index 12f5798d..5108a912 100644 --- a/src/apify/_crypto.py +++ b/src/apify/_crypto.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import secrets from typing import Any @@ -89,7 +91,7 @@ def private_decrypt( # Slice the encrypted into cypher and authentication tag authentication_tag_bytes = encrypted_value_bytes[-ENCRYPTION_AUTH_TAG_LENGTH:] - encrypted_data_bytes = encrypted_value_bytes[:len(encrypted_value_bytes) - ENCRYPTION_AUTH_TAG_LENGTH] + encrypted_data_bytes = encrypted_value_bytes[: (len(encrypted_value_bytes) - ENCRYPTION_AUTH_TAG_LENGTH)] encryption_key_bytes = password_bytes[:ENCRYPTION_KEY_LENGTH] initialization_vector_bytes = password_bytes[ENCRYPTION_KEY_LENGTH:] @@ -97,19 +99,21 @@ def private_decrypt( cipher = Cipher(algorithms.AES(encryption_key_bytes), modes.GCM(initialization_vector_bytes, authentication_tag_bytes)) decryptor = cipher.decryptor() decipher_bytes = decryptor.update(encrypted_data_bytes) + decryptor.finalize() - except InvalidTagException: - raise ValueError('Decryption failed, malformed encrypted value or password.') - except Exception as err: - raise err + except InvalidTagException as exc: + raise ValueError('Decryption failed, malformed encrypted value or password.') from exc + except Exception: + raise return decipher_bytes.decode('utf-8') -def _load_private_key(private_key_file_base64: str, private_key_password: str) -> rsa.RSAPrivateKey: - private_key = serialization.load_pem_private_key(base64.b64decode( - private_key_file_base64.encode('utf-8')), password=private_key_password.encode('utf-8')) +def load_private_key(private_key_file_base64: str, private_key_password: str) -> rsa.RSAPrivateKey: + private_key = serialization.load_pem_private_key( + base64.b64decode(private_key_file_base64.encode('utf-8')), + password=private_key_password.encode('utf-8'), + ) if not isinstance(private_key, rsa.RSAPrivateKey): - raise ValueError('Invalid private key.') + raise TypeError('Invalid private key.') return private_key @@ -117,7 +121,7 @@ def _load_private_key(private_key_file_base64: str, private_key_password: str) - def _load_public_key(public_key_file_base64: str) -> rsa.RSAPublicKey: public_key = serialization.load_pem_public_key(base64.b64decode(public_key_file_base64.encode('utf-8'))) if not isinstance(public_key, rsa.RSAPublicKey): - raise ValueError('Invalid public key.') + raise TypeError('Invalid public key.') return public_key @@ -128,7 +132,7 @@ def crypto_random_object_id(length: int = 17) -> str: return ''.join(secrets.choice(chars) for _ in range(length)) -def _decrypt_input_secrets(private_key: rsa.RSAPrivateKey, input: Any) -> Any: +def decrypt_input_secrets(private_key: rsa.RSAPrivateKey, input: Any) -> Any: # noqa: A002 """Decrypt input secrets.""" if not isinstance(input, dict): return input diff --git a/src/apify/_memory_storage/file_storage_utils.py b/src/apify/_memory_storage/file_storage_utils.py index f3764a9f..c4591ba8 100644 --- a/src/apify/_memory_storage/file_storage_utils.py +++ b/src/apify/_memory_storage/file_storage_utils.py @@ -1,15 +1,16 @@ +from __future__ import annotations + import os -from typing import Dict, List, Tuple import aiofiles from aiofiles.os import makedirs from apify_shared.utils import json_dumps -from .._utils import _force_remove +from .._utils import force_remove -async def _update_metadata(*, data: Dict, entity_directory: str, write_metadata: bool) -> None: +async def update_metadata(*, data: dict, entity_directory: str, write_metadata: bool) -> None: # Skip writing the actual metadata file. This is done after ensuring the directory exists so we have the directory present if not write_metadata: return @@ -25,7 +26,7 @@ async def _update_metadata(*, data: Dict, entity_directory: str, write_metadata: async def _update_dataset_items( *, - data: List[Tuple[str, Dict]], + data: list[tuple[str, dict]], entity_directory: str, persist_storage: bool, ) -> None: @@ -43,10 +44,10 @@ async def _update_dataset_items( await f.write(json_dumps(item).encode('utf-8')) -async def _update_request_queue_item( +async def update_request_queue_item( *, request_id: str, - request: Dict, + request: dict, entity_directory: str, persist_storage: bool, ) -> None: @@ -63,9 +64,9 @@ async def _update_request_queue_item( await f.write(json_dumps(request).encode('utf-8')) -async def _delete_request(*, request_id: str, entity_directory: str) -> None: +async def delete_request(*, request_id: str, entity_directory: str) -> None: # Ensure the directory for the entity exists await makedirs(entity_directory, exist_ok=True) file_path = os.path.join(entity_directory, f'{request_id}.json') - await _force_remove(file_path) + await force_remove(file_path) diff --git a/src/apify/_memory_storage/memory_storage_client.py b/src/apify/_memory_storage/memory_storage_client.py index ff2c7f4d..3c42cdad 100644 --- a/src/apify/_memory_storage/memory_storage_client.py +++ b/src/apify/_memory_storage/memory_storage_client.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import asyncio import contextlib import os from pathlib import Path -from typing import List, Optional import aioshutil from aiofiles import ospath @@ -11,7 +12,7 @@ from apify_shared.consts import ApifyEnvVars from apify_shared.utils import ignore_docs -from .._utils import _maybe_parse_bool +from .._utils import maybe_parse_bool from .resource_clients.dataset import DatasetClient from .resource_clients.dataset_collection import DatasetCollectionClient from .resource_clients.key_value_store import KeyValueStoreClient @@ -37,9 +38,9 @@ class MemoryStorageClient: _request_queues_directory: str _write_metadata: bool _persist_storage: bool - _datasets_handled: List[DatasetClient] - _key_value_stores_handled: List[KeyValueStoreClient] - _request_queues_handled: List[RequestQueueClient] + _datasets_handled: list[DatasetClient] + _key_value_stores_handled: list[KeyValueStoreClient] + _request_queues_handled: list[RequestQueueClient] _purged_on_start: bool = False _purge_lock: asyncio.Lock @@ -47,7 +48,11 @@ class MemoryStorageClient: """Indicates whether a purge was already performed on this instance""" def __init__( - self, *, local_data_directory: Optional[str] = None, write_metadata: Optional[bool] = None, persist_storage: Optional[bool] = None, + self: MemoryStorageClient, + *, + local_data_directory: str | None = None, + write_metadata: bool | None = None, + persist_storage: bool | None = None, ) -> None: """Initialize the MemoryStorageClient. @@ -61,17 +66,17 @@ def __init__( self._key_value_stores_directory = os.path.join(self._local_data_directory, 'key_value_stores') self._request_queues_directory = os.path.join(self._local_data_directory, 'request_queues') self._write_metadata = write_metadata if write_metadata is not None else '*' in os.getenv('DEBUG', '') - self._persist_storage = persist_storage if persist_storage is not None else _maybe_parse_bool(os.getenv(ApifyEnvVars.PERSIST_STORAGE, 'true')) + self._persist_storage = persist_storage if persist_storage is not None else maybe_parse_bool(os.getenv(ApifyEnvVars.PERSIST_STORAGE, 'true')) self._datasets_handled = [] self._key_value_stores_handled = [] self._request_queues_handled = [] self._purge_lock = asyncio.Lock() - def datasets(self) -> DatasetCollectionClient: + def datasets(self: MemoryStorageClient) -> DatasetCollectionClient: """Retrieve the sub-client for manipulating datasets.""" return DatasetCollectionClient(base_storage_directory=self._datasets_directory, memory_storage_client=self) - def dataset(self, dataset_id: str) -> DatasetClient: + def dataset(self: MemoryStorageClient, dataset_id: str) -> DatasetClient: """Retrieve the sub-client for manipulating a single dataset. Args: @@ -79,11 +84,11 @@ def dataset(self, dataset_id: str) -> DatasetClient: """ return DatasetClient(base_storage_directory=self._datasets_directory, memory_storage_client=self, id=dataset_id) - def key_value_stores(self) -> KeyValueStoreCollectionClient: + def key_value_stores(self: MemoryStorageClient) -> KeyValueStoreCollectionClient: """Retrieve the sub-client for manipulating key-value stores.""" return KeyValueStoreCollectionClient(base_storage_directory=self._key_value_stores_directory, memory_storage_client=self) - def key_value_store(self, key_value_store_id: str) -> KeyValueStoreClient: + def key_value_store(self: MemoryStorageClient, key_value_store_id: str) -> KeyValueStoreClient: """Retrieve the sub-client for manipulating a single key-value store. Args: @@ -91,11 +96,16 @@ def key_value_store(self, key_value_store_id: str) -> KeyValueStoreClient: """ return KeyValueStoreClient(base_storage_directory=self._key_value_stores_directory, memory_storage_client=self, id=key_value_store_id) - def request_queues(self) -> RequestQueueCollectionClient: + def request_queues(self: MemoryStorageClient) -> RequestQueueCollectionClient: """Retrieve the sub-client for manipulating request queues.""" return RequestQueueCollectionClient(base_storage_directory=self._request_queues_directory, memory_storage_client=self) - def request_queue(self, request_queue_id: str, *, client_key: Optional[str] = None) -> RequestQueueClient: # noqa: U100 + def request_queue( + self: MemoryStorageClient, + request_queue_id: str, + *, + client_key: str | None = None, # noqa: ARG002 + ) -> RequestQueueClient: """Retrieve the sub-client for manipulating a single request queue. Args: @@ -104,7 +114,7 @@ def request_queue(self, request_queue_id: str, *, client_key: Optional[str] = No """ return RequestQueueClient(base_storage_directory=self._request_queues_directory, memory_storage_client=self, id=request_queue_id) - async def _purge_on_start(self) -> None: + async def _purge_on_start(self: MemoryStorageClient) -> None: # Optimistic, non-blocking check if self._purged_on_start is True: return @@ -117,7 +127,7 @@ async def _purge_on_start(self) -> None: await self._purge() self._purged_on_start = True - async def _purge(self) -> None: + async def _purge(self: MemoryStorageClient) -> None: """Clean up the default storage directories before the run starts. Specifically, `purge` cleans up: @@ -147,7 +157,7 @@ async def _purge(self) -> None: if request_queue_folder.name == 'default' or request_queue_folder.name.startswith('__APIFY_TEMPORARY'): await self._batch_remove_files(request_queue_folder.path) - async def _handle_default_key_value_store(self, folder: str) -> None: + async def _handle_default_key_value_store(self: MemoryStorageClient, folder: str) -> None: """Remove everything from the default key-value store folder except `possible_input_keys`.""" folder_exists = await ospath.exists(folder) temporary_path = os.path.normpath(os.path.join(folder, '../__APIFY_MIGRATING_KEY_VALUE_STORE__')) @@ -175,13 +185,13 @@ async def _handle_default_key_value_store(self, folder: str) -> None: counter = 0 temp_path_for_old_folder = os.path.normpath(os.path.join(folder, f'../__OLD_DEFAULT_{counter}__')) done = False - while not done: - try: + try: + while not done: await rename(folder, temp_path_for_old_folder) done = True - except Exception: - counter += 1 - temp_path_for_old_folder = os.path.normpath(os.path.join(folder, f'../__OLD_DEFAULT_{counter}__')) + except Exception: + counter += 1 + temp_path_for_old_folder = os.path.normpath(os.path.join(folder, f'../__OLD_DEFAULT_{counter}__')) # Replace the temporary folder with the original folder await rename(temporary_path, folder) @@ -189,12 +199,15 @@ async def _handle_default_key_value_store(self, folder: str) -> None: # Remove the old folder await self._batch_remove_files(temp_path_for_old_folder) - async def _batch_remove_files(self, folder: str, counter: int = 0) -> None: + async def _batch_remove_files(self: MemoryStorageClient, folder: str, counter: int = 0) -> None: folder_exists = await ospath.exists(folder) if folder_exists: - temporary_folder = folder if os.path.basename(folder).startswith('__APIFY_TEMPORARY_') else os.path.normpath( - os.path.join(folder, f'../__APIFY_TEMPORARY_{counter}__')) + temporary_folder = ( + folder + if os.path.basename(folder).startswith('__APIFY_TEMPORARY_') + else os.path.normpath(os.path.join(folder, f'../__APIFY_TEMPORARY_{counter}__')) + ) try: # Rename the old folder to the new one to allow background deletions @@ -204,3 +217,4 @@ async def _batch_remove_files(self, folder: str, counter: int = 0) -> None: return await self._batch_remove_files(folder, counter + 1) await aioshutil.rmtree(temporary_folder, ignore_errors=True) + return None diff --git a/src/apify/_memory_storage/resource_clients/__init__.py b/src/apify/_memory_storage/resource_clients/__init__.py index b0cb4dfe..0a79ebe3 100644 --- a/src/apify/_memory_storage/resource_clients/__init__.py +++ b/src/apify/_memory_storage/resource_clients/__init__.py @@ -7,6 +7,13 @@ from .request_queue import RequestQueueClient from .request_queue_collection import RequestQueueCollectionClient -__all__ = ['BaseResourceClient', 'BaseResourceCollectionClient', - 'DatasetClient', 'DatasetCollectionClient', 'KeyValueStoreClient', - 'KeyValueStoreCollectionClient', 'RequestQueueClient', 'RequestQueueCollectionClient'] +__all__ = [ + 'BaseResourceClient', + 'BaseResourceCollectionClient', + 'DatasetClient', + 'DatasetCollectionClient', + 'KeyValueStoreClient', + 'KeyValueStoreCollectionClient', + 'RequestQueueClient', + 'RequestQueueCollectionClient', +] diff --git a/src/apify/_memory_storage/resource_clients/base_resource_client.py b/src/apify/_memory_storage/resource_clients/base_resource_client.py index d555330c..d68b0584 100644 --- a/src/apify/_memory_storage/resource_clients/base_resource_client.py +++ b/src/apify/_memory_storage/resource_clients/base_resource_client.py @@ -1,9 +1,9 @@ +from __future__ import annotations + import json import os from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, List, Optional - -from typing_extensions import Self +from typing import TYPE_CHECKING from apify_shared.utils import ignore_docs @@ -16,23 +16,23 @@ class BaseResourceClient(ABC): """Base class for resource clients.""" _id: str - _name: Optional[str] + _name: str | None _resource_directory: str @abstractmethod def __init__( - self, + self: BaseResourceClient, *, base_storage_directory: str, - memory_storage_client: 'MemoryStorageClient', - id: Optional[str] = None, - name: Optional[str] = None, + memory_storage_client: MemoryStorageClient, + id: str | None = None, # noqa: A002 + name: str | None = None, ) -> None: """Initialize the BaseResourceClient.""" raise NotImplementedError('You must override this method in the subclass!') @abstractmethod - async def get(self) -> Optional[Dict]: + async def get(self: BaseResourceClient) -> dict | None: """Retrieve the storage. Returns: @@ -42,44 +42,53 @@ async def get(self) -> Optional[Dict]: @classmethod @abstractmethod - def _get_storages_dir(cls, memory_storage_client: 'MemoryStorageClient') -> str: + def _get_storages_dir(cls: type[BaseResourceClient], memory_storage_client: MemoryStorageClient) -> str: raise NotImplementedError('You must override this method in the subclass!') @classmethod @abstractmethod - def _get_storage_client_cache(cls, memory_storage_client: 'MemoryStorageClient') -> List[Self]: + def _get_storage_client_cache( + cls: type[BaseResourceClient], + memory_storage_client: MemoryStorageClient, + ) -> list[BaseResourceClient]: raise NotImplementedError('You must override this method in the subclass!') @abstractmethod - def _to_resource_info(self) -> Dict: + def _to_resource_info(self: BaseResourceClient) -> dict: raise NotImplementedError('You must override this method in the subclass!') @classmethod @abstractmethod def _create_from_directory( - cls, + cls: type[BaseResourceClient], storage_directory: str, - memory_storage_client: 'MemoryStorageClient', - id: Optional[str] = None, - name: Optional[str] = None, - ) -> Self: + memory_storage_client: MemoryStorageClient, + id: str | None = None, # noqa: A002 + name: str | None = None, + ) -> BaseResourceClient: raise NotImplementedError('You must override this method in the subclass!') @classmethod def _find_or_create_client_by_id_or_name( - cls, - memory_storage_client: 'MemoryStorageClient', - id: Optional[str] = None, - name: Optional[str] = None, - ) -> Optional[Self]: - assert id is not None or name is not None + cls: type[BaseResourceClient], + memory_storage_client: MemoryStorageClient, + id: str | None = None, # noqa: A002 + name: str | None = None, + ) -> BaseResourceClient | None: + assert id is not None or name is not None # noqa: S101 storage_client_cache = cls._get_storage_client_cache(memory_storage_client) storages_dir = cls._get_storages_dir(memory_storage_client) # First check memory cache - found = next((storage_client for storage_client in storage_client_cache - if storage_client._id == id or (storage_client._name and name and storage_client._name.lower() == name.lower())), None) + found = next( + ( + storage_client + for storage_client in storage_client_cache + if storage_client._id == id or (storage_client._name and name and storage_client._name.lower() == name.lower()) + ), + None, + ) if found is not None: return found @@ -108,12 +117,13 @@ def _find_or_create_client_by_id_or_name( break if name and name == metadata.get('name'): storage_path = entry.path - id = metadata.get(id) + id = metadata.get(id) # noqa: A001 break # As a last resort, try to check if the accessed storage is the default one, # and the folder has no metadata # TODO: make this respect the APIFY_DEFAULT_XXX_ID env var + # https://github.com/apify/apify-sdk-python/issues/149 if id == 'default': possible_storage_path = os.path.join(storages_dir, id) if os.access(possible_storage_path, os.F_OK): diff --git a/src/apify/_memory_storage/resource_clients/base_resource_collection_client.py b/src/apify/_memory_storage/resource_clients/base_resource_collection_client.py index 8c80e849..2b003e52 100644 --- a/src/apify/_memory_storage/resource_clients/base_resource_collection_client.py +++ b/src/apify/_memory_storage/resource_clients/base_resource_collection_client.py @@ -1,18 +1,20 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from operator import itemgetter -from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Type, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar, cast from apify_shared.models import ListPage from apify_shared.utils import ignore_docs -from ..file_storage_utils import _update_metadata +from ..file_storage_utils import update_metadata from .base_resource_client import BaseResourceClient if TYPE_CHECKING: from ..memory_storage_client import MemoryStorageClient -ResourceClientType = TypeVar('ResourceClientType', bound=BaseResourceClient, contravariant=True) +ResourceClientType = TypeVar('ResourceClientType', bound=BaseResourceClient, contravariant=True) # noqa: PLC0105 @ignore_docs @@ -20,23 +22,28 @@ class BaseResourceCollectionClient(ABC, Generic[ResourceClientType]): """Base class for resource collection clients.""" _base_storage_directory: str - _memory_storage_client: 'MemoryStorageClient' + _memory_storage_client: MemoryStorageClient - def __init__(self, *, base_storage_directory: str, memory_storage_client: 'MemoryStorageClient') -> None: + def __init__( + self: BaseResourceCollectionClient, + *, + base_storage_directory: str, + memory_storage_client: MemoryStorageClient, + ) -> None: """Initialize the DatasetCollectionClient with the passed arguments.""" self._base_storage_directory = base_storage_directory self._memory_storage_client = memory_storage_client @abstractmethod - def _get_storage_client_cache(self) -> List[ResourceClientType]: + def _get_storage_client_cache(self: BaseResourceCollectionClient) -> list[ResourceClientType]: raise NotImplementedError('You must override this method in the subclass!') @abstractmethod - def _get_resource_client_class(self) -> Type[ResourceClientType]: + def _get_resource_client_class(self: BaseResourceCollectionClient) -> type[ResourceClientType]: raise NotImplementedError('You must override this method in the subclass!') @abstractmethod - async def list(self) -> ListPage: + async def list(self: BaseResourceCollectionClient) -> ListPage: # noqa: A003 """List the available storages. Returns: @@ -46,23 +53,25 @@ async def list(self) -> ListPage: items = [storage._to_resource_info() for storage in storage_client_cache] - return ListPage({ - 'total': len(items), - 'count': len(items), - 'offset': 0, - 'limit': len(items), - 'desc': False, - 'items': sorted(items, key=itemgetter('createdAt')), - }) + return ListPage( + { + 'total': len(items), + 'count': len(items), + 'offset': 0, + 'limit': len(items), + 'desc': False, + 'items': sorted(items, key=itemgetter('createdAt')), + } + ) @abstractmethod async def get_or_create( - self, + self: BaseResourceCollectionClient, *, - name: Optional[str] = None, - schema: Optional[Dict] = None, # noqa: U100 - _id: Optional[str] = None, - ) -> Dict: + name: str | None = None, + schema: dict | None = None, + _id: str | None = None, + ) -> dict: """Retrieve a named storage, or create a new one when it doesn't exist. Args: @@ -76,9 +85,14 @@ async def get_or_create( storage_client_cache = self._get_storage_client_cache() if name or _id: - found = resource_client_class._find_or_create_client_by_id_or_name(memory_storage_client=self._memory_storage_client, name=name, id=_id) + found = resource_client_class._find_or_create_client_by_id_or_name( + memory_storage_client=self._memory_storage_client, + name=name, + id=_id, + ) if found: - return found._to_resource_info() + resource_info = found._to_resource_info() + return cast(dict, resource_info) new_resource = resource_client_class( id=_id, @@ -91,10 +105,10 @@ async def get_or_create( resource_info = new_resource._to_resource_info() # Write to the disk - await _update_metadata( + await update_metadata( data=resource_info, entity_directory=new_resource._resource_directory, write_metadata=self._memory_storage_client._write_metadata, ) - return resource_info + return cast(dict, resource_info) diff --git a/src/apify/_memory_storage/resource_clients/dataset.py b/src/apify/_memory_storage/resource_clients/dataset.py index ed21c0d3..f0c9d119 100644 --- a/src/apify/_memory_storage/resource_clients/dataset.py +++ b/src/apify/_memory_storage/resource_clients/dataset.py @@ -1,22 +1,25 @@ +from __future__ import annotations + import asyncio import json import os from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, AsyncIterator import aioshutil from apify_shared.models import ListPage -from apify_shared.types import JSONSerializable from apify_shared.utils import ignore_docs from ..._crypto import crypto_random_object_id -from ..._utils import _force_rename, _raise_on_duplicate_storage, _raise_on_non_existing_storage -from ...consts import _StorageTypes -from ..file_storage_utils import _update_dataset_items, _update_metadata +from ..._utils import force_rename, raise_on_duplicate_storage, raise_on_non_existing_storage +from ...consts import StorageTypes +from ..file_storage_utils import _update_dataset_items, update_metadata from .base_resource_client import BaseResourceClient if TYPE_CHECKING: + from apify_shared.types import JSONSerializable + from ..memory_storage_client import MemoryStorageClient # This is what API returns in the x-apify-pagination-limit @@ -34,9 +37,9 @@ class DatasetClient(BaseResourceClient): _id: str _resource_directory: str - _memory_storage_client: 'MemoryStorageClient' - _name: Optional[str] - _dataset_entries: Dict[str, Dict] + _memory_storage_client: MemoryStorageClient + _name: str | None + _dataset_entries: dict[str, dict] _created_at: datetime _accessed_at: datetime _modified_at: datetime @@ -44,12 +47,12 @@ class DatasetClient(BaseResourceClient): _file_operation_lock: asyncio.Lock def __init__( - self, + self: DatasetClient, *, base_storage_directory: str, - memory_storage_client: 'MemoryStorageClient', - id: Optional[str] = None, - name: Optional[str] = None, + memory_storage_client: MemoryStorageClient, + id: str | None = None, # noqa: A002 + name: str | None = None, ) -> None: """Initialize the DatasetClient.""" self._id = id or crypto_random_object_id() @@ -62,7 +65,7 @@ def __init__( self._modified_at = datetime.now(timezone.utc) self._file_operation_lock = asyncio.Lock() - async def get(self) -> Optional[Dict]: + async def get(self: DatasetClient) -> dict | None: """Retrieve the dataset. Returns: @@ -71,13 +74,13 @@ async def get(self) -> Optional[Dict]: found = self._find_or_create_client_by_id_or_name(memory_storage_client=self._memory_storage_client, id=self._id, name=self._name) if found: - async with found._file_operation_lock: - await found._update_timestamps(False) + async with found._file_operation_lock: # type: ignore + await found._update_timestamps(has_been_modified=False) # type: ignore return found._to_resource_info() return None - async def update(self, *, name: Optional[str] = None) -> Dict: + async def update(self: DatasetClient, *, name: str | None = None) -> dict: """Update the dataset with specified fields. Args: @@ -94,13 +97,13 @@ async def update(self, *, name: Optional[str] = None) -> Dict: ) if existing_dataset_by_id is None: - _raise_on_non_existing_storage(_StorageTypes.DATASET, self._id) + raise_on_non_existing_storage(StorageTypes.DATASET, self._id) # Skip if no changes if name is None: return existing_dataset_by_id._to_resource_info() - async with existing_dataset_by_id._file_operation_lock: + async with existing_dataset_by_id._file_operation_lock: # type: ignore # Check that name is not in use already existing_dataset_by_name = next( (dataset for dataset in self._memory_storage_client._datasets_handled if dataset._name and dataset._name.lower() == name.lower()), @@ -108,7 +111,7 @@ async def update(self, *, name: Optional[str] = None) -> Dict: ) if existing_dataset_by_name is not None: - _raise_on_duplicate_storage(_StorageTypes.DATASET, 'name', name) + raise_on_duplicate_storage(StorageTypes.DATASET, 'name', name) existing_dataset_by_id._name = name @@ -116,14 +119,14 @@ async def update(self, *, name: Optional[str] = None) -> Dict: existing_dataset_by_id._resource_directory = os.path.join(self._memory_storage_client._datasets_directory, name) - await _force_rename(previous_dir, existing_dataset_by_id._resource_directory) + await force_rename(previous_dir, existing_dataset_by_id._resource_directory) # Update timestamps - await existing_dataset_by_id._update_timestamps(True) + await existing_dataset_by_id._update_timestamps(has_been_modified=True) # type: ignore return existing_dataset_by_id._to_resource_info() - async def delete(self) -> None: + async def delete(self: DatasetClient) -> None: """Delete the dataset.""" dataset = next((dataset for dataset in self._memory_storage_client._datasets_handled if dataset._id == self._id), None) @@ -137,19 +140,19 @@ async def delete(self) -> None: await aioshutil.rmtree(dataset._resource_directory) async def list_items( - self, + self: DatasetClient, *, - offset: Optional[int] = 0, - limit: Optional[int] = LIST_ITEMS_LIMIT, - clean: Optional[bool] = None, # noqa: U100 - desc: Optional[bool] = None, - fields: Optional[List[str]] = None, # noqa: U100 - omit: Optional[List[str]] = None, # noqa: U100 - unwind: Optional[str] = None, # noqa: U100 - skip_empty: Optional[bool] = None, # noqa: U100 - skip_hidden: Optional[bool] = None, # noqa: U100 - flatten: Optional[List[str]] = None, # noqa: U100 - view: Optional[str] = None, # noqa: U100 + offset: int | None = 0, + limit: int | None = LIST_ITEMS_LIMIT, + clean: bool | None = None, # noqa: ARG002 + desc: bool | None = None, + fields: list[str] | None = None, # noqa: ARG002 + omit: list[str] | None = None, # noqa: ARG002 + unwind: str | None = None, # noqa: ARG002 + skip_empty: bool | None = None, # noqa: ARG002 + skip_hidden: bool | None = None, # noqa: ARG002 + flatten: list[str] | None = None, # noqa: ARG002 + view: str | None = None, # noqa: ARG002 ) -> ListPage: """List the items of the dataset. @@ -188,11 +191,11 @@ async def list_items( ) if existing_dataset_by_id is None: - _raise_on_non_existing_storage(_StorageTypes.DATASET, self._id) + raise_on_non_existing_storage(StorageTypes.DATASET, self._id) - async with existing_dataset_by_id._file_operation_lock: - start, end = existing_dataset_by_id._get_start_and_end_indexes( - max(existing_dataset_by_id._item_count - (offset or 0) - (limit or LIST_ITEMS_LIMIT), 0) if desc else offset or 0, + async with existing_dataset_by_id._file_operation_lock: # type: ignore + start, end = existing_dataset_by_id._get_start_and_end_indexes( # type: ignore + max(existing_dataset_by_id._item_count - (offset or 0) - (limit or LIST_ITEMS_LIMIT), 0) if desc else offset or 0, # type: ignore limit, ) @@ -200,35 +203,37 @@ async def list_items( for idx in range(start, end): entry_number = self._generate_local_entry_name(idx) - items.append(existing_dataset_by_id._dataset_entries[entry_number]) + items.append(existing_dataset_by_id._dataset_entries[entry_number]) # type: ignore - await existing_dataset_by_id._update_timestamps(False) + await existing_dataset_by_id._update_timestamps(has_been_modified=False) # type: ignore if desc: items.reverse() - return ListPage({ - 'count': len(items), - 'desc': desc or False, - 'items': items, - 'limit': limit or LIST_ITEMS_LIMIT, - 'offset': offset or 0, - 'total': existing_dataset_by_id._item_count, - }) + return ListPage( + { + 'count': len(items), + 'desc': desc or False, + 'items': items, + 'limit': limit or LIST_ITEMS_LIMIT, + 'offset': offset or 0, + 'total': existing_dataset_by_id._item_count, # type: ignore + } + ) async def iterate_items( - self, + self: DatasetClient, *, offset: int = 0, - limit: Optional[int] = None, - clean: Optional[bool] = None, # noqa: U100 - desc: Optional[bool] = None, - fields: Optional[List[str]] = None, # noqa: U100 - omit: Optional[List[str]] = None, # noqa: U100 - unwind: Optional[str] = None, # noqa: U100 - skip_empty: Optional[bool] = None, # noqa: U100 - skip_hidden: Optional[bool] = None, # noqa: U100 - ) -> AsyncIterator[Dict]: + limit: int | None = None, + clean: bool | None = None, # noqa: ARG002 + desc: bool | None = None, + fields: list[str] | None = None, # noqa: ARG002 + omit: list[str] | None = None, # noqa: ARG002 + unwind: str | None = None, # noqa: ARG002 + skip_empty: bool | None = None, # noqa: ARG002 + skip_hidden: bool | None = None, # noqa: ARG002 + ) -> AsyncIterator[dict]: """Iterate over the items in the dataset. Args: @@ -260,17 +265,11 @@ async def iterate_items( first_item = offset # If there is no limit, set last_item to None until we get the total from the first API response - if limit is None: - last_item = None - else: - last_item = offset + limit + last_item = None if limit is None else offset + limit current_offset = first_item while last_item is None or current_offset < last_item: - if last_item is None: - current_limit = cache_size - else: - current_limit = min(cache_size, last_item - current_offset) + current_limit = cache_size if last_item is None else min(cache_size, last_item - current_offset) current_items_page = await self.list_items( offset=current_offset, @@ -285,13 +284,13 @@ async def iterate_items( for item in current_items_page.items: yield item - async def get_items_as_bytes(self, *_args: Any, **_kwargs: Any) -> bytes: # noqa: D102 + async def get_items_as_bytes(self: DatasetClient, *_args: Any, **_kwargs: Any) -> bytes: raise NotImplementedError('This method is not supported in local memory storage.') - async def stream_items(self, *_args: Any, **_kwargs: Any) -> AsyncIterator: # noqa: D102 + async def stream_items(self: DatasetClient, *_args: Any, **_kwargs: Any) -> AsyncIterator: raise NotImplementedError('This method is not supported in local memory storage.') - async def push_items(self, items: JSONSerializable) -> None: + async def push_items(self: DatasetClient, items: JSONSerializable) -> None: """Push items to the dataset. Args: @@ -299,27 +298,26 @@ async def push_items(self, items: JSONSerializable) -> None: """ # Check by id existing_dataset_by_id = self._find_or_create_client_by_id_or_name( - memory_storage_client=self._memory_storage_client, id=self._id, name=self._name) + memory_storage_client=self._memory_storage_client, id=self._id, name=self._name + ) if existing_dataset_by_id is None: - _raise_on_non_existing_storage(_StorageTypes.DATASET, self._id) + raise_on_non_existing_storage(StorageTypes.DATASET, self._id) normalized = self._normalize_items(items) - added_ids: List[str] = [] + added_ids: list[str] = [] for entry in normalized: - existing_dataset_by_id._item_count += 1 - idx = self._generate_local_entry_name(existing_dataset_by_id._item_count) + existing_dataset_by_id._item_count += 1 # type: ignore + idx = self._generate_local_entry_name(existing_dataset_by_id._item_count) # type: ignore - existing_dataset_by_id._dataset_entries[idx] = entry + existing_dataset_by_id._dataset_entries[idx] = entry # type: ignore added_ids.append(idx) - data_entries: List[Tuple[str, Dict]] = [] - for id in added_ids: - data_entries.append((id, existing_dataset_by_id._dataset_entries[id])) + data_entries = [(id, existing_dataset_by_id._dataset_entries[id]) for id in added_ids] # type: ignore # noqa: A001 - async with existing_dataset_by_id._file_operation_lock: - await existing_dataset_by_id._update_timestamps(True) + async with existing_dataset_by_id._file_operation_lock: # type: ignore + await existing_dataset_by_id._update_timestamps(has_been_modified=True) # type: ignore await _update_dataset_items( data=data_entries, @@ -327,7 +325,7 @@ async def push_items(self, items: JSONSerializable) -> None: persist_storage=self._memory_storage_client._persist_storage, ) - def _to_resource_info(self) -> Dict: + def _to_resource_info(self: DatasetClient) -> dict: """Retrieve the dataset info.""" return { 'id': self._id, @@ -338,7 +336,7 @@ def _to_resource_info(self) -> Dict: 'modifiedAt': self._modified_at, } - async def _update_timestamps(self, has_been_modified: bool) -> None: + async def _update_timestamps(self: DatasetClient, has_been_modified: bool) -> None: # noqa: FBT001 """Update the timestamps of the dataset.""" self._accessed_at = datetime.now(timezone.utc) @@ -346,63 +344,66 @@ async def _update_timestamps(self, has_been_modified: bool) -> None: self._modified_at = datetime.now(timezone.utc) dataset_info = self._to_resource_info() - await _update_metadata( + await update_metadata( data=dataset_info, entity_directory=self._resource_directory, write_metadata=self._memory_storage_client._write_metadata, ) - def _get_start_and_end_indexes(self, offset: int, limit: Optional[int] = None) -> Tuple[int, int]: + def _get_start_and_end_indexes(self: DatasetClient, offset: int, limit: int | None = None) -> tuple[int, int]: actual_limit = limit or self._item_count start = offset + 1 end = min(offset + actual_limit, self._item_count) + 1 return (start, end) - def _generate_local_entry_name(self, idx: int) -> str: + def _generate_local_entry_name(self: DatasetClient, idx: int) -> str: return str(idx).zfill(LOCAL_ENTRY_NAME_DIGITS) - def _normalize_items(self, items: JSONSerializable) -> List[Dict]: - def normalize_item(item: Any) -> Optional[Dict]: - if type(item) is str: + def _normalize_items(self: DatasetClient, items: JSONSerializable) -> list[dict]: + def normalize_item(item: Any) -> dict | None: + if isinstance(item, str): item = json.loads(item) - if type(item) is list: + if isinstance(item, list): received = ',\n'.join(item) - raise ValueError(f'Each dataset item can only be a single JSON object, not an array. Received: [{received}]') + raise TypeError(f'Each dataset item can only be a single JSON object, not an array. Received: [{received}]') - if type(item) is not dict and item is not None: - raise ValueError(f'Each dataset item must be a JSON object. Received: {item}') + if (not isinstance(item, dict)) and item is not None: + raise TypeError(f'Each dataset item must be a JSON object. Received: {item}') return item - if type(items) is str: + if isinstance(items, str): items = json.loads(items) - result = list(map(normalize_item, items)) if type(items) is list else [normalize_item(items)] + result = list(map(normalize_item, items)) if isinstance(items, list) else [normalize_item(items)] # filter(None, ..) returns items that are True return list(filter(None, result)) @classmethod - def _get_storages_dir(cls, memory_storage_client: 'MemoryStorageClient') -> str: + def _get_storages_dir(cls: type[DatasetClient], memory_storage_client: MemoryStorageClient) -> str: return memory_storage_client._datasets_directory @classmethod - def _get_storage_client_cache(cls, memory_storage_client: 'MemoryStorageClient') -> List['DatasetClient']: + def _get_storage_client_cache( # type: ignore + cls: type[DatasetClient], + memory_storage_client: MemoryStorageClient, + ) -> list[DatasetClient]: return memory_storage_client._datasets_handled @classmethod def _create_from_directory( - cls, + cls: type[DatasetClient], storage_directory: str, - memory_storage_client: 'MemoryStorageClient', - id: Optional[str] = None, - name: Optional[str] = None, - ) -> 'DatasetClient': + memory_storage_client: MemoryStorageClient, + id: str | None = None, # noqa: A002 + name: str | None = None, + ) -> DatasetClient: item_count = 0 created_at = datetime.now(timezone.utc) accessed_at = datetime.now(timezone.utc) modified_at = datetime.now(timezone.utc) - entries: Dict[str, Dict] = {} + entries: dict[str, dict] = {} has_seen_metadata_file = False @@ -415,7 +416,7 @@ def _create_from_directory( # We have found the dataset's metadata file, build out information based on it with open(os.path.join(storage_directory, entry.name), encoding='utf-8') as f: metadata = json.load(f) - id = metadata['id'] + id = metadata['id'] # noqa: A001 name = metadata['name'] item_count = metadata['itemCount'] created_at = datetime.fromisoformat(metadata['createdAt']) @@ -433,8 +434,12 @@ def _create_from_directory( if not has_seen_metadata_file: item_count += 1 - new_client = DatasetClient(base_storage_directory=memory_storage_client._datasets_directory, - memory_storage_client=memory_storage_client, id=id, name=name) + new_client = DatasetClient( + base_storage_directory=memory_storage_client._datasets_directory, + memory_storage_client=memory_storage_client, + id=id, + name=name, + ) # Overwrite properties new_client._accessed_at = accessed_at diff --git a/src/apify/_memory_storage/resource_clients/dataset_collection.py b/src/apify/_memory_storage/resource_clients/dataset_collection.py index 260043bf..24b5cd95 100644 --- a/src/apify/_memory_storage/resource_clients/dataset_collection.py +++ b/src/apify/_memory_storage/resource_clients/dataset_collection.py @@ -1,23 +1,27 @@ -from typing import Dict, List, Optional, Type +from __future__ import annotations + +from typing import TYPE_CHECKING -from apify_shared.models import ListPage from apify_shared.utils import ignore_docs from .base_resource_collection_client import BaseResourceCollectionClient from .dataset import DatasetClient +if TYPE_CHECKING: + from apify_shared.models import ListPage + @ignore_docs class DatasetCollectionClient(BaseResourceCollectionClient): """Sub-client for manipulating datasets.""" - def _get_storage_client_cache(self) -> List[DatasetClient]: + def _get_storage_client_cache(self: DatasetCollectionClient) -> list[DatasetClient]: return self._memory_storage_client._datasets_handled - def _get_resource_client_class(self) -> Type[DatasetClient]: + def _get_resource_client_class(self: DatasetCollectionClient) -> type[DatasetClient]: return DatasetClient - async def list(self) -> ListPage: + async def list(self: DatasetCollectionClient) -> ListPage: # noqa: A003 """List the available datasets. Returns: @@ -26,17 +30,17 @@ async def list(self) -> ListPage: return await super().list() async def get_or_create( - self, + self: DatasetCollectionClient, *, - name: Optional[str] = None, - schema: Optional[Dict] = None, - _id: Optional[str] = None, - ) -> Dict: + name: str | None = None, + schema: dict | None = None, + _id: str | None = None, + ) -> dict: """Retrieve a named dataset, or create a new one when it doesn't exist. Args: name (str, optional): The name of the dataset to retrieve or create. - schema (Dict, optional): The schema of the dataset + schema (dict, optional): The schema of the dataset Returns: dict: The retrieved or newly-created dataset. diff --git a/src/apify/_memory_storage/resource_clients/key_value_store.py b/src/apify/_memory_storage/resource_clients/key_value_store.py index cd33fca9..a77949a4 100644 --- a/src/apify/_memory_storage/resource_clients/key_value_store.py +++ b/src/apify/_memory_storage/resource_clients/key_value_store.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import io import json @@ -6,33 +8,39 @@ import pathlib from datetime import datetime, timezone from operator import itemgetter -from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, TypedDict +from typing import TYPE_CHECKING, Any, AsyncIterator, TypedDict import aiofiles import aioshutil from aiofiles.os import makedirs -from typing_extensions import NotRequired from apify_shared.utils import ignore_docs, is_file_or_bytes, json_dumps from ..._crypto import crypto_random_object_id from ..._utils import ( - _force_remove, - _force_rename, - _guess_file_extension, - _maybe_parse_body, - _raise_on_duplicate_storage, - _raise_on_non_existing_storage, + force_remove, + force_rename, + guess_file_extension, + maybe_parse_body, + raise_on_duplicate_storage, + raise_on_non_existing_storage, ) -from ...consts import DEFAULT_API_PARAM_LIMIT, _StorageTypes +from ...consts import DEFAULT_API_PARAM_LIMIT, StorageTypes from ...log import logger -from ..file_storage_utils import _update_metadata +from ..file_storage_utils import update_metadata from .base_resource_client import BaseResourceClient if TYPE_CHECKING: + from typing_extensions import NotRequired + from ..memory_storage_client import MemoryStorageClient -KeyValueStoreRecord = TypedDict('KeyValueStoreRecord', {'key': str, 'value': Any, 'contentType': Optional[str], 'filename': NotRequired[str]}) + +class KeyValueStoreRecord(TypedDict): + key: str + value: Any + contentType: str | None + filename: NotRequired[str] def _filename_from_record(record: KeyValueStoreRecord) -> str: @@ -42,12 +50,12 @@ def _filename_from_record(record: KeyValueStoreRecord) -> str: content_type = record.get('contentType') if not content_type or content_type == 'application/octet-stream': return record['key'] - else: - extension = _guess_file_extension(content_type) - if record['key'].endswith(f'.{extension}'): - return record['key'] - else: - return f'{record["key"]}.{extension}' + + extension = guess_file_extension(content_type) + if record['key'].endswith(f'.{extension}'): + return record['key'] + + return f'{record["key"]}.{extension}' @ignore_docs @@ -56,21 +64,21 @@ class KeyValueStoreClient(BaseResourceClient): _id: str _resource_directory: str - _memory_storage_client: 'MemoryStorageClient' - _name: Optional[str] - _records: Dict[str, KeyValueStoreRecord] + _memory_storage_client: MemoryStorageClient + _name: str | None + _records: dict[str, KeyValueStoreRecord] _created_at: datetime _accessed_at: datetime _modified_at: datetime _file_operation_lock: asyncio.Lock def __init__( - self, + self: KeyValueStoreClient, *, base_storage_directory: str, - memory_storage_client: 'MemoryStorageClient', - id: Optional[str] = None, - name: Optional[str] = None, + memory_storage_client: MemoryStorageClient, + id: str | None = None, # noqa: A002 + name: str | None = None, ) -> None: """Initialize the KeyValueStoreClient.""" self._id = id or crypto_random_object_id() @@ -83,7 +91,7 @@ def __init__( self._modified_at = datetime.now(timezone.utc) self._file_operation_lock = asyncio.Lock() - async def get(self) -> Optional[Dict]: + async def get(self: KeyValueStoreClient) -> dict | None: """Retrieve the key-value store. Returns: @@ -92,13 +100,13 @@ async def get(self) -> Optional[Dict]: found = self._find_or_create_client_by_id_or_name(memory_storage_client=self._memory_storage_client, id=self._id, name=self._name) if found: - async with found._file_operation_lock: - await found._update_timestamps(False) + async with found._file_operation_lock: # type: ignore + await found._update_timestamps(has_been_modified=False) # type: ignore return found._to_resource_info() return None - async def update(self, *, name: Optional[str] = None) -> Dict: + async def update(self: KeyValueStoreClient, *, name: str | None = None) -> dict: """Update the key-value store with specified fields. Args: @@ -109,16 +117,17 @@ async def update(self, *, name: Optional[str] = None) -> Dict: """ # Check by id existing_store_by_id = self._find_or_create_client_by_id_or_name( - memory_storage_client=self._memory_storage_client, id=self._id, name=self._name) + memory_storage_client=self._memory_storage_client, id=self._id, name=self._name + ) if existing_store_by_id is None: - _raise_on_non_existing_storage(_StorageTypes.KEY_VALUE_STORE, self._id) + raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self._id) # Skip if no changes if name is None: return existing_store_by_id._to_resource_info() - async with existing_store_by_id._file_operation_lock: + async with existing_store_by_id._file_operation_lock: # type: ignore # Check that name is not in use already existing_store_by_name = next( (store for store in self._memory_storage_client._key_value_stores_handled if store._name and store._name.lower() == name.lower()), @@ -126,7 +135,7 @@ async def update(self, *, name: Optional[str] = None) -> Dict: ) if existing_store_by_name is not None: - _raise_on_duplicate_storage(_StorageTypes.KEY_VALUE_STORE, 'name', name) + raise_on_duplicate_storage(StorageTypes.KEY_VALUE_STORE, 'name', name) existing_store_by_id._name = name @@ -134,14 +143,14 @@ async def update(self, *, name: Optional[str] = None) -> Dict: existing_store_by_id._resource_directory = os.path.join(self._memory_storage_client._key_value_stores_directory, name) - await _force_rename(previous_dir, existing_store_by_id._resource_directory) + await force_rename(previous_dir, existing_store_by_id._resource_directory) # Update timestamps - await existing_store_by_id._update_timestamps(True) + await existing_store_by_id._update_timestamps(has_been_modified=True) # type: ignore return existing_store_by_id._to_resource_info() - async def delete(self) -> None: + async def delete(self: KeyValueStoreClient) -> None: """Delete the key-value store.""" store = next((store for store in self._memory_storage_client._key_value_stores_handled if store._id == self._id), None) @@ -153,7 +162,12 @@ async def delete(self) -> None: if os.path.exists(store._resource_directory): await aioshutil.rmtree(store._resource_directory) - async def list_keys(self, *, limit: int = DEFAULT_API_PARAM_LIMIT, exclusive_start_key: Optional[str] = None) -> Dict: + async def list_keys( + self: KeyValueStoreClient, + *, + limit: int = DEFAULT_API_PARAM_LIMIT, + exclusive_start_key: str | None = None, + ) -> dict: """List the keys in the key-value store. Args: @@ -165,19 +179,22 @@ async def list_keys(self, *, limit: int = DEFAULT_API_PARAM_LIMIT, exclusive_sta """ # Check by id existing_store_by_id = self._find_or_create_client_by_id_or_name( - memory_storage_client=self._memory_storage_client, id=self._id, name=self._name) + memory_storage_client=self._memory_storage_client, id=self._id, name=self._name + ) if existing_store_by_id is None: - _raise_on_non_existing_storage(_StorageTypes.KEY_VALUE_STORE, self._id) + raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self._id) items = [] - for record in existing_store_by_id._records.values(): + for record in existing_store_by_id._records.values(): # type: ignore size = len(record['value']) - items.append({ - 'key': record['key'], - 'size': size, - }) + items.append( + { + 'key': record['key'], + 'size': size, + } + ) if len(items) == 0: return { @@ -196,7 +213,7 @@ async def list_keys(self, *, limit: int = DEFAULT_API_PARAM_LIMIT, exclusive_sta if exclusive_start_key is not None: key_pos = next((idx for idx, i in enumerate(items) if i['key'] == exclusive_start_key), None) if key_pos is not None: - truncated_items = items[key_pos + 1:] + truncated_items = items[(key_pos + 1) :] limited_items = truncated_items[:limit] @@ -205,8 +222,8 @@ async def list_keys(self, *, limit: int = DEFAULT_API_PARAM_LIMIT, exclusive_sta is_last_selected_item_absolutely_last = last_item_in_store == last_selected_item next_exclusive_start_key = None if is_last_selected_item_absolutely_last else last_selected_item['key'] - async with existing_store_by_id._file_operation_lock: - await existing_store_by_id._update_timestamps(False) + async with existing_store_by_id._file_operation_lock: # type: ignore + await existing_store_by_id._update_timestamps(has_been_modified=False) # type: ignore return { 'count': len(items), @@ -217,15 +234,20 @@ async def list_keys(self, *, limit: int = DEFAULT_API_PARAM_LIMIT, exclusive_sta 'items': limited_items, } - async def _get_record_internal(self, key: str, as_bytes: bool = False) -> Optional[Dict]: + async def _get_record_internal( + self: KeyValueStoreClient, + key: str, + as_bytes: bool = False, # noqa: FBT001, FBT002 + ) -> dict | None: # Check by id existing_store_by_id = self._find_or_create_client_by_id_or_name( - memory_storage_client=self._memory_storage_client, id=self._id, name=self._name) + memory_storage_client=self._memory_storage_client, id=self._id, name=self._name + ) if existing_store_by_id is None: - _raise_on_non_existing_storage(_StorageTypes.KEY_VALUE_STORE, self._id) + raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self._id) - stored_record = existing_store_by_id._records.get(key) + stored_record = existing_store_by_id._records.get(key) # type: ignore if stored_record is None: return None @@ -238,16 +260,16 @@ async def _get_record_internal(self, key: str, as_bytes: bool = False) -> Option if not as_bytes: try: - record['value'] = _maybe_parse_body(record['value'], record['contentType']) + record['value'] = maybe_parse_body(record['value'], record['contentType']) except ValueError: logger.exception('Error parsing key-value store record') - async with existing_store_by_id._file_operation_lock: - await existing_store_by_id._update_timestamps(False) + async with existing_store_by_id._file_operation_lock: # type: ignore + await existing_store_by_id._update_timestamps(has_been_modified=False) # type: ignore return record - async def get_record(self, key: str) -> Optional[Dict]: + async def get_record(self: KeyValueStoreClient, key: str) -> dict | None: """Retrieve the given record from the key-value store. Args: @@ -258,7 +280,7 @@ async def get_record(self, key: str) -> Optional[Dict]: """ return await self._get_record_internal(key) - async def get_record_as_bytes(self, key: str) -> Optional[Dict]: + async def get_record_as_bytes(self: KeyValueStoreClient, key: str) -> dict | None: """Retrieve the given record from the key-value store, without parsing it. Args: @@ -269,10 +291,10 @@ async def get_record_as_bytes(self, key: str) -> Optional[Dict]: """ return await self._get_record_internal(key, as_bytes=True) - async def stream_record(self, _key: str) -> AsyncIterator[Optional[Dict]]: # noqa: D102 + async def stream_record(self: KeyValueStoreClient, _key: str) -> AsyncIterator[dict | None]: raise NotImplementedError('This method is not supported in local memory storage.') - async def set_record(self, key: str, value: Any, content_type: Optional[str] = None) -> None: + async def set_record(self: KeyValueStoreClient, key: str, value: Any, content_type: str | None = None) -> None: """Set a value to the given record in the key-value store. Args: @@ -282,10 +304,11 @@ async def set_record(self, key: str, value: Any, content_type: Optional[str] = N """ # Check by id existing_store_by_id = self._find_or_create_client_by_id_or_name( - memory_storage_client=self._memory_storage_client, id=self._id, name=self._name) + memory_storage_client=self._memory_storage_client, id=self._id, name=self._name + ) if existing_store_by_id is None: - _raise_on_non_existing_storage(_StorageTypes.KEY_VALUE_STORE, self._id) + raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self._id) if isinstance(value, io.IOBase): raise NotImplementedError('File-like values are not supported in local memory storage') @@ -301,24 +324,24 @@ async def set_record(self, key: str, value: Any, content_type: Optional[str] = N if 'application/json' in content_type and not is_file_or_bytes(value) and not isinstance(value, str): value = json_dumps(value).encode('utf-8') - async with existing_store_by_id._file_operation_lock: - await existing_store_by_id._update_timestamps(True) + async with existing_store_by_id._file_operation_lock: # type: ignore + await existing_store_by_id._update_timestamps(has_been_modified=True) # type: ignore record: KeyValueStoreRecord = { 'key': key, 'value': value, 'contentType': content_type, } - old_record = existing_store_by_id._records.get(key) - existing_store_by_id._records[key] = record + old_record = existing_store_by_id._records.get(key) # type: ignore + existing_store_by_id._records[key] = record # type: ignore if self._memory_storage_client._persist_storage: if old_record is not None and _filename_from_record(old_record) != _filename_from_record(record): - await existing_store_by_id._delete_persisted_record(old_record) + await existing_store_by_id._delete_persisted_record(old_record) # type: ignore - await existing_store_by_id._persist_record(record) + await existing_store_by_id._persist_record(record) # type: ignore - async def _persist_record(self, record: KeyValueStoreRecord) -> None: + async def _persist_record(self: KeyValueStoreClient, record: KeyValueStoreRecord) -> None: store_directory = self._resource_directory record_filename = _filename_from_record(record) record['filename'] = record_filename @@ -339,12 +362,16 @@ async def _persist_record(self, record: KeyValueStoreRecord) -> None: if self._memory_storage_client._write_metadata: async with aiofiles.open(record_metadata_path, mode='wb') as f: - await f.write(json_dumps({ - 'key': record['key'], - 'contentType': record['contentType'], - }).encode('utf-8')) + await f.write( + json_dumps( + { + 'key': record['key'], + 'contentType': record['contentType'], + } + ).encode('utf-8') + ) - async def delete_record(self, key: str) -> None: + async def delete_record(self: KeyValueStoreClient, key: str) -> None: """Delete the specified record from the key-value store. Args: @@ -352,21 +379,22 @@ async def delete_record(self, key: str) -> None: """ # Check by id existing_store_by_id = self._find_or_create_client_by_id_or_name( - memory_storage_client=self._memory_storage_client, id=self._id, name=self._name) + memory_storage_client=self._memory_storage_client, id=self._id, name=self._name + ) if existing_store_by_id is None: - _raise_on_non_existing_storage(_StorageTypes.KEY_VALUE_STORE, self._id) + raise_on_non_existing_storage(StorageTypes.KEY_VALUE_STORE, self._id) - record = existing_store_by_id._records.get(key) + record = existing_store_by_id._records.get(key) # type: ignore if record is not None: - async with existing_store_by_id._file_operation_lock: - del existing_store_by_id._records[key] - await existing_store_by_id._update_timestamps(True) + async with existing_store_by_id._file_operation_lock: # type: ignore + del existing_store_by_id._records[key] # type: ignore + await existing_store_by_id._update_timestamps(has_been_modified=True) # type: ignore if self._memory_storage_client._persist_storage: - await existing_store_by_id._delete_persisted_record(record) + await existing_store_by_id._delete_persisted_record(record) # type: ignore - async def _delete_persisted_record(self, record: KeyValueStoreRecord) -> None: + async def _delete_persisted_record(self: KeyValueStoreClient, record: KeyValueStoreRecord) -> None: store_directory = self._resource_directory record_filename = _filename_from_record(record) @@ -377,10 +405,10 @@ async def _delete_persisted_record(self, record: KeyValueStoreRecord) -> None: record_path = os.path.join(store_directory, record_filename) record_metadata_path = os.path.join(store_directory, record_filename + '.__metadata__.json') - await _force_remove(record_path) - await _force_remove(record_metadata_path) + await force_remove(record_path) + await force_remove(record_metadata_path) - def _to_resource_info(self) -> Dict: + def _to_resource_info(self: KeyValueStoreClient) -> dict: """Retrieve the key-value store info.""" return { 'id': self._id, @@ -391,35 +419,38 @@ def _to_resource_info(self) -> Dict: 'userId': '1', } - async def _update_timestamps(self, has_been_modified: bool) -> None: + async def _update_timestamps(self: KeyValueStoreClient, has_been_modified: bool) -> None: # noqa: FBT001 self._accessed_at = datetime.now(timezone.utc) if has_been_modified: self._modified_at = datetime.now(timezone.utc) kv_store_info = self._to_resource_info() - await _update_metadata( + await update_metadata( data=kv_store_info, entity_directory=self._resource_directory, write_metadata=self._memory_storage_client._write_metadata, ) @classmethod - def _get_storages_dir(cls, memory_storage_client: 'MemoryStorageClient') -> str: + def _get_storages_dir(cls: type[KeyValueStoreClient], memory_storage_client: MemoryStorageClient) -> str: return memory_storage_client._key_value_stores_directory @classmethod - def _get_storage_client_cache(cls, memory_storage_client: 'MemoryStorageClient') -> List['KeyValueStoreClient']: + def _get_storage_client_cache( # type: ignore + cls: type[KeyValueStoreClient], + memory_storage_client: MemoryStorageClient, + ) -> list[KeyValueStoreClient]: return memory_storage_client._key_value_stores_handled @classmethod def _create_from_directory( - cls, + cls: type[KeyValueStoreClient], storage_directory: str, - memory_storage_client: 'MemoryStorageClient', - id: Optional[str] = None, - name: Optional[str] = None, - ) -> 'KeyValueStoreClient': + memory_storage_client: MemoryStorageClient, + id: str | None = None, # noqa: A002 + name: str | None = None, + ) -> KeyValueStoreClient: created_at = datetime.now(timezone.utc) accessed_at = datetime.now(timezone.utc) modified_at = datetime.now(timezone.utc) @@ -428,7 +459,7 @@ def _create_from_directory( if os.path.exists(store_metadata_path): with open(store_metadata_path, encoding='utf-8') as f: metadata = json.load(f) - id = metadata['id'] + id = metadata['id'] # noqa: A001 name = metadata['name'] created_at = datetime.fromisoformat(metadata['createdAt']) accessed_at = datetime.fromisoformat(metadata['accessedAt']) @@ -460,12 +491,12 @@ def _create_from_directory( # Try checking if this file has a metadata file associated with it metadata = None - if (os.path.exists(os.path.join(storage_directory, entry.name + '.__metadata__.json'))): + if os.path.exists(os.path.join(storage_directory, entry.name + '.__metadata__.json')): with open(os.path.join(storage_directory, entry.name + '.__metadata__.json'), encoding='utf-8') as metadata_file: try: metadata = json.load(metadata_file) - assert metadata.get('key') is not None - assert metadata.get('contentType') is not None + assert metadata.get('key') is not None # noqa: S101 + assert metadata.get('contentType') is not None # noqa: S101 except Exception: logger.warning( f"""Metadata of key-value store entry "{entry.name}" for store {name or id} could not be parsed.""" @@ -484,7 +515,7 @@ def _create_from_directory( } try: - _maybe_parse_body(file_content, metadata['contentType']) + maybe_parse_body(file_content, metadata['contentType']) except Exception: metadata['contentType'] = 'application/octet-stream' logger.warning( diff --git a/src/apify/_memory_storage/resource_clients/key_value_store_collection.py b/src/apify/_memory_storage/resource_clients/key_value_store_collection.py index 4f6470be..7dd31fed 100644 --- a/src/apify/_memory_storage/resource_clients/key_value_store_collection.py +++ b/src/apify/_memory_storage/resource_clients/key_value_store_collection.py @@ -1,23 +1,27 @@ -from typing import Dict, List, Optional, Type +from __future__ import annotations + +from typing import TYPE_CHECKING -from apify_shared.models import ListPage from apify_shared.utils import ignore_docs from .base_resource_collection_client import BaseResourceCollectionClient from .key_value_store import KeyValueStoreClient +if TYPE_CHECKING: + from apify_shared.models import ListPage + @ignore_docs class KeyValueStoreCollectionClient(BaseResourceCollectionClient): """Sub-client for manipulating key-value stores.""" - def _get_storage_client_cache(self) -> List[KeyValueStoreClient]: + def _get_storage_client_cache(self: KeyValueStoreCollectionClient) -> list[KeyValueStoreClient]: return self._memory_storage_client._key_value_stores_handled - def _get_resource_client_class(self) -> Type[KeyValueStoreClient]: + def _get_resource_client_class(self: KeyValueStoreCollectionClient) -> type[KeyValueStoreClient]: return KeyValueStoreClient - async def list(self) -> ListPage: + async def list(self: KeyValueStoreCollectionClient) -> ListPage: # noqa: A003 """List the available key-value stores. Returns: @@ -26,12 +30,12 @@ async def list(self) -> ListPage: return await super().list() async def get_or_create( - self, + self: KeyValueStoreCollectionClient, *, - name: Optional[str] = None, - schema: Optional[Dict] = None, - _id: Optional[str] = None, - ) -> Dict: + name: str | None = None, + schema: dict | None = None, + _id: str | None = None, + ) -> dict: """Retrieve a named key-value store, or create a new one when it doesn't exist. Args: diff --git a/src/apify/_memory_storage/resource_clients/request_queue.py b/src/apify/_memory_storage/resource_clients/request_queue.py index 2c574da5..caabf57f 100644 --- a/src/apify/_memory_storage/resource_clients/request_queue.py +++ b/src/apify/_memory_storage/resource_clients/request_queue.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import asyncio import json import os from datetime import datetime, timezone from decimal import Decimal -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING import aioshutil from sortedcollections import ValueSortedDict @@ -11,9 +13,9 @@ from apify_shared.utils import filter_out_none_values_recursively, ignore_docs, json_dumps from ..._crypto import crypto_random_object_id -from ..._utils import _force_rename, _raise_on_duplicate_storage, _raise_on_non_existing_storage, _unique_key_to_request_id -from ...consts import _StorageTypes -from ..file_storage_utils import _delete_request, _update_metadata, _update_request_queue_item +from ..._utils import force_rename, raise_on_duplicate_storage, raise_on_non_existing_storage, unique_key_to_request_id +from ...consts import StorageTypes +from ..file_storage_utils import delete_request, update_metadata, update_request_queue_item from .base_resource_client import BaseResourceClient if TYPE_CHECKING: @@ -26,8 +28,8 @@ class RequestQueueClient(BaseResourceClient): _id: str _resource_directory: str - _memory_storage_client: 'MemoryStorageClient' - _name: Optional[str] + _memory_storage_client: MemoryStorageClient + _name: str | None _requests: ValueSortedDict _created_at: datetime _accessed_at: datetime @@ -38,12 +40,12 @@ class RequestQueueClient(BaseResourceClient): _file_operation_lock: asyncio.Lock def __init__( - self, + self: RequestQueueClient, *, base_storage_directory: str, - memory_storage_client: 'MemoryStorageClient', - id: Optional[str] = None, - name: Optional[str] = None, + memory_storage_client: MemoryStorageClient, + id: str | None = None, # noqa: A002 + name: str | None = None, ) -> None: """Initialize the RequestQueueClient.""" self._id = id or crypto_random_object_id() @@ -56,7 +58,7 @@ def __init__( self._modified_at = datetime.now(timezone.utc) self._file_operation_lock = asyncio.Lock() - async def get(self) -> Optional[Dict]: + async def get(self: RequestQueueClient) -> dict | None: """Retrieve the request queue. Returns: @@ -65,13 +67,13 @@ async def get(self) -> Optional[Dict]: found = self._find_or_create_client_by_id_or_name(memory_storage_client=self._memory_storage_client, id=self._id, name=self._name) if found: - async with found._file_operation_lock: - await found._update_timestamps(False) + async with found._file_operation_lock: # type: ignore + await found._update_timestamps(has_been_modified=False) # type: ignore return found._to_resource_info() return None - async def update(self, *, name: Optional[str] = None) -> Dict: + async def update(self: RequestQueueClient, *, name: str | None = None) -> dict: """Update the request queue with specified fields. Args: @@ -82,22 +84,24 @@ async def update(self, *, name: Optional[str] = None) -> Dict: """ # Check by id existing_queue_by_id = self._find_or_create_client_by_id_or_name( - memory_storage_client=self._memory_storage_client, id=self._id, name=self._name) + memory_storage_client=self._memory_storage_client, id=self._id, name=self._name + ) if existing_queue_by_id is None: - _raise_on_non_existing_storage(_StorageTypes.REQUEST_QUEUE, self._id) + raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self._id) # Skip if no changes if name is None: return existing_queue_by_id._to_resource_info() - async with existing_queue_by_id._file_operation_lock: + async with existing_queue_by_id._file_operation_lock: # type: ignore # Check that name is not in use already existing_queue_by_name = next( - (queue for queue in self._memory_storage_client._request_queues_handled if queue._name and queue._name.lower() == name.lower()), None) + (queue for queue in self._memory_storage_client._request_queues_handled if queue._name and queue._name.lower() == name.lower()), None + ) if existing_queue_by_name is not None: - _raise_on_duplicate_storage(_StorageTypes.REQUEST_QUEUE, 'name', name) + raise_on_duplicate_storage(StorageTypes.REQUEST_QUEUE, 'name', name) existing_queue_by_id._name = name @@ -105,14 +109,14 @@ async def update(self, *, name: Optional[str] = None) -> Dict: existing_queue_by_id._resource_directory = os.path.join(self._memory_storage_client._request_queues_directory, name) - await _force_rename(previous_dir, existing_queue_by_id._resource_directory) + await force_rename(previous_dir, existing_queue_by_id._resource_directory) # Update timestamps - await existing_queue_by_id._update_timestamps(True) + await existing_queue_by_id._update_timestamps(has_been_modified=True) # type: ignore return existing_queue_by_id._to_resource_info() - async def delete(self) -> None: + async def delete(self: RequestQueueClient) -> None: """Delete the request queue.""" queue = next((queue for queue in self._memory_storage_client._request_queues_handled if queue._id == self._id), None) @@ -126,7 +130,7 @@ async def delete(self) -> None: if os.path.exists(queue._resource_directory): await aioshutil.rmtree(queue._resource_directory) - async def list_head(self, *, limit: Optional[int] = None) -> Dict: + async def list_head(self: RequestQueueClient, *, limit: int | None = None) -> dict: """Retrieve a given number of requests from the beginning of the queue. Args: @@ -136,23 +140,24 @@ async def list_head(self, *, limit: Optional[int] = None) -> Dict: dict: The desired number of requests from the beginning of the queue. """ existing_queue_by_id = self._find_or_create_client_by_id_or_name( - memory_storage_client=self._memory_storage_client, id=self._id, name=self._name) + memory_storage_client=self._memory_storage_client, id=self._id, name=self._name + ) if existing_queue_by_id is None: - _raise_on_non_existing_storage(_StorageTypes.REQUEST_QUEUE, self._id) + raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self._id) - async with existing_queue_by_id._file_operation_lock: - await existing_queue_by_id._update_timestamps(False) + async with existing_queue_by_id._file_operation_lock: # type: ignore + await existing_queue_by_id._update_timestamps(has_been_modified=False) # type: ignore - items: List[Dict] = [] + items: list[dict] = [] # Iterate all requests in the queue which have sorted key larger than infinity, which means `orderNo` is not `None` # This will iterate them in order of `orderNo` - for request_key in existing_queue_by_id._requests.irange_key(min_key=-float('inf'), inclusive=(False, True)): + for request_key in existing_queue_by_id._requests.irange_key(min_key=-float('inf'), inclusive=(False, True)): # type: ignore if len(items) == limit: break - request = existing_queue_by_id._requests.get(request_key) + request = existing_queue_by_id._requests.get(request_key) # type: ignore # Check that the request still exists and was not handled, # in case something deleted it or marked it as handled concurrenctly @@ -162,11 +167,11 @@ async def list_head(self, *, limit: Optional[int] = None) -> Dict: return { 'limit': limit, 'hadMultipleClients': False, - 'queueModifiedAt': existing_queue_by_id._modified_at, + 'queueModifiedAt': existing_queue_by_id._modified_at, # type: ignore 'items': [self._json_to_request(item['json']) for item in items], } - async def add_request(self, request: Dict, *, forefront: Optional[bool] = None) -> Dict: + async def add_request(self: RequestQueueClient, request: dict, *, forefront: bool | None = None) -> dict: """Add a request to the queue. Args: @@ -177,19 +182,20 @@ async def add_request(self, request: Dict, *, forefront: Optional[bool] = None) dict: The added request. """ existing_queue_by_id = self._find_or_create_client_by_id_or_name( - memory_storage_client=self._memory_storage_client, id=self._id, name=self._name) + memory_storage_client=self._memory_storage_client, id=self._id, name=self._name + ) if existing_queue_by_id is None: - _raise_on_non_existing_storage(_StorageTypes.REQUEST_QUEUE, self._id) + raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self._id) request_model = self._create_internal_request(request, forefront) - async with existing_queue_by_id._file_operation_lock: - existing_request_with_id = existing_queue_by_id._requests.get(request_model['id']) + async with existing_queue_by_id._file_operation_lock: # type: ignore + existing_request_with_id = existing_queue_by_id._requests.get(request_model['id']) # type: ignore # We already have the request present, so we return information about it if existing_request_with_id is not None: - await existing_queue_by_id._update_timestamps(False) + await existing_queue_by_id._update_timestamps(has_been_modified=False) # type: ignore return { 'requestId': existing_request_with_id['id'], @@ -197,13 +203,13 @@ async def add_request(self, request: Dict, *, forefront: Optional[bool] = None) 'wasAlreadyPresent': True, } - existing_queue_by_id._requests[request_model['id']] = request_model + existing_queue_by_id._requests[request_model['id']] = request_model # type: ignore if request_model['orderNo'] is None: - existing_queue_by_id._handled_request_count += 1 + existing_queue_by_id._handled_request_count += 1 # type: ignore else: - existing_queue_by_id._pending_request_count += 1 - await existing_queue_by_id._update_timestamps(True) - await _update_request_queue_item( + existing_queue_by_id._pending_request_count += 1 # type: ignore + await existing_queue_by_id._update_timestamps(has_been_modified=True) # type: ignore + await update_request_queue_item( request=request_model, request_id=request_model['id'], entity_directory=existing_queue_by_id._resource_directory, @@ -218,7 +224,7 @@ async def add_request(self, request: Dict, *, forefront: Optional[bool] = None) 'wasAlreadyPresent': False, } - async def get_request(self, request_id: str) -> Optional[Dict]: + async def get_request(self: RequestQueueClient, request_id: str) -> dict | None: """Retrieve a request from the queue. Args: @@ -228,18 +234,19 @@ async def get_request(self, request_id: str) -> Optional[Dict]: dict, optional: The retrieved request, or None, if it did not exist. """ existing_queue_by_id = self._find_or_create_client_by_id_or_name( - memory_storage_client=self._memory_storage_client, id=self._id, name=self._name) + memory_storage_client=self._memory_storage_client, id=self._id, name=self._name + ) if existing_queue_by_id is None: - _raise_on_non_existing_storage(_StorageTypes.REQUEST_QUEUE, self._id) + raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self._id) - async with existing_queue_by_id._file_operation_lock: - await existing_queue_by_id._update_timestamps(False) + async with existing_queue_by_id._file_operation_lock: # type: ignore + await existing_queue_by_id._update_timestamps(has_been_modified=False) # type: ignore - request = existing_queue_by_id._requests.get(request_id) + request = existing_queue_by_id._requests.get(request_id) # type: ignore return self._json_to_request(request['json'] if request is not None else None) - async def update_request(self, request: Dict, *, forefront: Optional[bool] = None) -> Dict: + async def update_request(self: RequestQueueClient, request: dict, *, forefront: bool | None = None) -> dict: """Update a request in the queue. Args: @@ -250,40 +257,41 @@ async def update_request(self, request: Dict, *, forefront: Optional[bool] = Non dict: The updated request """ existing_queue_by_id = self._find_or_create_client_by_id_or_name( - memory_storage_client=self._memory_storage_client, id=self._id, name=self._name) + memory_storage_client=self._memory_storage_client, id=self._id, name=self._name + ) if existing_queue_by_id is None: - _raise_on_non_existing_storage(_StorageTypes.REQUEST_QUEUE, self._id) + raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self._id) request_model = self._create_internal_request(request, forefront) # First we need to check the existing request to be # able to return information about its handled state. - existing_request = existing_queue_by_id._requests.get(request_model['id']) + existing_request = existing_queue_by_id._requests.get(request_model['id']) # type: ignore # Undefined means that the request is not present in the queue. # We need to insert it, to behave the same as API. if existing_request is None: return await self.add_request(request, forefront=forefront) - async with existing_queue_by_id._file_operation_lock: + async with existing_queue_by_id._file_operation_lock: # type: ignore # When updating the request, we need to make sure that # the handled counts are updated correctly in all cases. - existing_queue_by_id._requests[request_model['id']] = request_model + existing_queue_by_id._requests[request_model['id']] = request_model # type: ignore pending_count_adjustment = 0 - is_request_handled_state_changing = type(existing_request['orderNo']) != type(request_model['orderNo']) # noqa + is_request_handled_state_changing = not isinstance(existing_request['orderNo'], type(request_model['orderNo'])) request_was_handled_before_update = existing_request['orderNo'] is None # We add 1 pending request if previous state was handled if is_request_handled_state_changing: pending_count_adjustment = 1 if request_was_handled_before_update else -1 - existing_queue_by_id._pending_request_count += pending_count_adjustment - existing_queue_by_id._handled_request_count -= pending_count_adjustment - await existing_queue_by_id._update_timestamps(True) - await _update_request_queue_item( + existing_queue_by_id._pending_request_count += pending_count_adjustment # type: ignore + existing_queue_by_id._handled_request_count -= pending_count_adjustment # type: ignore + await existing_queue_by_id._update_timestamps(has_been_modified=True) # type: ignore + await update_request_queue_item( request=request_model, request_id=request_model['id'], entity_directory=existing_queue_by_id._resource_directory, @@ -296,31 +304,32 @@ async def update_request(self, request: Dict, *, forefront: Optional[bool] = Non 'wasAlreadyPresent': True, } - async def delete_request(self, request_id: str) -> None: + async def delete_request(self: RequestQueueClient, request_id: str) -> None: """Delete a request from the queue. Args: request_id (str): ID of the request to delete. """ existing_queue_by_id = self._find_or_create_client_by_id_or_name( - memory_storage_client=self._memory_storage_client, id=self._id, name=self._name) + memory_storage_client=self._memory_storage_client, id=self._id, name=self._name + ) if existing_queue_by_id is None: - _raise_on_non_existing_storage(_StorageTypes.REQUEST_QUEUE, self._id) + raise_on_non_existing_storage(StorageTypes.REQUEST_QUEUE, self._id) - async with existing_queue_by_id._file_operation_lock: - request = existing_queue_by_id._requests.get(request_id) + async with existing_queue_by_id._file_operation_lock: # type: ignore + request = existing_queue_by_id._requests.get(request_id) # type: ignore if request: - del existing_queue_by_id._requests[request_id] + del existing_queue_by_id._requests[request_id] # type: ignore if request['orderNo'] is None: - existing_queue_by_id._handled_request_count -= 1 + existing_queue_by_id._handled_request_count -= 1 # type: ignore else: - existing_queue_by_id._pending_request_count -= 1 - await existing_queue_by_id._update_timestamps(True) - await _delete_request(entity_directory=existing_queue_by_id._resource_directory, request_id=request_id) + existing_queue_by_id._pending_request_count -= 1 # type: ignore + await existing_queue_by_id._update_timestamps(has_been_modified=True) # type: ignore + await delete_request(entity_directory=existing_queue_by_id._resource_directory, request_id=request_id) - def _to_resource_info(self) -> Dict: + def _to_resource_info(self: RequestQueueClient) -> dict: """Retrieve the request queue store info.""" return { 'accessedAt': self._accessed_at, @@ -336,28 +345,28 @@ def _to_resource_info(self) -> Dict: 'userId': '1', } - async def _update_timestamps(self, has_been_modified: bool) -> None: + async def _update_timestamps(self: RequestQueueClient, has_been_modified: bool) -> None: # noqa: FBT001 self._accessed_at = datetime.now(timezone.utc) if has_been_modified: self._modified_at = datetime.now(timezone.utc) request_queue_info = self._to_resource_info() - await _update_metadata( + await update_metadata( data=request_queue_info, entity_directory=self._resource_directory, write_metadata=self._memory_storage_client._write_metadata, ) - def _json_to_request(self, request_json: Optional[str]) -> Optional[dict]: + def _json_to_request(self: RequestQueueClient, request_json: str | None) -> dict | None: if request_json is None: return None request = json.loads(request_json) return filter_out_none_values_recursively(request) - def _create_internal_request(self, request: Dict, forefront: Optional[bool]) -> Dict: + def _create_internal_request(self: RequestQueueClient, request: dict, forefront: bool | None) -> dict: order_no = self._calculate_order_no(request, forefront) - id = _unique_key_to_request_id(request['uniqueKey']) + id = unique_key_to_request_id(request['uniqueKey']) # noqa: A001 if request.get('id') is not None and request['id'] != id: raise ValueError('Request ID does not match its unique_key.') @@ -373,7 +382,7 @@ def _create_internal_request(self, request: Dict, forefront: Optional[bool]) -> 'url': request['url'], } - def _calculate_order_no(self, request: Dict, forefront: Optional[bool]) -> Optional[Decimal]: + def _calculate_order_no(self: RequestQueueClient, request: dict, forefront: bool | None) -> Decimal | None: if request.get('handledAt') is not None: return None @@ -390,27 +399,30 @@ def _calculate_order_no(self, request: Dict, forefront: Optional[bool]) -> Optio return -timestamp if forefront else timestamp @classmethod - def _get_storages_dir(cls, memory_storage_client: 'MemoryStorageClient') -> str: + def _get_storages_dir(cls: type[RequestQueueClient], memory_storage_client: MemoryStorageClient) -> str: return memory_storage_client._request_queues_directory @classmethod - def _get_storage_client_cache(cls, memory_storage_client: 'MemoryStorageClient') -> List['RequestQueueClient']: + def _get_storage_client_cache( # type: ignore + cls: type[RequestQueueClient], + memory_storage_client: MemoryStorageClient, + ) -> list[RequestQueueClient]: return memory_storage_client._request_queues_handled @classmethod def _create_from_directory( - cls, + cls: type[RequestQueueClient], storage_directory: str, - memory_storage_client: 'MemoryStorageClient', - id: Optional[str] = None, - name: Optional[str] = None, - ) -> 'RequestQueueClient': + memory_storage_client: MemoryStorageClient, + id: str | None = None, # noqa: A002 + name: str | None = None, + ) -> RequestQueueClient: created_at = datetime.now(timezone.utc) accessed_at = datetime.now(timezone.utc) modified_at = datetime.now(timezone.utc) handled_request_count = 0 pending_request_count = 0 - entries: List[Dict] = [] + entries: list[dict] = [] # Access the request queue folder for entry in os.scandir(storage_directory): @@ -419,7 +431,7 @@ def _create_from_directory( # We have found the queue's metadata file, build out information based on it with open(os.path.join(storage_directory, entry.name), encoding='utf-8') as f: metadata = json.load(f) - id = metadata['id'] + id = metadata['id'] # noqa: A001 name = metadata['name'] created_at = datetime.fromisoformat(metadata['createdAt']) accessed_at = datetime.fromisoformat(metadata['accessedAt']) diff --git a/src/apify/_memory_storage/resource_clients/request_queue_collection.py b/src/apify/_memory_storage/resource_clients/request_queue_collection.py index 1ec9a761..5265d14b 100644 --- a/src/apify/_memory_storage/resource_clients/request_queue_collection.py +++ b/src/apify/_memory_storage/resource_clients/request_queue_collection.py @@ -1,23 +1,27 @@ -from typing import Dict, List, Optional, Type +from __future__ import annotations + +from typing import TYPE_CHECKING -from apify_shared.models import ListPage from apify_shared.utils import ignore_docs from .base_resource_collection_client import BaseResourceCollectionClient from .request_queue import RequestQueueClient +if TYPE_CHECKING: + from apify_shared.models import ListPage + @ignore_docs class RequestQueueCollectionClient(BaseResourceCollectionClient): """Sub-client for manipulating request queues.""" - def _get_storage_client_cache(self) -> List[RequestQueueClient]: + def _get_storage_client_cache(self: RequestQueueCollectionClient) -> list[RequestQueueClient]: return self._memory_storage_client._request_queues_handled - def _get_resource_client_class(self) -> Type[RequestQueueClient]: + def _get_resource_client_class(self: RequestQueueCollectionClient) -> type[RequestQueueClient]: return RequestQueueClient - async def list(self) -> ListPage: + async def list(self: RequestQueueCollectionClient) -> ListPage: # noqa: A003 """List the available request queues. Returns: @@ -26,17 +30,17 @@ async def list(self) -> ListPage: return await super().list() async def get_or_create( - self, + self: RequestQueueCollectionClient, *, - name: Optional[str] = None, - schema: Optional[Dict] = None, - _id: Optional[str] = None, - ) -> Dict: + name: str | None = None, + schema: dict | None = None, + _id: str | None = None, + ) -> dict: """Retrieve a named request queue, or create a new one when it doesn't exist. Args: name (str, optional): The name of the request queue to retrieve or create. - schema (Dict, optional): The schema of the request queue + schema (dict, optional): The schema of the request queue Returns: dict: The retrieved or newly-created request queue. diff --git a/src/apify/_utils.py b/src/apify/_utils.py index 22947332..f1d3e595 100644 --- a/src/apify/_utils.py +++ b/src/apify/_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import base64 import builtins @@ -15,9 +17,19 @@ from collections.abc import MutableMapping from datetime import datetime, timezone from importlib import metadata -from typing import Any, Callable, Dict, Generic, ItemsView, Iterator, List, NoReturn, Optional +from typing import ( + Any, + Callable, + Generic, + ItemsView, + Iterator, + NoReturn, + TypeVar, + ValuesView, + cast, + overload, +) from typing import OrderedDict as OrderedDictType -from typing import Tuple, Type, TypeVar, Union, ValuesView, cast, overload import aioshutil import psutil @@ -37,24 +49,30 @@ ActorEnvVars, ApifyEnvVars, ) -from apify_shared.utils import ignore_docs, is_content_type_json, is_content_type_text, is_content_type_xml, maybe_extract_enum_member_value +from apify_shared.utils import ( + ignore_docs, + is_content_type_json, + is_content_type_text, + is_content_type_xml, + maybe_extract_enum_member_value, +) -from .consts import REQUEST_ID_LENGTH, _StorageTypes +from .consts import REQUEST_ID_LENGTH, StorageTypes T = TypeVar('T') -def _get_system_info() -> Dict: +def get_system_info() -> dict: python_version = '.'.join([str(x) for x in sys.version_info[:3]]) - system_info: Dict[str, Union[str, bool]] = { + system_info: dict[str, str | bool] = { 'apify_sdk_version': metadata.version('apify'), 'apify_client_version': metadata.version('apify-client'), 'python_version': python_version, 'os': sys.platform, } - if _is_running_in_ipython(): + if is_running_in_ipython(): system_info['is_running_in_ipython'] = True return system_info @@ -72,7 +90,7 @@ class dualproperty(Generic[DualPropertyType]): # noqa: N801 and when accessing it on a class, it calls the getter with the class as the first argument. """ - def __init__(self, getter: Callable[..., DualPropertyType]) -> None: + def __init__(self: dualproperty, getter: Callable[..., DualPropertyType]) -> None: """Initialize the dualproperty. Args: @@ -81,75 +99,76 @@ def __init__(self, getter: Callable[..., DualPropertyType]) -> None: """ self.getter = getter - def __get__(self, obj: Optional[DualPropertyOwner], owner: Type[DualPropertyOwner]) -> DualPropertyType: + def __get__(self: dualproperty, obj: DualPropertyOwner | None, owner: type[DualPropertyOwner]) -> DualPropertyType: """Call the getter with the right object. Args: - obj (Optional[T]): The instance of class T on which the getter will be called - owner (Type[T]): The class object of class T on which the getter will be called, if obj is None + obj (T | None): The instance of class T on which the getter will be called + owner (type[T]): The class object of class T on which the getter will be called, if obj is None Returns: The result of the getter. """ - return self.getter(obj or owner) + val = self.getter(obj or owner) + return cast(DualPropertyType, val) @overload -def _fetch_and_parse_env_var(env_var: BOOL_ENV_VARS_TYPE) -> Optional[bool]: +def fetch_and_parse_env_var(env_var: BOOL_ENV_VARS_TYPE) -> bool | None: ... @overload -def _fetch_and_parse_env_var(env_var: BOOL_ENV_VARS_TYPE, default: bool) -> bool: +def fetch_and_parse_env_var(env_var: BOOL_ENV_VARS_TYPE, default: bool) -> bool: # noqa: FBT001 ... @overload -def _fetch_and_parse_env_var(env_var: DATETIME_ENV_VARS_TYPE) -> Optional[Union[datetime, str]]: +def fetch_and_parse_env_var(env_var: DATETIME_ENV_VARS_TYPE) -> datetime | str | None: ... @overload -def _fetch_and_parse_env_var(env_var: DATETIME_ENV_VARS_TYPE, default: datetime) -> Union[datetime, str]: +def fetch_and_parse_env_var(env_var: DATETIME_ENV_VARS_TYPE, default: datetime) -> datetime | str: ... @overload -def _fetch_and_parse_env_var(env_var: FLOAT_ENV_VARS_TYPE) -> Optional[float]: +def fetch_and_parse_env_var(env_var: FLOAT_ENV_VARS_TYPE) -> float | None: ... @overload -def _fetch_and_parse_env_var(env_var: FLOAT_ENV_VARS_TYPE, default: float) -> float: +def fetch_and_parse_env_var(env_var: FLOAT_ENV_VARS_TYPE, default: float) -> float: ... @overload -def _fetch_and_parse_env_var(env_var: INTEGER_ENV_VARS_TYPE) -> Optional[int]: +def fetch_and_parse_env_var(env_var: INTEGER_ENV_VARS_TYPE) -> int | None: ... @overload -def _fetch_and_parse_env_var(env_var: INTEGER_ENV_VARS_TYPE, default: int) -> int: +def fetch_and_parse_env_var(env_var: INTEGER_ENV_VARS_TYPE, default: int) -> int: ... @overload -def _fetch_and_parse_env_var(env_var: STRING_ENV_VARS_TYPE, default: str) -> str: +def fetch_and_parse_env_var(env_var: STRING_ENV_VARS_TYPE, default: str) -> str: ... @overload -def _fetch_and_parse_env_var(env_var: STRING_ENV_VARS_TYPE) -> Optional[str]: +def fetch_and_parse_env_var(env_var: STRING_ENV_VARS_TYPE) -> str | None: ... @overload -def _fetch_and_parse_env_var(env_var: Union[ActorEnvVars, ApifyEnvVars]) -> Optional[Any]: +def fetch_and_parse_env_var(env_var: ActorEnvVars | ApifyEnvVars) -> Any: ... -def _fetch_and_parse_env_var(env_var: Any, default: Any = None) -> Any: +def fetch_and_parse_env_var(env_var: Any, default: Any = None) -> Any: env_var_name = str(maybe_extract_enum_member_value(env_var)) val = os.getenv(env_var_name) @@ -157,27 +176,27 @@ def _fetch_and_parse_env_var(env_var: Any, default: Any = None) -> Any: return default if env_var in BOOL_ENV_VARS: - return _maybe_parse_bool(val) + return maybe_parse_bool(val) if env_var in FLOAT_ENV_VARS: - parsed_float = _maybe_parse_float(val) + parsed_float = maybe_parse_float(val) if parsed_float is None: return default return parsed_float if env_var in INTEGER_ENV_VARS: - parsed_int = _maybe_parse_int(val) + parsed_int = maybe_parse_int(val) if parsed_int is None: return default return parsed_int if env_var in DATETIME_ENV_VARS: - return _maybe_parse_datetime(val) + return maybe_parse_datetime(val) return val -def _get_cpu_usage_percent() -> float: +def get_cpu_usage_percent() -> float: return psutil.cpu_percent() -def _get_memory_usage_bytes() -> int: +def get_memory_usage_bytes() -> int: current_process = psutil.Process(os.getpid()) mem = int(current_process.memory_info().rss or 0) for child in current_process.children(recursive=True): @@ -186,34 +205,34 @@ def _get_memory_usage_bytes() -> int: return mem -def _maybe_parse_bool(val: Optional[str]) -> bool: - if val == 'true' or val == 'True' or val == '1': +def maybe_parse_bool(val: str | None) -> bool: + if val in {'true', 'True', '1'}: return True return False -def _maybe_parse_datetime(val: str) -> Union[datetime, str]: +def maybe_parse_datetime(val: str) -> datetime | str: try: return datetime.strptime(val, '%Y-%m-%dT%H:%M:%S.%fZ').replace(tzinfo=timezone.utc) except ValueError: return val -def _maybe_parse_float(val: str) -> Optional[float]: +def maybe_parse_float(val: str) -> float | None: try: return float(val) except ValueError: return None -def _maybe_parse_int(val: str) -> Optional[int]: +def maybe_parse_int(val: str) -> int | None: try: return int(val) except ValueError: return None -async def _run_func_at_interval_async(func: Callable, interval_secs: float) -> None: +async def run_func_at_interval_async(func: Callable, interval_secs: float) -> None: started_at = time.perf_counter() sleep_until = started_at while True: @@ -230,23 +249,23 @@ async def _run_func_at_interval_async(func: Callable, interval_secs: float) -> N await res -async def _force_remove(filename: str) -> None: +async def force_remove(filename: str) -> None: """JS-like rm(filename, { force: true }).""" with contextlib.suppress(FileNotFoundError): await remove(filename) -def _raise_on_non_existing_storage(client_type: _StorageTypes, id: str) -> NoReturn: +def raise_on_non_existing_storage(client_type: StorageTypes, id: str) -> NoReturn: # noqa: A002 client_type = maybe_extract_enum_member_value(client_type) raise ValueError(f'{client_type} with id "{id}" does not exist.') -def _raise_on_duplicate_storage(client_type: _StorageTypes, key_name: str, value: str) -> NoReturn: +def raise_on_duplicate_storage(client_type: StorageTypes, key_name: str, value: str) -> NoReturn: client_type = maybe_extract_enum_member_value(client_type) raise ValueError(f'{client_type} with {key_name} "{value}" already exists.') -def _guess_file_extension(content_type: str) -> Optional[str]: +def guess_file_extension(content_type: str) -> str | None: """Guess the file extension based on content type.""" # e.g. mimetypes.guess_extension('application/json ') does not work... actual_content_type = content_type.split(';')[0].strip() @@ -264,22 +283,22 @@ def _guess_file_extension(content_type: str) -> Optional[str]: return ext[1:] if ext is not None else ext -def _maybe_parse_body(body: bytes, content_type: str) -> Any: +def maybe_parse_body(body: bytes, content_type: str) -> Any: if is_content_type_json(content_type): return json.loads(body.decode('utf-8')) # Returns any - elif is_content_type_xml(content_type) or is_content_type_text(content_type): + if is_content_type_xml(content_type) or is_content_type_text(content_type): return body.decode('utf-8') return body -def _unique_key_to_request_id(unique_key: str) -> str: +def unique_key_to_request_id(unique_key: str) -> str: """Generate request ID based on unique key in a deterministic way.""" - id = re.sub(r'(\+|\/|=)', '', base64.b64encode(hashlib.sha256(unique_key.encode('utf-8')).digest()).decode('utf-8')) + id = re.sub(r'(\+|\/|=)', '', base64.b64encode(hashlib.sha256(unique_key.encode('utf-8')).digest()).decode('utf-8')) # noqa: A001 return id[:REQUEST_ID_LENGTH] if len(id) > REQUEST_ID_LENGTH else id -async def _force_rename(src_dir: str, dst_dir: str) -> None: +async def force_rename(src_dir: str, dst_dir: str) -> None: """Rename a directory. Checks for existence of soruce directory and removes destination directory if it exists.""" # Make sure source directory exists if await ospath.exists(src_dir): @@ -288,11 +307,12 @@ async def _force_rename(src_dir: str, dst_dir: str) -> None: await aioshutil.rmtree(dst_dir, ignore_errors=True) await rename(src_dir, dst_dir) + ImplementationType = TypeVar('ImplementationType', bound=Callable) MetadataType = TypeVar('MetadataType', bound=Callable) -def _wrap_internal(implementation: ImplementationType, metadata_source: MetadataType) -> MetadataType: +def wrap_internal(implementation: ImplementationType, metadata_source: MetadataType) -> MetadataType: @functools.wraps(metadata_source) def wrapper(*args: Any, **kwargs: Any) -> Any: return implementation(*args, **kwargs) @@ -308,67 +328,68 @@ class LRUCache(MutableMapping, Generic[T]): _max_length: int - def __init__(self, max_length: int) -> None: + def __init__(self: LRUCache, max_length: int) -> None: """Create a LRUCache with a specific max_length.""" self._cache = OrderedDict() self._max_length = max_length - def __getitem__(self, key: str) -> T: + def __getitem__(self: LRUCache, key: str) -> T: """Get an item from the cache. Move it to the end if present.""" val = self._cache[key] # No 'key in cache' condition since the previous line would raise KeyError self._cache.move_to_end(key) - return val + return cast(T, val) # Sadly TS impl returns bool indicating whether the key was already present or not - def __setitem__(self, key: str, value: T) -> None: + def __setitem__(self: LRUCache, key: str, value: T) -> None: """Add an item to the cache. Remove least used item if max_length exceeded.""" self._cache[key] = value if len(self._cache) > self._max_length: self._cache.popitem(last=False) - def __delitem__(self, key: str) -> None: + def __delitem__(self: LRUCache, key: str) -> None: """Remove an item from the cache.""" del self._cache[key] - def __iter__(self) -> Iterator[str]: + def __iter__(self: LRUCache) -> Iterator[str]: """Iterate over the keys of the cache in order of insertion.""" return self._cache.__iter__() - def __len__(self) -> int: + def __len__(self: LRUCache) -> int: """Get the number of items in the cache.""" return len(self._cache) - def values(self) -> ValuesView[T]: # Needed so we don't mutate the cache by __getitem__ + def values(self: LRUCache) -> ValuesView[T]: # Needed so we don't mutate the cache by __getitem__ """Iterate over the values in the cache in order of insertion.""" return self._cache.values() - def items(self) -> ItemsView[str, T]: # Needed so we don't mutate the cache by __getitem__ + def items(self: LRUCache) -> ItemsView[str, T]: # Needed so we don't mutate the cache by __getitem__ """Iterate over the pairs of (key, value) in the cache in order of insertion.""" return self._cache.items() -def _is_running_in_ipython() -> bool: +def is_running_in_ipython() -> bool: return getattr(builtins, '__IPYTHON__', False) @overload -def _budget_ow(value: Union[str, int, float, bool], predicate: Tuple[Type, bool], value_name: str) -> None: +def budget_ow(value: str | float | bool, predicate: tuple[type, bool], value_name: str) -> None: ... @overload -def _budget_ow(value: Dict, predicate: Dict[str, Tuple[Type, bool]]) -> None: +def budget_ow(value: dict, predicate: dict[str, tuple[type, bool]]) -> None: ... -def _budget_ow( - value: Union[Dict, str, int, float, bool], - predicate: Union[Dict[str, Tuple[Type, bool]], Tuple[Type, bool]], - value_name: Optional[str] = None, +def budget_ow( + value: dict | str | float | bool, + predicate: dict[str, tuple[type, bool]] | tuple[type, bool], + value_name: str | None = None, ) -> None: """Budget version of ow.""" - def validate_single(field_value: Any, expected_type: Type, required: bool, name: str) -> None: + + def validate_single(field_value: Any, expected_type: type, required: bool, name: str) -> None: # noqa: FBT001 if field_value is None and required: raise ValueError(f'"{name}" is required!') if (field_value is not None or required) and not isinstance(field_value, expected_type): @@ -389,4 +410,4 @@ def validate_single(field_value: Any, expected_type: Type, required: bool, name: PARSE_DATE_FIELDS_MAX_DEPTH = 3 PARSE_DATE_FIELDS_KEY_SUFFIX = 'At' -ListOrDictOrAny = TypeVar('ListOrDictOrAny', List, Dict, Any) +ListOrDictOrAny = TypeVar('ListOrDictOrAny', list, dict, Any) diff --git a/src/apify/actor.py b/src/apify/actor.py index 4dc727e1..1afd3956 100644 --- a/src/apify/actor.py +++ b/src/apify/actor.py @@ -1,28 +1,27 @@ +from __future__ import annotations + import asyncio import contextlib import inspect -import logging import os import sys from datetime import datetime, timezone -from types import TracebackType -from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, cast from apify_client import ApifyClientAsync from apify_shared.consts import ActorEnvVars, ActorEventTypes, ActorExitCodes, ApifyEnvVars, WebhookEventType from apify_shared.utils import ignore_docs, maybe_extract_enum_member_value -from ._crypto import _decrypt_input_secrets, _load_private_key -from ._memory_storage import MemoryStorageClient +from ._crypto import decrypt_input_secrets, load_private_key from ._utils import ( - _fetch_and_parse_env_var, - _get_cpu_usage_percent, - _get_memory_usage_bytes, - _get_system_info, - _is_running_in_ipython, - _run_func_at_interval_async, - _wrap_internal, dualproperty, + fetch_and_parse_env_var, + get_cpu_usage_percent, + get_memory_usage_bytes, + get_system_info, + is_running_in_ipython, + run_func_at_interval_async, + wrap_internal, ) from .config import Configuration from .consts import EVENT_LISTENERS_TIMEOUT_SECS @@ -31,6 +30,12 @@ from .proxy_configuration import ProxyConfiguration from .storages import Dataset, KeyValueStore, RequestQueue, StorageClientManager +if TYPE_CHECKING: + import logging + from types import TracebackType + + from ._memory_storage import MemoryStorageClient + T = TypeVar('T') MainReturnType = TypeVar('MainReturnType') @@ -40,15 +45,15 @@ class _ActorContextManager(type): @staticmethod - async def __aenter__() -> Type['Actor']: + async def __aenter__() -> type[Actor]: await Actor.init() return Actor @staticmethod async def __aexit__( - _exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - _exc_traceback: Optional[TracebackType], + _exc_type: type[BaseException] | None, + exc_value: BaseException | None, + _exc_traceback: TracebackType | None, ) -> None: if not Actor._get_default_instance()._is_exiting: if exc_value: @@ -63,17 +68,17 @@ async def __aexit__( class Actor(metaclass=_ActorContextManager): """The main class of the SDK, through which all the actor operations should be done.""" - _default_instance: Optional['Actor'] = None + _default_instance: Actor | None = None _apify_client: ApifyClientAsync _memory_storage_client: MemoryStorageClient _config: Configuration _event_manager: EventManager - _send_system_info_interval_task: Optional[asyncio.Task] = None - _send_persist_state_interval_task: Optional[asyncio.Task] = None + _send_system_info_interval_task: asyncio.Task | None = None + _send_persist_state_interval_task: asyncio.Task | None = None _is_exiting = False _was_final_persist_state_emitted = False - def __init__(self, config: Optional[Configuration] = None) -> None: + def __init__(self: Actor, config: Configuration | None = None) -> None: """Create an Actor instance. Note that you don't have to do this, all the methods on this class function as classmethods too, @@ -87,35 +92,35 @@ def __init__(self, config: Optional[Configuration] = None) -> None: # we need to have an `_xxx_internal` instance method which contains the actual implementation of the method, # and then in the instance constructor overwrite the `xxx` classmethod with the `_xxx_internal` instance method, # while copying the annotations, types and so on. - self.init = _wrap_internal(self._init_internal, self.init) # type: ignore - self.exit = _wrap_internal(self._exit_internal, self.exit) # type: ignore - self.fail = _wrap_internal(self._fail_internal, self.fail) # type: ignore - self.main = _wrap_internal(self._main_internal, self.main) # type: ignore - self.new_client = _wrap_internal(self._new_client_internal, self.new_client) # type: ignore - - self.open_dataset = _wrap_internal(self._open_dataset_internal, self.open_dataset) # type: ignore - self.open_key_value_store = _wrap_internal(self._open_key_value_store_internal, self.open_key_value_store) # type: ignore - self.open_request_queue = _wrap_internal(self._open_request_queue_internal, self.open_request_queue) # type: ignore - self.push_data = _wrap_internal(self._push_data_internal, self.push_data) # type: ignore - self.get_input = _wrap_internal(self._get_input_internal, self.get_input) # type: ignore - self.get_value = _wrap_internal(self._get_value_internal, self.get_value) # type: ignore - self.set_value = _wrap_internal(self._set_value_internal, self.set_value) # type: ignore - - self.on = _wrap_internal(self._on_internal, self.on) # type: ignore - self.off = _wrap_internal(self._off_internal, self.off) # type: ignore - - self.is_at_home = _wrap_internal(self._is_at_home_internal, self.is_at_home) # type: ignore - self.get_env = _wrap_internal(self._get_env_internal, self.get_env) # type: ignore - - self.start = _wrap_internal(self._start_internal, self.start) # type: ignore - self.call = _wrap_internal(self._call_internal, self.call) # type: ignore - self.call_task = _wrap_internal(self._call_task_internal, self.call_task) # type: ignore - self.abort = _wrap_internal(self._abort_internal, self.abort) # type: ignore - self.metamorph = _wrap_internal(self._metamorph_internal, self.metamorph) # type: ignore - self.reboot = _wrap_internal(self._reboot_internal, self.reboot) # type: ignore - self.add_webhook = _wrap_internal(self._add_webhook_internal, self.add_webhook) # type: ignore - self.set_status_message = _wrap_internal(self._set_status_message_internal, self.set_status_message) # type: ignore - self.create_proxy_configuration = _wrap_internal(self._create_proxy_configuration_internal, self.create_proxy_configuration) # type: ignore + self.init = wrap_internal(self._init_internal, self.init) # type: ignore + self.exit = wrap_internal(self._exit_internal, self.exit) # type: ignore + self.fail = wrap_internal(self._fail_internal, self.fail) # type: ignore + self.main = wrap_internal(self._main_internal, self.main) # type: ignore + self.new_client = wrap_internal(self._new_client_internal, self.new_client) # type: ignore + + self.open_dataset = wrap_internal(self._open_dataset_internal, self.open_dataset) # type: ignore + self.open_key_value_store = wrap_internal(self._open_key_value_store_internal, self.open_key_value_store) # type: ignore + self.open_request_queue = wrap_internal(self._open_request_queue_internal, self.open_request_queue) # type: ignore + self.push_data = wrap_internal(self._push_data_internal, self.push_data) # type: ignore + self.get_input = wrap_internal(self._get_input_internal, self.get_input) # type: ignore + self.get_value = wrap_internal(self._get_value_internal, self.get_value) # type: ignore + self.set_value = wrap_internal(self._set_value_internal, self.set_value) # type: ignore + + self.on = wrap_internal(self._on_internal, self.on) # type: ignore + self.off = wrap_internal(self._off_internal, self.off) # type: ignore + + self.is_at_home = wrap_internal(self._is_at_home_internal, self.is_at_home) # type: ignore + self.get_env = wrap_internal(self._get_env_internal, self.get_env) # type: ignore + + self.start = wrap_internal(self._start_internal, self.start) # type: ignore + self.call = wrap_internal(self._call_internal, self.call) # type: ignore + self.call_task = wrap_internal(self._call_task_internal, self.call_task) # type: ignore + self.abort = wrap_internal(self._abort_internal, self.abort) # type: ignore + self.metamorph = wrap_internal(self._metamorph_internal, self.metamorph) # type: ignore + self.reboot = wrap_internal(self._reboot_internal, self.reboot) # type: ignore + self.add_webhook = wrap_internal(self._add_webhook_internal, self.add_webhook) # type: ignore + self.set_status_message = wrap_internal(self._set_status_message_internal, self.set_status_message) # type: ignore + self.create_proxy_configuration = wrap_internal(self._create_proxy_configuration_internal, self.create_proxy_configuration) # type: ignore self._config: Configuration = config or Configuration() self._apify_client = self.new_client() @@ -124,7 +129,7 @@ def __init__(self, config: Optional[Configuration] = None) -> None: self._is_initialized = False @ignore_docs - async def __aenter__(self) -> 'Actor': + async def __aenter__(self: Actor) -> Actor: """Initialize the Actor. Automatically initializes the Actor instance when you use it in an `async with ...` statement. @@ -138,10 +143,10 @@ async def __aenter__(self) -> 'Actor': @ignore_docs async def __aexit__( - self, - _exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - _exc_traceback: Optional[TracebackType], + self: Actor, + _exc_type: type[BaseException] | None, + exc_value: BaseException | None, + _exc_traceback: TracebackType | None, ) -> None: """Exit the Actor, handling any exceptions properly. @@ -159,47 +164,45 @@ async def __aexit__( await self.exit() @classmethod - def _get_default_instance(cls) -> 'Actor': + def _get_default_instance(cls: type[Actor]) -> Actor: if not cls._default_instance: cls._default_instance = cls(config=Configuration.get_global_configuration()) return cls._default_instance @dualproperty - def apify_client(self_or_cls) -> ApifyClientAsync: # noqa: N805 - """The ApifyClientAsync instance the Actor instance uses.""" # noqa: D401 + def apify_client(self_or_cls: type[Actor] | Actor) -> ApifyClientAsync: # noqa: N805 + """The ApifyClientAsync instance the Actor instance uses.""" if isinstance(self_or_cls, type): return self_or_cls._get_default_instance()._apify_client - else: - return self_or_cls._apify_client + return self_or_cls._apify_client @dualproperty - def config(self_or_cls) -> Configuration: # noqa: N805 - """The Configuration instance the Actor instance uses.""" # noqa: D401 + def config(self_or_cls: type[Actor] | Actor) -> Configuration: # noqa: N805 + """The Configuration instance the Actor instance uses.""" if isinstance(self_or_cls, type): return self_or_cls._get_default_instance()._config - else: - return self_or_cls._config + return self_or_cls._config @dualproperty - def event_manager(self_or_cls) -> EventManager: # noqa: N805 - """The EventManager instance the Actor instance uses.""" # noqa: D401 + def event_manager(self_or_cls: type[Actor] | Actor) -> EventManager: # noqa: N805 + """The EventManager instance the Actor instance uses.""" if isinstance(self_or_cls, type): return self_or_cls._get_default_instance()._event_manager - else: - return self_or_cls._event_manager + + return self_or_cls._event_manager @dualproperty - def log(_self_or_cls) -> logging.Logger: # noqa: N805 - """The logging.Logger instance the Actor uses.""" # noqa: D401 + def log(_self_or_cls: type[Actor] | Actor) -> logging.Logger: # noqa: N805 + """The logging.Logger instance the Actor uses.""" return logger - def _raise_if_not_initialized(self) -> None: + def _raise_if_not_initialized(self: Actor) -> None: if not self._is_initialized: raise RuntimeError('The actor was not initialized!') @classmethod - async def init(cls) -> None: + async def init(cls: type[Actor]) -> None: """Initialize the actor instance. This initializes the Actor instance. @@ -213,7 +216,7 @@ async def init(cls) -> None: """ return await cls._get_default_instance().init() - async def _init_internal(self) -> None: + async def _init_internal(self: Actor) -> None: if self._is_initialized: raise RuntimeError('The actor was already initialized!') @@ -221,9 +224,10 @@ async def _init_internal(self) -> None: self._was_final_persist_state_emitted = False self.log.info('Initializing actor...') - self.log.info('System info', extra=_get_system_info()) + self.log.info('System info', extra=get_system_info()) # TODO: Print outdated SDK version warning (we need a new env var for this) + # https://github.com/apify/apify-sdk-python/issues/146 StorageClientManager.set_config(self._config) if self._config.token: @@ -232,7 +236,7 @@ async def _init_internal(self) -> None: await self._event_manager.init() self._send_persist_state_interval_task = asyncio.create_task( - _run_func_at_interval_async( + run_func_at_interval_async( lambda: self._event_manager.emit(ActorEventTypes.PERSIST_STATE, {'isMigrating': False}), self._config.persist_state_interval_millis / 1000, ), @@ -240,8 +244,8 @@ async def _init_internal(self) -> None: if not self.is_at_home(): self._send_system_info_interval_task = asyncio.create_task( - _run_func_at_interval_async( - lambda: self._event_manager.emit(ActorEventTypes.SYSTEM_INFO, self._get_system_info()), + run_func_at_interval_async( + lambda: self._event_manager.emit(ActorEventTypes.SYSTEM_INFO, self.get_system_info()), self._config.system_info_interval_millis / 1000, ), ) @@ -250,13 +254,14 @@ async def _init_internal(self) -> None: # The CPU usage is calculated as an average between two last calls to psutil # We need to make a first, dummy call, so the next calls have something to compare itself agains - _get_cpu_usage_percent() + get_cpu_usage_percent() self._is_initialized = True - def _get_system_info(self) -> Dict: - cpu_usage_percent = _get_cpu_usage_percent() - memory_usage_bytes = _get_memory_usage_bytes() + def get_system_info(self: Actor) -> dict: + """Get the current system info.""" + cpu_usage_percent = get_cpu_usage_percent() + memory_usage_bytes = get_memory_usage_bytes() # This is in camel case to be compatible with the events from the platform result = { 'createdAt': datetime.now(timezone.utc), @@ -264,11 +269,11 @@ def _get_system_info(self) -> Dict: 'memCurrentBytes': memory_usage_bytes, } if self._config.max_used_cpu_ratio: - result['isCpuOverloaded'] = (cpu_usage_percent > 100 * self._config.max_used_cpu_ratio) + result['isCpuOverloaded'] = cpu_usage_percent > 100 * self._config.max_used_cpu_ratio return result - async def _respond_to_migrating_event(self, _event_data: Any) -> None: + async def _respond_to_migrating_event(self: Actor, _event_data: Any) -> None: # Don't emit any more regular persist state events if self._send_persist_state_interval_task and not self._send_persist_state_interval_task.cancelled(): self._send_persist_state_interval_task.cancel() @@ -278,7 +283,7 @@ async def _respond_to_migrating_event(self, _event_data: Any) -> None: self._event_manager.emit(ActorEventTypes.PERSIST_STATE, {'isMigrating': True}) self._was_final_persist_state_emitted = True - async def _cancel_event_emitting_intervals(self) -> None: + async def _cancel_event_emitting_intervals(self: Actor) -> None: if self._send_persist_state_interval_task and not self._send_persist_state_interval_task.cancelled(): self._send_persist_state_interval_task.cancel() with contextlib.suppress(asyncio.CancelledError): @@ -290,12 +295,12 @@ async def _cancel_event_emitting_intervals(self) -> None: await self._send_system_info_interval_task @classmethod - async def exit( - cls, + async def exit( # noqa: A003 + cls: type[Actor], *, exit_code: int = 0, - event_listeners_timeout_secs: Optional[float] = EVENT_LISTENERS_TIMEOUT_SECS, - status_message: Optional[str] = None, + event_listeners_timeout_secs: float | None = EVENT_LISTENERS_TIMEOUT_SECS, + status_message: str | None = None, ) -> None: """Exit the actor instance. @@ -317,11 +322,11 @@ async def exit( ) async def _exit_internal( - self, + self: Actor, *, exit_code: int = 0, - event_listeners_timeout_secs: Optional[float] = EVENT_LISTENERS_TIMEOUT_SECS, - status_message: Optional[str] = None, + event_listeners_timeout_secs: float | None = EVENT_LISTENERS_TIMEOUT_SECS, + status_message: str | None = None, ) -> None: self._raise_if_not_initialized() @@ -348,9 +353,9 @@ async def _exit_internal( self._is_initialized = False - if _is_running_in_ipython(): + if is_running_in_ipython(): self.log.debug(f'Not calling sys.exit({exit_code}) because actor is running in IPython') - elif os.getenv('PYTEST_CURRENT_TEST', False): + elif os.getenv('PYTEST_CURRENT_TEST', default=False): # noqa: PLW1508 self.log.debug(f'Not calling sys.exit({exit_code}) because actor is running in an unit test') elif hasattr(asyncio, '_nest_patched'): self.log.debug(f'Not calling sys.exit({exit_code}) because actor is running in a nested event loop') @@ -359,11 +364,11 @@ async def _exit_internal( @classmethod async def fail( - cls, + cls: type[Actor], *, exit_code: int = 1, - exception: Optional[BaseException] = None, - status_message: Optional[str] = None, + exception: BaseException | None = None, + status_message: str | None = None, ) -> None: """Fail the actor instance. @@ -382,23 +387,23 @@ async def fail( ) async def _fail_internal( - self, + self: Actor, *, exit_code: int = 1, - exception: Optional[BaseException] = None, - status_message: Optional[str] = None, + exception: BaseException | None = None, + status_message: str | None = None, ) -> None: self._raise_if_not_initialized() # In IPython, we don't run `sys.exit()` during actor exits, # so the exception traceback will be printed on its own - if exception and not _is_running_in_ipython(): + if exception and not is_running_in_ipython(): self.log.exception('Actor failed with an exception', exc_info=exception) await self.exit(exit_code=exit_code, status_message=status_message) @classmethod - async def main(cls, main_actor_function: Callable[[], MainReturnType]) -> Optional[MainReturnType]: + async def main(cls: type[Actor], main_actor_function: Callable[[], MainReturnType]) -> MainReturnType | None: """Initialize the actor, run the passed function and finish the actor cleanly. **The `Actor.main()` function is optional** and is provided merely for your convenience. @@ -423,7 +428,7 @@ async def main(cls, main_actor_function: Callable[[], MainReturnType]) -> Option main_actor_function=main_actor_function, ) - async def _main_internal(self, main_actor_function: Callable[[], MainReturnType]) -> Optional[MainReturnType]: + async def _main_internal(self: Actor, main_actor_function: Callable[[], MainReturnType]) -> MainReturnType | None: if not inspect.isfunction(main_actor_function): raise TypeError(f'First argument passed to Actor.main() must be a function, but instead it was {type(main_actor_function)}') @@ -435,22 +440,22 @@ async def _main_internal(self, main_actor_function: Callable[[], MainReturnType] res = main_actor_function() await self.exit() return cast(MainReturnType, res) - except Exception as e: + except Exception as exc: await self.fail( exit_code=ActorExitCodes.ERROR_USER_FUNCTION_THREW.value, - exception=e, + exception=exc, ) return None @classmethod def new_client( - cls, + cls: type[Actor], *, - token: Optional[str] = None, - api_url: Optional[str] = None, - max_retries: Optional[int] = None, - min_delay_between_retries_millis: Optional[int] = None, - timeout_secs: Optional[int] = None, + token: str | None = None, + api_url: str | None = None, + max_retries: int | None = None, + min_delay_between_retries_millis: int | None = None, + timeout_secs: int | None = None, ) -> ApifyClientAsync: """Return a new instance of the Apify API client. @@ -477,13 +482,13 @@ def new_client( ) def _new_client_internal( - self, + self: Actor, *, - token: Optional[str] = None, - api_url: Optional[str] = None, - max_retries: Optional[int] = None, - min_delay_between_retries_millis: Optional[int] = None, - timeout_secs: Optional[int] = None, + token: str | None = None, + api_url: str | None = None, + max_retries: int | None = None, + min_delay_between_retries_millis: int | None = None, + timeout_secs: int | None = None, ) -> ApifyClientAsync: token = token or self._config.token api_url = api_url or self._config.api_base_url @@ -495,11 +500,17 @@ def _new_client_internal( timeout_secs=timeout_secs, ) - def _get_storage_client(self, force_cloud: bool) -> Optional[ApifyClientAsync]: + def _get_storage_client(self: Actor, force_cloud: bool) -> ApifyClientAsync | None: # noqa: FBT001 return self._apify_client if force_cloud else None @classmethod - async def open_dataset(cls, *, id: Optional[str] = None, name: Optional[str] = None, force_cloud: bool = False) -> Dataset: + async def open_dataset( + cls: type[Actor], + *, + id: str | None = None, # noqa: A002 + name: str | None = None, + force_cloud: bool = False, + ) -> Dataset: """Open a dataset. Datasets are used to store structured data where each object stored has the same attributes, @@ -520,13 +531,25 @@ async def open_dataset(cls, *, id: Optional[str] = None, name: Optional[str] = N """ return await cls._get_default_instance().open_dataset(id=id, name=name, force_cloud=force_cloud) - async def _open_dataset_internal(self, *, id: Optional[str] = None, name: Optional[str] = None, force_cloud: bool = False) -> Dataset: + async def _open_dataset_internal( + self: Actor, + *, + id: str | None = None, # noqa: A002 + name: str | None = None, + force_cloud: bool = False, + ) -> Dataset: self._raise_if_not_initialized() return await Dataset.open(id=id, name=name, force_cloud=force_cloud, config=self._config) @classmethod - async def open_key_value_store(cls, *, id: Optional[str] = None, name: Optional[str] = None, force_cloud: bool = False) -> KeyValueStore: + async def open_key_value_store( + cls: type[Actor], + *, + id: str | None = None, # noqa: A002 + name: str | None = None, + force_cloud: bool = False, + ) -> KeyValueStore: """Open a key-value store. Key-value stores are used to store records or files, along with their MIME content type. @@ -547,10 +570,10 @@ async def open_key_value_store(cls, *, id: Optional[str] = None, name: Optional[ return await cls._get_default_instance().open_key_value_store(id=id, name=name, force_cloud=force_cloud) async def _open_key_value_store_internal( - self, + self: Actor, *, - id: Optional[str] = None, - name: Optional[str] = None, + id: str | None = None, # noqa: A002 + name: str | None = None, force_cloud: bool = False, ) -> KeyValueStore: self._raise_if_not_initialized() @@ -558,7 +581,13 @@ async def _open_key_value_store_internal( return await KeyValueStore.open(id=id, name=name, force_cloud=force_cloud, config=self._config) @classmethod - async def open_request_queue(cls, *, id: Optional[str] = None, name: Optional[str] = None, force_cloud: bool = False) -> RequestQueue: + async def open_request_queue( + cls: type[Actor], + *, + id: str | None = None, # noqa: A002 + name: str | None = None, + force_cloud: bool = False, + ) -> RequestQueue: """Open a request queue. Request queue represents a queue of URLs to crawl, which is stored either on local filesystem or in the Apify cloud. @@ -580,10 +609,10 @@ async def open_request_queue(cls, *, id: Optional[str] = None, name: Optional[st return await cls._get_default_instance().open_request_queue(id=id, name=name, force_cloud=force_cloud) async def _open_request_queue_internal( - self, + self: Actor, *, - id: Optional[str] = None, - name: Optional[str] = None, + id: str | None = None, # noqa: A002 + name: str | None = None, force_cloud: bool = False, ) -> RequestQueue: self._raise_if_not_initialized() @@ -591,7 +620,7 @@ async def _open_request_queue_internal( return await RequestQueue.open(id=id, name=name, force_cloud=force_cloud, config=self._config) @classmethod - async def push_data(cls, data: Any) -> None: + async def push_data(cls: type[Actor], data: Any) -> None: """Store an object or a list of objects to the default dataset of the current actor run. Args: @@ -599,7 +628,7 @@ async def push_data(cls, data: Any) -> None: """ return await cls._get_default_instance().push_data(data=data) - async def _push_data_internal(self, data: Any) -> None: + async def _push_data_internal(self: Actor, data: Any) -> None: self._raise_if_not_initialized() if not data: @@ -609,27 +638,27 @@ async def _push_data_internal(self, data: Any) -> None: await dataset.push_data(data) @classmethod - async def get_input(cls) -> Any: + async def get_input(cls: type[Actor]) -> Any: """Get the actor input value from the default key-value store associated with the current actor run.""" return await cls._get_default_instance().get_input() - async def _get_input_internal(self) -> Any: + async def _get_input_internal(self: Actor) -> Any: self._raise_if_not_initialized() input_value = await self.get_value(self._config.input_key) input_secrets_private_key = self._config.input_secrets_private_key_file input_secrets_key_passphrase = self._config.input_secrets_private_key_passphrase if input_secrets_private_key and input_secrets_key_passphrase: - private_key = _load_private_key( + private_key = load_private_key( input_secrets_private_key, input_secrets_key_passphrase, ) - input_value = _decrypt_input_secrets(private_key, input_value) + input_value = decrypt_input_secrets(private_key, input_value) return input_value @classmethod - async def get_value(cls, key: str, default_value: Optional[T] = None) -> Any: + async def get_value(cls: type[Actor], key: str, default_value: Any = None) -> Any: """Get a value from the default key-value store associated with the current actor run. Args: @@ -638,20 +667,19 @@ async def get_value(cls, key: str, default_value: Optional[T] = None) -> Any: """ return await cls._get_default_instance().get_value(key=key, default_value=default_value) - async def _get_value_internal(self, key: str, default_value: Optional[T] = None) -> Any: + async def _get_value_internal(self: Actor, key: str, default_value: Any = None) -> Any: self._raise_if_not_initialized() key_value_store = await self.open_key_value_store() - value = await key_value_store.get_value(key, default_value) - return value + return await key_value_store.get_value(key, default_value) @classmethod async def set_value( - cls, + cls: type[Actor], key: str, value: Any, *, - content_type: Optional[str] = None, + content_type: str | None = None, ) -> None: """Set or delete a value in the default key-value store associated with the current actor run. @@ -667,11 +695,11 @@ async def set_value( ) async def _set_value_internal( - self, + self: Actor, key: str, value: Any, *, - content_type: Optional[str] = None, + content_type: str | None = None, ) -> None: self._raise_if_not_initialized() @@ -679,7 +707,7 @@ async def _set_value_internal( return await key_value_store.set_value(key, value, content_type=content_type) @classmethod - def on(cls, event_name: ActorEventTypes, listener: Callable) -> Callable: + def on(cls: type[Actor], event_name: ActorEventTypes, listener: Callable) -> Callable: """Add an event listener to the actor's event manager. The following events can be emitted: @@ -707,13 +735,13 @@ def on(cls, event_name: ActorEventTypes, listener: Callable) -> Callable: """ return cls._get_default_instance().on(event_name, listener) - def _on_internal(self, event_name: ActorEventTypes, listener: Callable) -> Callable: + def _on_internal(self: Actor, event_name: ActorEventTypes, listener: Callable) -> Callable: self._raise_if_not_initialized() return self._event_manager.on(event_name, listener) @classmethod - def off(cls, event_name: ActorEventTypes, listener: Optional[Callable] = None) -> None: + def off(cls: type[Actor], event_name: ActorEventTypes, listener: Callable | None = None) -> None: """Remove a listener, or all listeners, from an actor event. Args: @@ -722,21 +750,21 @@ def off(cls, event_name: ActorEventTypes, listener: Optional[Callable] = None) - """ return cls._get_default_instance().off(event_name, listener) - def _off_internal(self, event_name: ActorEventTypes, listener: Optional[Callable] = None) -> None: + def _off_internal(self: Actor, event_name: ActorEventTypes, listener: Callable | None = None) -> None: self._raise_if_not_initialized() return self._event_manager.off(event_name, listener) @classmethod - def is_at_home(cls) -> bool: + def is_at_home(cls: type[Actor]) -> bool: """Return `True` when the actor is running on the Apify platform, and `False` otherwise (for example when running locally).""" return cls._get_default_instance().is_at_home() - def _is_at_home_internal(self) -> bool: + def _is_at_home_internal(self: Actor) -> bool: return self._config.is_at_home @classmethod - def get_env(cls) -> Dict: + def get_env(cls: type[Actor]) -> dict: """Return a dictionary with information parsed from all the `APIFY_XXX` environment variables. For a list of all the environment variables, @@ -745,27 +773,25 @@ def get_env(cls) -> Dict: """ return cls._get_default_instance().get_env() - def _get_env_internal(self) -> Dict: + def _get_env_internal(self: Actor) -> dict: self._raise_if_not_initialized() - return { - env_var.name.lower(): _fetch_and_parse_env_var(env_var) for env_var in [*ActorEnvVars, *ApifyEnvVars] - } + return {env_var.name.lower(): fetch_and_parse_env_var(env_var) for env_var in [*ActorEnvVars, *ApifyEnvVars]} @classmethod async def start( - cls, + cls: type[Actor], actor_id: str, - run_input: Optional[Any] = None, + run_input: Any = None, *, - token: Optional[str] = None, - content_type: Optional[str] = None, - build: Optional[str] = None, - memory_mbytes: Optional[int] = None, - timeout_secs: Optional[int] = None, - wait_for_finish: Optional[int] = None, - webhooks: Optional[List[Dict]] = None, - ) -> Dict: + token: str | None = None, + content_type: str | None = None, + build: str | None = None, + memory_mbytes: int | None = None, + timeout_secs: int | None = None, + wait_for_finish: int | None = None, + webhooks: list[dict] | None = None, + ) -> dict: """Run an actor on the Apify platform. Unlike `Actor.call`, this method just starts the run without waiting for finish. @@ -808,24 +834,21 @@ async def start( ) async def _start_internal( - self, + self: Actor, actor_id: str, - run_input: Optional[Any] = None, + run_input: Any = None, *, - token: Optional[str] = None, - content_type: Optional[str] = None, - build: Optional[str] = None, - memory_mbytes: Optional[int] = None, - timeout_secs: Optional[int] = None, - wait_for_finish: Optional[int] = None, - webhooks: Optional[List[Dict]] = None, - ) -> Dict: + token: str | None = None, + content_type: str | None = None, + build: str | None = None, + memory_mbytes: int | None = None, + timeout_secs: int | None = None, + wait_for_finish: int | None = None, + webhooks: list[dict] | None = None, + ) -> dict: self._raise_if_not_initialized() - if token: - client = self.new_client(token=token) - else: - client = self._apify_client + client = self.new_client(token=token) if token else self._apify_client return await client.actor(actor_id).start( run_input=run_input, @@ -839,12 +862,12 @@ async def _start_internal( @classmethod async def abort( - cls, + cls: type[Actor], run_id: str, *, - token: Optional[str] = None, - gracefully: Optional[bool] = None, - ) -> Dict: + token: str | None = None, + gracefully: bool | None = None, + ) -> dict: """Abort given actor run on the Apify platform using the current user account (determined by the `APIFY_TOKEN` environment variable). Args: @@ -864,19 +887,16 @@ async def abort( ) async def _abort_internal( - self, + self: Actor, run_id: str, *, - token: Optional[str] = None, - status_message: Optional[str] = None, - gracefully: Optional[bool] = None, - ) -> Dict: + token: str | None = None, + status_message: str | None = None, + gracefully: bool | None = None, + ) -> dict: self._raise_if_not_initialized() - if token: - client = self.new_client(token=token) - else: - client = self._apify_client + client = self.new_client(token=token) if token else self._apify_client if status_message: await client.run(run_id).update(status_message=status_message) @@ -885,18 +905,18 @@ async def _abort_internal( @classmethod async def call( - cls, + cls: type[Actor], actor_id: str, - run_input: Optional[Any] = None, + run_input: Any = None, *, - token: Optional[str] = None, - content_type: Optional[str] = None, - build: Optional[str] = None, - memory_mbytes: Optional[int] = None, - timeout_secs: Optional[int] = None, - webhooks: Optional[List[Dict]] = None, - wait_secs: Optional[int] = None, - ) -> Optional[Dict]: + token: str | None = None, + content_type: str | None = None, + build: str | None = None, + memory_mbytes: int | None = None, + timeout_secs: int | None = None, + webhooks: list[dict] | None = None, + wait_secs: int | None = None, + ) -> dict | None: """Start an actor on the Apify Platform and wait for it to finish before returning. It waits indefinitely, unless the wait_secs argument is provided. @@ -933,24 +953,21 @@ async def call( ) async def _call_internal( - self, + self: Actor, actor_id: str, - run_input: Optional[Any] = None, + run_input: Any = None, *, - token: Optional[str] = None, - content_type: Optional[str] = None, - build: Optional[str] = None, - memory_mbytes: Optional[int] = None, - timeout_secs: Optional[int] = None, - webhooks: Optional[List[Dict]] = None, - wait_secs: Optional[int] = None, - ) -> Optional[Dict]: + token: str | None = None, + content_type: str | None = None, + build: str | None = None, + memory_mbytes: int | None = None, + timeout_secs: int | None = None, + webhooks: list[dict] | None = None, + wait_secs: int | None = None, + ) -> dict | None: self._raise_if_not_initialized() - if token: - client = self.new_client(token=token) - else: - client = self._apify_client + client = self.new_client(token=token) if token else self._apify_client return await client.actor(actor_id).call( run_input=run_input, @@ -964,17 +981,17 @@ async def _call_internal( @classmethod async def call_task( - cls, + cls: type[Actor], task_id: str, - task_input: Optional[Dict[str, Any]] = None, + task_input: dict | None = None, *, - build: Optional[str] = None, - memory_mbytes: Optional[int] = None, - timeout_secs: Optional[int] = None, - webhooks: Optional[List[Dict]] = None, - wait_secs: Optional[int] = None, - token: Optional[str] = None, - ) -> Optional[Dict]: + build: str | None = None, + memory_mbytes: int | None = None, + timeout_secs: int | None = None, + webhooks: list[dict] | None = None, + wait_secs: int | None = None, + token: str | None = None, + ) -> dict | None: """Start an actor task on the Apify Platform and wait for it to finish before returning. It waits indefinitely, unless the wait_secs argument is provided. @@ -1013,23 +1030,20 @@ async def call_task( ) async def _call_task_internal( - self, + self: Actor, task_id: str, - task_input: Optional[Dict[str, Any]] = None, + task_input: dict | None = None, *, - build: Optional[str] = None, - memory_mbytes: Optional[int] = None, - timeout_secs: Optional[int] = None, - webhooks: Optional[List[Dict]] = None, - wait_secs: Optional[int] = None, - token: Optional[str] = None, - ) -> Optional[Dict]: + build: str | None = None, + memory_mbytes: int | None = None, + timeout_secs: int | None = None, + webhooks: list[dict] | None = None, + wait_secs: int | None = None, + token: str | None = None, + ) -> dict | None: self._raise_if_not_initialized() - if token: - client = self.new_client(token=token) - else: - client = self._apify_client + client = self.new_client(token=token) if token else self._apify_client return await client.task(task_id).call( task_input=task_input, @@ -1042,13 +1056,13 @@ async def _call_task_internal( @classmethod async def metamorph( - cls, + cls: type[Actor], target_actor_id: str, - run_input: Optional[Any] = None, + run_input: Any = None, *, - target_actor_build: Optional[str] = None, - content_type: Optional[str] = None, - custom_after_sleep_millis: Optional[int] = None, + target_actor_build: str | None = None, + content_type: str | None = None, + custom_after_sleep_millis: int | None = None, ) -> None: """Transform this actor run to an actor run of a different actor. @@ -1076,13 +1090,13 @@ async def metamorph( ) async def _metamorph_internal( - self, + self: Actor, target_actor_id: str, - run_input: Optional[Any] = None, + run_input: Any = None, *, - target_actor_build: Optional[str] = None, - content_type: Optional[str] = None, - custom_after_sleep_millis: Optional[int] = None, + target_actor_build: str | None = None, + content_type: str | None = None, + custom_after_sleep_millis: int | None = None, ) -> None: self._raise_if_not_initialized() @@ -1094,7 +1108,7 @@ async def _metamorph_internal( custom_after_sleep_millis = self._config.metamorph_after_sleep_millis # If is_at_home() is True, config.actor_run_id is always set - assert self._config.actor_run_id is not None + assert self._config.actor_run_id is not None # noqa: S101 await self._apify_client.run(self._config.actor_run_id).metamorph( target_actor_id=target_actor_id, @@ -1108,10 +1122,10 @@ async def _metamorph_internal( @classmethod async def reboot( - cls, + cls: type[Actor], *, - event_listeners_timeout_secs: Optional[int] = EVENT_LISTENERS_TIMEOUT_SECS, - custom_after_sleep_millis: Optional[int] = None, + event_listeners_timeout_secs: int | None = EVENT_LISTENERS_TIMEOUT_SECS, + custom_after_sleep_millis: int | None = None, ) -> None: """Internally reboot this actor. @@ -1127,10 +1141,10 @@ async def reboot( ) async def _reboot_internal( - self, + self: Actor, *, - event_listeners_timeout_secs: Optional[int] = EVENT_LISTENERS_TIMEOUT_SECS, - custom_after_sleep_millis: Optional[int] = None, + event_listeners_timeout_secs: int | None = EVENT_LISTENERS_TIMEOUT_SECS, + custom_after_sleep_millis: int | None = None, ) -> None: self._raise_if_not_initialized() @@ -1148,7 +1162,7 @@ async def _reboot_internal( await self._event_manager.close(event_listeners_timeout_secs=event_listeners_timeout_secs) - assert self._config.actor_run_id is not None + assert self._config.actor_run_id is not None # noqa: S101 await self._apify_client.run(self._config.actor_run_id).reboot() if custom_after_sleep_millis: @@ -1156,15 +1170,15 @@ async def _reboot_internal( @classmethod async def add_webhook( - cls, + cls: type[Actor], *, - event_types: List[WebhookEventType], + event_types: list[WebhookEventType], request_url: str, - payload_template: Optional[str] = None, - ignore_ssl_errors: Optional[bool] = None, - do_not_retry: Optional[bool] = None, - idempotency_key: Optional[str] = None, - ) -> Dict: + payload_template: str | None = None, + ignore_ssl_errors: bool | None = None, + do_not_retry: bool | None = None, + idempotency_key: str | None = None, + ) -> dict: """Create an ad-hoc webhook for the current actor run. This webhook lets you receive a notification when the actor run finished or failed. @@ -1197,15 +1211,15 @@ async def add_webhook( ) async def _add_webhook_internal( - self, + self: Actor, *, - event_types: List[WebhookEventType], + event_types: list[WebhookEventType], request_url: str, - payload_template: Optional[str] = None, - ignore_ssl_errors: Optional[bool] = None, - do_not_retry: Optional[bool] = None, - idempotency_key: Optional[str] = None, - ) -> Optional[Dict]: + payload_template: str | None = None, + ignore_ssl_errors: bool | None = None, + do_not_retry: bool | None = None, + idempotency_key: str | None = None, + ) -> dict | None: self._raise_if_not_initialized() if not self.is_at_home(): @@ -1213,7 +1227,7 @@ async def _add_webhook_internal( return None # If is_at_home() is True, config.actor_run_id is always set - assert self._config.actor_run_id is not None + assert self._config.actor_run_id is not None # noqa: S101 return await self._apify_client.webhooks().create( actor_run_id=self._config.actor_run_id, @@ -1226,7 +1240,12 @@ async def _add_webhook_internal( ) @classmethod - async def set_status_message(cls, status_message: str, *, is_terminal: Optional[bool] = None) -> Optional[Dict]: + async def set_status_message( + cls: type[Actor], + status_message: str, + *, + is_terminal: bool | None = None, + ) -> dict | None: """Set the status message for the current actor run. Args: @@ -1238,7 +1257,12 @@ async def set_status_message(cls, status_message: str, *, is_terminal: Optional[ """ return await cls._get_default_instance().set_status_message(status_message=status_message, is_terminal=is_terminal) - async def _set_status_message_internal(self, status_message: str, *, is_terminal: Optional[bool] = None) -> Optional[Dict]: + async def _set_status_message_internal( + self: Actor, + status_message: str, + *, + is_terminal: bool | None = None, + ) -> dict | None: self._raise_if_not_initialized() if not self.is_at_home(): @@ -1247,21 +1271,21 @@ async def _set_status_message_internal(self, status_message: str, *, is_terminal return None # If is_at_home() is True, config.actor_run_id is always set - assert self._config.actor_run_id is not None + assert self._config.actor_run_id is not None # noqa: S101 return await self._apify_client.run(self._config.actor_run_id).update(status_message=status_message, is_status_message_terminal=is_terminal) @classmethod async def create_proxy_configuration( - cls, + cls: type[Actor], *, - actor_proxy_input: Optional[Dict] = None, # this is the raw proxy input from the actor run input, it is not spread or snake_cased in here - password: Optional[str] = None, - groups: Optional[List[str]] = None, - country_code: Optional[str] = None, - proxy_urls: Optional[List[str]] = None, - new_url_function: Optional[Union[Callable[[Optional[str]], str], Callable[[Optional[str]], Awaitable[str]]]] = None, - ) -> Optional[ProxyConfiguration]: + actor_proxy_input: dict | None = None, # this is the raw proxy input from the actor run input, it is not spread or snake_cased in here + password: str | None = None, + groups: list[str] | None = None, + country_code: str | None = None, + proxy_urls: list[str] | None = None, + new_url_function: Callable[[str | None], str] | Callable[[str | None], Awaitable[str]] | None = None, + ) -> ProxyConfiguration | None: """Create a ProxyConfiguration object with the passed proxy configuration. Configures connection to a proxy server with the provided options. @@ -1292,15 +1316,15 @@ async def create_proxy_configuration( ) async def _create_proxy_configuration_internal( - self, + self: Actor, *, - actor_proxy_input: Optional[Dict] = None, # this is the raw proxy input from the actor run input, it is not spread or snake_cased in here - password: Optional[str] = None, - groups: Optional[List[str]] = None, - country_code: Optional[str] = None, - proxy_urls: Optional[List[str]] = None, - new_url_function: Optional[Union[Callable[[Optional[str]], str], Callable[[Optional[str]], Awaitable[str]]]] = None, - ) -> Optional[ProxyConfiguration]: + actor_proxy_input: dict | None = None, # this is the raw proxy input from the actor run input, it is not spread or snake_cased in here + password: str | None = None, + groups: list[str] | None = None, + country_code: str | None = None, + proxy_urls: list[str] | None = None, + new_url_function: Callable[[str | None], str] | Callable[[str | None], Awaitable[str]] | None = None, + ) -> ProxyConfiguration | None: self._raise_if_not_initialized() if actor_proxy_input is not None: diff --git a/src/apify/config.py b/src/apify/config.py index 9b66a03e..1eeb738a 100644 --- a/src/apify/config.py +++ b/src/apify/config.py @@ -1,8 +1,8 @@ -from typing import Optional +from __future__ import annotations from apify_shared.consts import ActorEnvVars, ApifyEnvVars -from ._utils import _fetch_and_parse_env_var +from ._utils import fetch_and_parse_env_var class Configuration: @@ -12,30 +12,30 @@ class Configuration: or it can be specific to each `Actor` instance on the `actor.config` property. """ - _default_instance: Optional['Configuration'] = None + _default_instance: Configuration | None = None def __init__( - self, + self: Configuration, *, - api_base_url: Optional[str] = None, - api_public_base_url: Optional[str] = None, - container_port: Optional[int] = None, - container_url: Optional[str] = None, - default_dataset_id: Optional[str] = None, - default_key_value_store_id: Optional[str] = None, - default_request_queue_id: Optional[str] = None, - input_key: Optional[str] = None, - max_used_cpu_ratio: Optional[float] = None, - metamorph_after_sleep_millis: Optional[int] = None, - persist_state_interval_millis: Optional[int] = None, - persist_storage: Optional[bool] = None, - proxy_hostname: Optional[str] = None, - proxy_password: Optional[str] = None, - proxy_port: Optional[int] = None, - proxy_status_url: Optional[str] = None, - purge_on_start: Optional[bool] = None, - token: Optional[str] = None, - system_info_interval_millis: Optional[int] = None, + api_base_url: str | None = None, + api_public_base_url: str | None = None, + container_port: int | None = None, + container_url: str | None = None, + default_dataset_id: str | None = None, + default_key_value_store_id: str | None = None, + default_request_queue_id: str | None = None, + input_key: str | None = None, + max_used_cpu_ratio: float | None = None, + metamorph_after_sleep_millis: int | None = None, + persist_state_interval_millis: int | None = None, + persist_storage: bool | None = None, + proxy_hostname: str | None = None, + proxy_password: str | None = None, + proxy_port: int | None = None, + proxy_status_url: str | None = None, + purge_on_start: bool | None = None, + token: str | None = None, + system_info_interval_millis: int | None = None, ) -> None: """Create a `Configuration` instance. @@ -67,55 +67,58 @@ def __init__( system_info_interval_millis (str, optional): How often should the actor emit the SYSTEM_INFO event when running locally. """ # TODO: Document all these members - self.actor_build_id = _fetch_and_parse_env_var(ActorEnvVars.BUILD_ID) - self.actor_build_number = _fetch_and_parse_env_var(ActorEnvVars.BUILD_NUMBER) - self.actor_events_ws_url = _fetch_and_parse_env_var(ActorEnvVars.EVENTS_WEBSOCKET_URL) - self.actor_id = _fetch_and_parse_env_var(ActorEnvVars.ID) - self.actor_run_id = _fetch_and_parse_env_var(ActorEnvVars.RUN_ID) - self.actor_task_id = _fetch_and_parse_env_var(ActorEnvVars.TASK_ID) - self.api_base_url = api_base_url or _fetch_and_parse_env_var(ApifyEnvVars.API_BASE_URL, 'https://api.apify.com') - self.api_public_base_url = api_public_base_url or _fetch_and_parse_env_var(ApifyEnvVars.API_PUBLIC_BASE_URL, 'https://api.apify.com') - self.chrome_executable_path = _fetch_and_parse_env_var(ApifyEnvVars.CHROME_EXECUTABLE_PATH) - self.container_port = container_port or _fetch_and_parse_env_var(ActorEnvVars.WEB_SERVER_PORT, 4321) - self.container_url = container_url or _fetch_and_parse_env_var(ActorEnvVars.WEB_SERVER_URL, 'http://localhost:4321') - self.dedicated_cpus = _fetch_and_parse_env_var(ApifyEnvVars.DEDICATED_CPUS) - self.default_browser_path = _fetch_and_parse_env_var(ApifyEnvVars.DEFAULT_BROWSER_PATH) - self.default_dataset_id = default_dataset_id or _fetch_and_parse_env_var(ActorEnvVars.DEFAULT_DATASET_ID, 'default') - self.default_key_value_store_id = default_key_value_store_id or _fetch_and_parse_env_var(ActorEnvVars.DEFAULT_KEY_VALUE_STORE_ID, 'default') - self.default_request_queue_id = default_request_queue_id or _fetch_and_parse_env_var(ActorEnvVars.DEFAULT_REQUEST_QUEUE_ID, 'default') - self.disable_browser_sandbox = _fetch_and_parse_env_var(ApifyEnvVars.DISABLE_BROWSER_SANDBOX, False) - self.headless = _fetch_and_parse_env_var(ApifyEnvVars.HEADLESS, True) - self.input_key = input_key or _fetch_and_parse_env_var(ActorEnvVars.INPUT_KEY, 'INPUT') - self.input_secrets_private_key_file = _fetch_and_parse_env_var(ApifyEnvVars.INPUT_SECRETS_PRIVATE_KEY_FILE) - self.input_secrets_private_key_passphrase = _fetch_and_parse_env_var(ApifyEnvVars.INPUT_SECRETS_PRIVATE_KEY_PASSPHRASE) - self.is_at_home = _fetch_and_parse_env_var(ApifyEnvVars.IS_AT_HOME, False) - self.max_used_cpu_ratio = max_used_cpu_ratio or _fetch_and_parse_env_var(ApifyEnvVars.MAX_USED_CPU_RATIO, 0.95) - self.memory_mbytes = _fetch_and_parse_env_var(ActorEnvVars.MEMORY_MBYTES) - self.meta_origin = _fetch_and_parse_env_var(ApifyEnvVars.META_ORIGIN) - self.metamorph_after_sleep_millis = metamorph_after_sleep_millis or _fetch_and_parse_env_var(ApifyEnvVars.METAMORPH_AFTER_SLEEP_MILLIS, 300000) # noqa: E501 - self.persist_state_interval_millis = persist_state_interval_millis or _fetch_and_parse_env_var(ApifyEnvVars.PERSIST_STATE_INTERVAL_MILLIS, 60000) # noqa: E501 - self.persist_storage = persist_storage or _fetch_and_parse_env_var(ApifyEnvVars.PERSIST_STORAGE, True) - self.proxy_hostname = proxy_hostname or _fetch_and_parse_env_var(ApifyEnvVars.PROXY_HOSTNAME, 'proxy.apify.com') - self.proxy_password = proxy_password or _fetch_and_parse_env_var(ApifyEnvVars.PROXY_PASSWORD) - self.proxy_port = proxy_port or _fetch_and_parse_env_var(ApifyEnvVars.PROXY_PORT, 8000) - self.proxy_status_url = proxy_status_url or _fetch_and_parse_env_var(ApifyEnvVars.PROXY_STATUS_URL, 'http://proxy.apify.com') - self.purge_on_start = purge_on_start or _fetch_and_parse_env_var(ApifyEnvVars.PURGE_ON_START, False) - self.started_at = _fetch_and_parse_env_var(ActorEnvVars.STARTED_AT) - self.timeout_at = _fetch_and_parse_env_var(ActorEnvVars.TIMEOUT_AT) - self.token = token or _fetch_and_parse_env_var(ApifyEnvVars.TOKEN) - self.user_id = _fetch_and_parse_env_var(ApifyEnvVars.USER_ID) - self.xvfb = _fetch_and_parse_env_var(ApifyEnvVars.XVFB, False) - self.system_info_interval_millis = system_info_interval_millis or _fetch_and_parse_env_var(ApifyEnvVars.SYSTEM_INFO_INTERVAL_MILLIS, 60000) + # https://github.com/apify/apify-sdk-python/issues/147 + self.actor_build_id = fetch_and_parse_env_var(ActorEnvVars.BUILD_ID) + self.actor_build_number = fetch_and_parse_env_var(ActorEnvVars.BUILD_NUMBER) + self.actor_events_ws_url = fetch_and_parse_env_var(ActorEnvVars.EVENTS_WEBSOCKET_URL) + self.actor_id = fetch_and_parse_env_var(ActorEnvVars.ID) + self.actor_run_id = fetch_and_parse_env_var(ActorEnvVars.RUN_ID) + self.actor_task_id = fetch_and_parse_env_var(ActorEnvVars.TASK_ID) + self.api_base_url = api_base_url or fetch_and_parse_env_var(ApifyEnvVars.API_BASE_URL, 'https://api.apify.com') + self.api_public_base_url = api_public_base_url or fetch_and_parse_env_var(ApifyEnvVars.API_PUBLIC_BASE_URL, 'https://api.apify.com') + self.chrome_executable_path = fetch_and_parse_env_var(ApifyEnvVars.CHROME_EXECUTABLE_PATH) + self.container_port = container_port or fetch_and_parse_env_var(ActorEnvVars.WEB_SERVER_PORT, 4321) + self.container_url = container_url or fetch_and_parse_env_var(ActorEnvVars.WEB_SERVER_URL, 'http://localhost:4321') + self.dedicated_cpus = fetch_and_parse_env_var(ApifyEnvVars.DEDICATED_CPUS) + self.default_browser_path = fetch_and_parse_env_var(ApifyEnvVars.DEFAULT_BROWSER_PATH) + self.default_dataset_id = default_dataset_id or fetch_and_parse_env_var(ActorEnvVars.DEFAULT_DATASET_ID, 'default') + self.default_key_value_store_id = default_key_value_store_id or fetch_and_parse_env_var(ActorEnvVars.DEFAULT_KEY_VALUE_STORE_ID, 'default') + self.default_request_queue_id = default_request_queue_id or fetch_and_parse_env_var(ActorEnvVars.DEFAULT_REQUEST_QUEUE_ID, 'default') + self.disable_browser_sandbox = fetch_and_parse_env_var(ApifyEnvVars.DISABLE_BROWSER_SANDBOX, default=False) + self.headless = fetch_and_parse_env_var(ApifyEnvVars.HEADLESS, default=True) + self.input_key = input_key or fetch_and_parse_env_var(ActorEnvVars.INPUT_KEY, 'INPUT') + self.input_secrets_private_key_file = fetch_and_parse_env_var(ApifyEnvVars.INPUT_SECRETS_PRIVATE_KEY_FILE) + self.input_secrets_private_key_passphrase = fetch_and_parse_env_var(ApifyEnvVars.INPUT_SECRETS_PRIVATE_KEY_PASSPHRASE) + self.is_at_home = fetch_and_parse_env_var(ApifyEnvVars.IS_AT_HOME, default=False) + self.max_used_cpu_ratio = max_used_cpu_ratio or fetch_and_parse_env_var(ApifyEnvVars.MAX_USED_CPU_RATIO, 0.95) + self.memory_mbytes = fetch_and_parse_env_var(ActorEnvVars.MEMORY_MBYTES) + self.meta_origin = fetch_and_parse_env_var(ApifyEnvVars.META_ORIGIN) + self.metamorph_after_sleep_millis = metamorph_after_sleep_millis or fetch_and_parse_env_var(ApifyEnvVars.METAMORPH_AFTER_SLEEP_MILLIS, 300000) + self.persist_state_interval_millis = persist_state_interval_millis or fetch_and_parse_env_var( + ApifyEnvVars.PERSIST_STATE_INTERVAL_MILLIS, 60000 + ) + self.persist_storage = persist_storage or fetch_and_parse_env_var(ApifyEnvVars.PERSIST_STORAGE, default=True) + self.proxy_hostname = proxy_hostname or fetch_and_parse_env_var(ApifyEnvVars.PROXY_HOSTNAME, 'proxy.apify.com') + self.proxy_password = proxy_password or fetch_and_parse_env_var(ApifyEnvVars.PROXY_PASSWORD) + self.proxy_port = proxy_port or fetch_and_parse_env_var(ApifyEnvVars.PROXY_PORT, 8000) + self.proxy_status_url = proxy_status_url or fetch_and_parse_env_var(ApifyEnvVars.PROXY_STATUS_URL, 'http://proxy.apify.com') + self.purge_on_start = purge_on_start or fetch_and_parse_env_var(ApifyEnvVars.PURGE_ON_START, default=False) + self.started_at = fetch_and_parse_env_var(ActorEnvVars.STARTED_AT) + self.timeout_at = fetch_and_parse_env_var(ActorEnvVars.TIMEOUT_AT) + self.token = token or fetch_and_parse_env_var(ApifyEnvVars.TOKEN) + self.user_id = fetch_and_parse_env_var(ApifyEnvVars.USER_ID) + self.xvfb = fetch_and_parse_env_var(ApifyEnvVars.XVFB, default=False) + self.system_info_interval_millis = system_info_interval_millis or fetch_and_parse_env_var(ApifyEnvVars.SYSTEM_INFO_INTERVAL_MILLIS, 60000) @classmethod - def _get_default_instance(cls) -> 'Configuration': + def _get_default_instance(cls: type[Configuration]) -> Configuration: if cls._default_instance is None: cls._default_instance = cls() return cls._default_instance @classmethod - def get_global_configuration(cls) -> 'Configuration': + def get_global_configuration(cls: type[Configuration]) -> Configuration: """Retrive the global configuration. The global configuration applies when you call actor methods via their static versions, e.g. `Actor.init()`. diff --git a/src/apify/consts.py b/src/apify/consts.py index 06684fc6..47d2ca7b 100644 --- a/src/apify/consts.py +++ b/src/apify/consts.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re import warnings from enum import Enum @@ -41,7 +43,7 @@ def __getattr__(name: str) -> Any: raise AttributeError(f'module {__name__!r} has no attribute {name!r}') -class _StorageTypes(str, Enum): +class StorageTypes(str, Enum): """Possible Apify storage types.""" DATASET = 'Dataset' diff --git a/src/apify/event_manager.py b/src/apify/event_manager.py index 0ca94a44..dd6b0193 100644 --- a/src/apify/event_manager.py +++ b/src/apify/event_manager.py @@ -1,19 +1,24 @@ +from __future__ import annotations + import asyncio import contextlib import inspect import json from collections import defaultdict -from typing import Any, Callable, Coroutine, Dict, List, Optional, Set, Union +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Union import websockets.client from pyee.asyncio import AsyncIOEventEmitter -from apify_shared.consts import ActorEventTypes from apify_shared.utils import ignore_docs, maybe_extract_enum_member_value, parse_date_fields -from .config import Configuration from .log import logger +if TYPE_CHECKING: + from apify_shared.consts import ActorEventTypes + + from .config import Configuration + ListenerType = Union[Callable[[], None], Callable[[Any], None], Callable[[], Coroutine[Any, Any, None]], Callable[[Any], Coroutine[Any, Any, None]]] @@ -25,15 +30,15 @@ class EventManager: but instead use it via the `Actor.on()` and `Actor.off()` methods. """ - _platform_events_websocket: Optional[websockets.client.WebSocketClientProtocol] = None - _process_platform_messages_task: Optional[asyncio.Task] = None - _send_persist_state_interval_task: Optional[asyncio.Task] = None - _send_system_info_interval_task: Optional[asyncio.Task] = None - _listener_tasks: Set[asyncio.Task] - _listeners_to_wrappers: Dict[ActorEventTypes, Dict[Callable, List[Callable]]] - _connected_to_platform_websocket: Optional[asyncio.Future] = None + _platform_events_websocket: websockets.client.WebSocketClientProtocol | None = None + _process_platform_messages_task: asyncio.Task | None = None + _send_persist_state_interval_task: asyncio.Task | None = None + _send_system_info_interval_task: asyncio.Task | None = None + _listener_tasks: set[asyncio.Task] + _listeners_to_wrappers: dict[ActorEventTypes, dict[Callable, list[Callable]]] + _connected_to_platform_websocket: asyncio.Future | None = None - def __init__(self, config: Configuration) -> None: + def __init__(self: EventManager, config: Configuration) -> None: """Create an instance of the EventManager. Args: @@ -45,7 +50,7 @@ def __init__(self, config: Configuration) -> None: self._listener_tasks = set() self._listeners_to_wrappers = defaultdict(lambda: defaultdict(list)) - async def init(self) -> None: + async def init(self: EventManager) -> None: """Initialize the event manager. When running this on the Apify Platform, this will start processing events @@ -67,7 +72,7 @@ async def init(self) -> None: self._initialized = True - async def close(self, event_listeners_timeout_secs: Optional[float] = None) -> None: + async def close(self: EventManager, event_listeners_timeout_secs: float | None = None) -> None: """Initialize the event manager. This will stop listening for the platform events, @@ -91,7 +96,7 @@ async def close(self, event_listeners_timeout_secs: Optional[float] = None) -> N self._initialized = False - def on(self, event_name: ActorEventTypes, listener: ListenerType) -> Callable: + def on(self: EventManager, event_name: ActorEventTypes, listener: ListenerType) -> Callable: """Add an event listener to the event manager. Args: @@ -111,15 +116,15 @@ def on(self, event_name: ActorEventTypes, listener: ListenerType) -> Callable: listener_argument_count = 1 else: try: - dummy_event_data: Dict = {} + dummy_event_data: dict = {} signature.bind(dummy_event_data) listener_argument_count = 1 except TypeError: try: signature.bind() listener_argument_count = 0 - except TypeError: - raise ValueError('The "listener" argument must be a callable which accepts 0 or 1 arguments!') + except TypeError as err: + raise ValueError('The "listener" argument must be a callable which accepts 0 or 1 arguments!') from err event_name = maybe_extract_enum_member_value(event_name) @@ -129,11 +134,10 @@ async def inner_wrapper(event_data: Any) -> None: await listener() else: await listener(event_data) + elif listener_argument_count == 0: + listener() # type: ignore[call-arg] else: - if listener_argument_count == 0: - listener() # type: ignore[call-arg] - else: - listener(event_data) # type: ignore[call-arg] + listener(event_data) # type: ignore[call-arg] async def outer_wrapper(event_data: Any) -> None: listener_task = asyncio.create_task(inner_wrapper(event_data)) @@ -152,7 +156,7 @@ async def outer_wrapper(event_data: Any) -> None: return self._event_emitter.add_listener(event_name, outer_wrapper) - def off(self, event_name: ActorEventTypes, listener: Optional[Callable] = None) -> None: + def off(self: EventManager, event_name: ActorEventTypes, listener: Callable | None = None) -> None: """Remove a listener, or all listeners, from an actor event. Args: @@ -172,7 +176,7 @@ def off(self, event_name: ActorEventTypes, listener: Optional[Callable] = None) self._listeners_to_wrappers[event_name] = defaultdict(list) self._event_emitter.remove_all_listeners(event_name) - def emit(self, event_name: ActorEventTypes, data: Any) -> None: + def emit(self: EventManager, event_name: ActorEventTypes, data: Any) -> None: """Emit an actor event manually. Args: @@ -183,12 +187,13 @@ def emit(self, event_name: ActorEventTypes, data: Any) -> None: self._event_emitter.emit(event_name, data) - async def wait_for_all_listeners_to_complete(self, *, timeout_secs: Optional[float] = None) -> None: + async def wait_for_all_listeners_to_complete(self: EventManager, *, timeout_secs: float | None = None) -> None: """Wait for all event listeners which are currently being executed to complete. Args: timeout_secs (float, optional): Timeout for the wait. If the event listeners don't finish until the timeout, they will be canceled. """ + async def _wait_for_listeners() -> None: results = await asyncio.gather(*self._listener_tasks, return_exceptions=True) for result in results: @@ -206,19 +211,19 @@ async def _wait_for_listeners() -> None: else: await _wait_for_listeners() - async def _process_platform_messages(self) -> None: + async def _process_platform_messages(self: EventManager) -> None: # This should be called only on the platform, where we have the ACTOR_EVENTS_WS_URL configured - assert self._config.actor_events_ws_url is not None - assert self._connected_to_platform_websocket is not None + assert self._config.actor_events_ws_url is not None # noqa: S101 + assert self._connected_to_platform_websocket is not None # noqa: S101 try: async with websockets.client.connect(self._config.actor_events_ws_url) as websocket: self._platform_events_websocket = websocket - self._connected_to_platform_websocket.set_result(True) + self._connected_to_platform_websocket.set_result(True) # noqa: FBT003 async for message in websocket: try: parsed_message = json.loads(message) - assert isinstance(parsed_message, dict) + assert isinstance(parsed_message, dict) # noqa: S101 parsed_message = parse_date_fields(parsed_message) event_name = parsed_message['name'] event_data = parsed_message.get('data') # 'data' can be missing @@ -229,4 +234,4 @@ async def _process_platform_messages(self) -> None: logger.exception('Cannot parse actor event', extra={'message': message}) except Exception: logger.exception('Error in websocket connection') - self._connected_to_platform_websocket.set_result(False) + self._connected_to_platform_websocket.set_result(False) # noqa: FBT003 diff --git a/src/apify/log.py b/src/apify/log.py index 308a1f31..986f8c96 100644 --- a/src/apify/log.py +++ b/src/apify/log.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import json import logging import textwrap import traceback -from typing import Any, Dict +from typing import Any from colorama import Fore, Style, just_fix_windows_console @@ -55,17 +57,24 @@ class ActorLogFormatter(logging.Formatter): # and extract all the extra ones not present in the empty log record empty_record = logging.LogRecord('dummy', 0, 'dummy', 0, 'dummy', None, None) - def __init__(self, include_logger_name: bool = False, *args: tuple, **kwargs: dict) -> None: + def __init__( + self: ActorLogFormatter, + include_logger_name: bool = False, # noqa: FBT001, FBT002 + *args: Any, + **kwargs: Any, + ) -> None: """Create an instance of the ActorLogFormatter. Args: include_logger_name: Include logger name at the beginning of the log line. Defaults to False. + args: Arguments passed to the parent class. + kwargs: Keyword arguments passed to the parent class. """ - super().__init__(*args, **kwargs) # type: ignore + super().__init__(*args, **kwargs) self.include_logger_name = include_logger_name - def _get_extra_fields(self, record: logging.LogRecord) -> Dict[str, Any]: - extra_fields: Dict[str, Any] = {} + def _get_extra_fields(self: ActorLogFormatter, record: logging.LogRecord) -> dict[str, Any]: + extra_fields: dict[str, Any] = {} for key, value in record.__dict__.items(): if key not in self.empty_record.__dict__: extra_fields[key] = value @@ -73,7 +82,7 @@ def _get_extra_fields(self, record: logging.LogRecord) -> Dict[str, Any]: return extra_fields @ignore_docs - def format(self, record: logging.LogRecord) -> str: + def format(self: ActorLogFormatter, record: logging.LogRecord) -> str: # noqa: A003 """Format the log record nicely. This formats the log record so that it: @@ -112,5 +121,5 @@ def format(self, record: logging.LogRecord) -> str: if self.include_logger_name: # Include logger name at the beginning of the log line return f'{logger_name_string}{level_string}{log_string}{extra_string}{exception_string}' - else: - return f'{level_string}{log_string}{extra_string}{exception_string}' + + return f'{level_string}{log_string}{extra_string}{exception_string}' diff --git a/src/apify/proxy_configuration.py b/src/apify/proxy_configuration.py index 9451d0c0..22e46447 100644 --- a/src/apify/proxy_configuration.py +++ b/src/apify/proxy_configuration.py @@ -1,25 +1,31 @@ +from __future__ import annotations + import inspect import ipaddress import re -from typing import Any, Awaitable, Callable, Dict, List, Optional, Pattern, TypedDict, Union +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Pattern, TypedDict from urllib.parse import urljoin, urlparse import httpx -from typing_extensions import NotRequired -from apify_client import ApifyClientAsync from apify_shared.consts import ApifyEnvVars from apify_shared.utils import ignore_docs from .config import Configuration from .log import logger +if TYPE_CHECKING: + from typing_extensions import NotRequired + + from apify_client import ApifyClientAsync + APIFY_PROXY_VALUE_REGEX = re.compile(r'^[\w._~]+$') COUNTRY_CODE_REGEX = re.compile(r'^[A-Z]{2}$') SESSION_ID_MAX_LENGTH = 50 -def _is_url(url: str) -> bool: +def is_url(url: str) -> bool: + """Check if the given string is a valid URL.""" try: parsed_url = urlparse(urljoin(url, '/')) has_all_parts = all([parsed_url.scheme, parsed_url.netloc, parsed_url.path]) @@ -39,10 +45,10 @@ def _is_url(url: str) -> bool: def _check( value: Any, *, - label: Optional[str], - pattern: Optional[Pattern] = None, - min_length: Optional[int] = None, - max_length: Optional[int] = None, + label: str | None, + pattern: Pattern | None = None, + min_length: int | None = None, + max_length: int | None = None, ) -> None: error_str = f'Value {value}' if label: @@ -55,7 +61,7 @@ def _check( raise ValueError(f'{error_str} is longer than maximum allowed length {max_length}') if pattern and not re.fullmatch(pattern, value): - raise ValueError(f'{error_str} does not match pattern {repr(pattern.pattern)}') + raise ValueError(f'{error_str} does not match pattern {pattern.pattern!r}') class ProxyInfo(TypedDict): @@ -76,7 +82,7 @@ class ProxyInfo(TypedDict): password: str """The password for the proxy.""" - groups: NotRequired[List[str]] + groups: NotRequired[list[str]] """An array of proxy groups to be used by the [Apify Proxy](https://docs.apify.com/proxy). If not provided, the proxy will select the groups automatically. """ @@ -111,30 +117,30 @@ class ProxyConfiguration: is_man_in_the_middle = False _next_custom_url_index = 0 - _proxy_urls: List[str] - _used_proxy_urls: Dict[str, str] - _new_url_function: Optional[Union[Callable[[Optional[str]], str], Callable[[Optional[str]], Awaitable[str]]]] = None - _groups: List[str] - _country_code: Optional[str] = None - _password: Optional[str] = None + _proxy_urls: list[str] + _used_proxy_urls: dict[str, str] + _new_url_function: Callable[[str | None], str] | Callable[[str | None], Awaitable[str]] | None = None + _groups: list[str] + _country_code: str | None = None + _password: str | None = None _hostname: str _port: int - _uses_apify_proxy: Optional[bool] = None + _uses_apify_proxy: bool | None = None _actor_config: Configuration - _apify_client: Optional[ApifyClientAsync] = None + _apify_client: ApifyClientAsync | None = None @ignore_docs def __init__( - self, + self: ProxyConfiguration, *, - password: Optional[str] = None, - groups: Optional[List[str]] = None, - country_code: Optional[str] = None, - proxy_urls: Optional[List[str]] = None, - new_url_function: Optional[Union[Callable[[Optional[str]], str], Callable[[Optional[str]], Awaitable[str]]]] = None, - _actor_config: Optional[Configuration] = None, - _apify_client: Optional[ApifyClientAsync] = None, - ): + password: str | None = None, + groups: list[str] | None = None, + country_code: str | None = None, + proxy_urls: list[str] | None = None, + new_url_function: Callable[[str | None], str] | Callable[[str | None], Awaitable[str]] | None = None, + _actor_config: Configuration | None = None, + _apify_client: ApifyClientAsync | None = None, + ) -> None: """Create a ProxyConfiguration instance. It is highly recommended to use `Actor.create_proxy_configuration()` instead of this. Args: @@ -152,8 +158,8 @@ def __init__( country_code = str(country_code) _check(country_code, label='country_code', pattern=COUNTRY_CODE_REGEX) if proxy_urls: - for (i, url) in enumerate(proxy_urls): - if not _is_url(url): + for i, url in enumerate(proxy_urls): + if not is_url(url): raise ValueError(f'proxy_urls[{i}] ("{url}") is not a valid URL') # Validation @@ -161,14 +167,18 @@ def __init__( raise ValueError('Cannot combine custom proxies in "proxy_urls" with custom generating function in "new_url_function".') if (proxy_urls or new_url_function) and (groups or country_code): - raise ValueError('Cannot combine custom proxies with Apify Proxy!' - ' It is not allowed to set "proxy_urls" or "new_url_function" combined with' - ' "groups" or "country_code".') + raise ValueError( + 'Cannot combine custom proxies with Apify Proxy!' + ' It is not allowed to set "proxy_urls" or "new_url_function" combined with' + ' "groups" or "country_code".' + ) # mypy has a bug with narrowing types for filter (https://github.com/python/mypy/issues/12682) if proxy_urls and next(filter(lambda url: 'apify.com' in url, proxy_urls), None): # type: ignore - logger.warning('Some Apify proxy features may work incorrectly. Please consider setting up Apify properties instead of `proxy_urls`.\n' - 'See https://sdk.apify.com/docs/guides/proxy-management#apify-proxy-configuration') + logger.warning( + 'Some Apify proxy features may work incorrectly. Please consider setting up Apify properties instead of `proxy_urls`.\n' + 'See https://sdk.apify.com/docs/guides/proxy-management#apify-proxy-configuration' + ) self._actor_config = _actor_config or Configuration._get_default_instance() self._apify_client = _apify_client @@ -184,7 +194,7 @@ def __init__( self._country_code = country_code self._uses_apify_proxy = not (proxy_urls or new_url_function) - async def initialize(self) -> None: + async def initialize(self: ProxyConfiguration) -> None: """Load the Apify Proxy password if the API token is provided and check access to Apify Proxy and provided proxy groups. Only called if Apify Proxy configuration is used. @@ -197,7 +207,7 @@ async def initialize(self) -> None: await self._maybe_fetch_password() await self._check_access() - async def new_url(self, session_id: Optional[Union[int, str]] = None) -> str: + async def new_url(self: ProxyConfiguration, session_id: int | str | None = None) -> str: """Return a new proxy URL based on provided configuration options and the `sessionId` parameter. Args: @@ -220,27 +230,27 @@ async def new_url(self, session_id: Optional[Union[int, str]] = None) -> str: if inspect.isawaitable(res): res = await res return str(res) - except Exception as e: - raise ValueError('The provided "new_url_function" did not return a valid URL') from e + except Exception as exc: + raise ValueError('The provided "new_url_function" did not return a valid URL') from exc if self._proxy_urls: if not session_id: index = self._next_custom_url_index self._next_custom_url_index = (self._next_custom_url_index + 1) % len(self._proxy_urls) return self._proxy_urls[index] - else: - if session_id not in self._used_proxy_urls: - index = self._next_custom_url_index - self._next_custom_url_index = (self._next_custom_url_index + 1) % len(self._proxy_urls) - self._used_proxy_urls[session_id] = self._proxy_urls[index] - return self._used_proxy_urls[session_id] + if session_id not in self._used_proxy_urls: + index = self._next_custom_url_index + self._next_custom_url_index = (self._next_custom_url_index + 1) % len(self._proxy_urls) + self._used_proxy_urls[session_id] = self._proxy_urls[index] + + return self._used_proxy_urls[session_id] username = self._get_username(session_id) return f'http://{username}:{self._password}@{self._hostname}:{self._port}' - async def new_proxy_info(self, session_id: Optional[Union[int, str]] = None) -> ProxyInfo: + async def new_proxy_info(self: ProxyConfiguration, session_id: int | str | None = None) -> ProxyInfo: """Create a new ProxyInfo object. Use it if you want to work with a rich representation of a proxy URL. @@ -275,21 +285,21 @@ async def new_proxy_info(self, session_id: Optional[Union[int, str]] = None) -> if session_id is not None: res['session_id'] = session_id return res - else: - parsed_url = urlparse(url) - assert parsed_url.hostname is not None - assert parsed_url.port is not None - res = { - 'url': url, - 'hostname': parsed_url.hostname, - 'port': parsed_url.port, - 'password': parsed_url.password or '', - } - if parsed_url.username: - res['username'] = parsed_url.username + + parsed_url = urlparse(url) + assert parsed_url.hostname is not None # noqa: S101 + assert parsed_url.port is not None # noqa: S101 + res = { + 'url': url, + 'hostname': parsed_url.hostname, + 'port': parsed_url.port, + 'password': parsed_url.password or '', + } + if parsed_url.username: + res['username'] = parsed_url.username return res - async def _maybe_fetch_password(self) -> None: + async def _maybe_fetch_password(self: ProxyConfiguration) -> None: token = self._actor_config.token if token and self._apify_client: @@ -299,17 +309,21 @@ async def _maybe_fetch_password(self) -> None: if self._password: if self._password != password: - logger.warning('The Apify Proxy password you provided belongs to' - ' a different user than the Apify token you are using. Are you sure this is correct?') + logger.warning( + 'The Apify Proxy password you provided belongs to' + ' a different user than the Apify token you are using. Are you sure this is correct?' + ) else: self._password = password if not self._password: - raise ValueError('Apify Proxy password must be provided using the "password" constructor argument' - f' or the "{ApifyEnvVars.PROXY_PASSWORD}" environment variable.' - f' If you add the "{ApifyEnvVars.TOKEN}" environment variable, the password will be automatically inferred.') + raise ValueError( + 'Apify Proxy password must be provided using the "password" constructor argument' + f' or the "{ApifyEnvVars.PROXY_PASSWORD}" environment variable.' + f' If you add the "{ApifyEnvVars.TOKEN}" environment variable, the password will be automatically inferred.' + ) - async def _check_access(self) -> None: + async def _check_access(self: ProxyConfiguration) -> None: proxy_status_url = f'{self._actor_config.proxy_status_url}/?format=json' status = None @@ -319,7 +333,7 @@ async def _check_access(self) -> None: response = await client.get(proxy_status_url) status = response.json() break - except Exception: + except Exception: # noqa: S110 # retry on connection errors pass @@ -329,20 +343,22 @@ async def _check_access(self) -> None: self.is_man_in_the_middle = status['isManInTheMiddle'] else: - logger.warning('Apify Proxy access check timed out. Watch out for errors with status code 407. ' - "If you see some, it most likely means you don't have access to either all or some of the proxies you're trying to use.") + logger.warning( + 'Apify Proxy access check timed out. Watch out for errors with status code 407. ' + "If you see some, it most likely means you don't have access to either all or some of the proxies you're trying to use." + ) - def _get_username(self, session_id: Optional[Union[int, str]] = None) -> str: + def _get_username(self: ProxyConfiguration, session_id: int | str | None = None) -> str: if session_id is not None: session_id = f'{session_id}' - parts: List[str] = [] + parts: list[str] = [] if self._groups: parts.append(f'groups-{"+".join(self._groups)}') if session_id is not None: parts.append(f'session-{session_id}') - if (self._country_code): + if self._country_code: parts.append(f'country-{self._country_code}') if not parts: diff --git a/src/apify/scrapy/middlewares.py b/src/apify/scrapy/middlewares.py index ef648d1d..a568eca9 100644 --- a/src/apify/scrapy/middlewares.py +++ b/src/apify/scrapy/middlewares.py @@ -1,11 +1,13 @@ +from __future__ import annotations + import traceback -from typing import Union +from typing import TYPE_CHECKING, Any try: - from scrapy import Spider + from scrapy import Spider # noqa: TCH002 from scrapy.downloadermiddlewares.retry import RetryMiddleware from scrapy.exceptions import IgnoreRequest - from scrapy.http import Request, Response + from scrapy.http import Request, Response # noqa: TCH002 from scrapy.utils.response import response_status_message except ImportError as exc: raise ImportError( @@ -13,27 +15,35 @@ ) from exc from ..actor import Actor -from ..storages import RequestQueue from .utils import nested_event_loop, open_queue_with_custom_client, to_apify_request +if TYPE_CHECKING: + from ..storages import RequestQueue + class ApifyRetryMiddleware(RetryMiddleware): """The default Scrapy retry middleware enriched with Apify's Request Queue interaction.""" - def __init__(self, *args: list, **kwargs: dict) -> None: + def __init__(self: ApifyRetryMiddleware, *args: Any, **kwargs: Any) -> None: """Create a new instance.""" super().__init__(*args, **kwargs) try: self._rq: RequestQueue = nested_event_loop.run_until_complete(open_queue_with_custom_client()) except BaseException: traceback.print_exc() + raise - def __del__(self) -> None: + def __del__(self: ApifyRetryMiddleware) -> None: """Before deleting the instance, close the nested event loop.""" nested_event_loop.stop() nested_event_loop.close() - def process_response(self, request: Request, response: Response, spider: Spider) -> Union[Request, Response]: + def process_response( + self: ApifyRetryMiddleware, + request: Request, + response: Response, + spider: Spider, + ) -> Request | Response | None: """Process the response and decide whether the request should be retried. Args: @@ -46,23 +56,22 @@ def process_response(self, request: Request, response: Response, spider: Spider) """ # Robots requests are bypassed directly, they don't go through a Scrapy Scheduler, and also through our # Request Queue. Check the scrapy.downloadermiddlewares.robotstxt.RobotsTxtMiddleware for details. - assert isinstance(request.url, str) + assert isinstance(request.url, str) # noqa: S101 if request.url.endswith('robots.txt'): return response try: - returned = nested_event_loop.run_until_complete(self._handle_retry_logic(request, response, spider)) + return nested_event_loop.run_until_complete(self._handle_retry_logic(request, response, spider)) except BaseException: traceback.print_exc() - - return returned + raise def process_exception( - self, + self: ApifyRetryMiddleware, request: Request, exception: BaseException, spider: Spider, - ) -> Union[Request, Response, None]: + ) -> Request | Response | None: """Handle the exception and decide whether the request should be retried.""" Actor.log.debug(f'ApifyRetryMiddleware.process_exception was called (scrapy_request={request})...') apify_request = to_apify_request(request, spider=spider) @@ -72,17 +81,18 @@ def process_exception( nested_event_loop.run_until_complete(self._rq.mark_request_as_handled(apify_request)) except BaseException: traceback.print_exc() + raise else: nested_event_loop.run_until_complete(self._rq.reclaim_request(apify_request)) return super().process_exception(request, exception, spider) async def _handle_retry_logic( - self, + self: ApifyRetryMiddleware, request: Request, response: Response, spider: Spider, - ) -> Union[Request, Response]: + ) -> Request | Response | None: """Handle the retry logic of the request.""" Actor.log.debug(f'ApifyRetryMiddleware.handle_retry_logic was called (scrapy_request={request})...') apify_request = to_apify_request(request, spider=spider) diff --git a/src/apify/scrapy/pipelines.py b/src/apify/scrapy/pipelines.py index fe44e4b8..3c682c17 100644 --- a/src/apify/scrapy/pipelines.py +++ b/src/apify/scrapy/pipelines.py @@ -1,7 +1,9 @@ -from itemadapter import ItemAdapter +from __future__ import annotations + +from itemadapter.adapter import ItemAdapter try: - from scrapy import Item, Spider + from scrapy import Item, Spider # noqa: TCH002 except ImportError as exc: raise ImportError( 'To use this module, you need to install the "scrapy" extra. Run "pip install apify[scrapy]".', @@ -16,7 +18,11 @@ class ActorDatasetPushPipeline: This pipeline is designed to be enabled only when the Scrapy project is run as an Actor. """ - async def process_item(self, item: Item, spider: Spider) -> Item: + async def process_item( + self: ActorDatasetPushPipeline, + item: Item, + spider: Spider, + ) -> Item: """Pushes the provided Scrapy item to the Actor's default dataset.""" item_dict = ItemAdapter(item).asdict() Actor.log.debug(f'Pushing item={item_dict} produced by spider={spider} to the dataset.') diff --git a/src/apify/scrapy/scheduler.py b/src/apify/scrapy/scheduler.py index c68864e9..139aa056 100644 --- a/src/apify/scrapy/scheduler.py +++ b/src/apify/scrapy/scheduler.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import traceback -from typing import Optional try: from scrapy import Spider from scrapy.core.scheduler import BaseScheduler - from scrapy.http.request import Request + from scrapy.http.request import Request # noqa: TCH002 from scrapy.utils.reactor import is_asyncio_reactor_installed except ImportError as exc: raise ImportError( @@ -23,7 +24,7 @@ class ApifyScheduler(BaseScheduler): This scheduler requires the asyncio Twisted reactor to be installed. """ - def __init__(self) -> None: + def __init__(self: ApifyScheduler) -> None: """Create a new instance.""" if not is_asyncio_reactor_installed(): raise ValueError( @@ -31,10 +32,10 @@ def __init__(self) -> None: 'Make sure you have it configured in the TWISTED_REACTOR setting. See the asyncio ' 'documentation of Scrapy for more information.', ) - self._rq: Optional[RequestQueue] = None - self.spider: Optional[Spider] = None + self._rq: RequestQueue | None = None + self.spider: Spider | None = None - def open(self, spider: Spider) -> None: + def open(self: ApifyScheduler, spider: Spider) -> None: # noqa: A003 # this has to be named "open" """Open the scheduler. Args: @@ -46,8 +47,9 @@ def open(self, spider: Spider) -> None: self._rq = nested_event_loop.run_until_complete(open_queue_with_custom_client()) except BaseException: traceback.print_exc() + raise - def close(self, reason: str) -> None: + def close(self: ApifyScheduler, reason: str) -> None: """Close the scheduler. Args: @@ -57,21 +59,22 @@ def close(self, reason: str) -> None: nested_event_loop.stop() nested_event_loop.close() - def has_pending_requests(self) -> bool: + def has_pending_requests(self: ApifyScheduler) -> bool: """Check if the scheduler has any pending requests. Returns: True if the scheduler has any pending requests, False otherwise. """ - assert isinstance(self._rq, RequestQueue) + assert isinstance(self._rq, RequestQueue) # noqa: S101 try: is_finished = nested_event_loop.run_until_complete(self._rq.is_finished()) except BaseException: traceback.print_exc() + raise return not is_finished - def enqueue_request(self, request: Request) -> bool: + def enqueue_request(self: ApifyScheduler, request: Request) -> bool: """Add a request to the scheduler. Args: @@ -83,19 +86,21 @@ def enqueue_request(self, request: Request) -> bool: call_id = crypto_random_object_id(8) Actor.log.debug(f'[{call_id}]: ApifyScheduler.enqueue_request was called (scrapy_request={request})...') + assert isinstance(self.spider, Spider) # noqa: S101 apify_request = to_apify_request(request, spider=self.spider) Actor.log.debug(f'[{call_id}]: scrapy_request was transformed to apify_request (apify_request={apify_request})') - assert isinstance(self._rq, RequestQueue) + assert isinstance(self._rq, RequestQueue) # noqa: S101 try: result = nested_event_loop.run_until_complete(self._rq.add_request(apify_request)) except BaseException: traceback.print_exc() + raise Actor.log.debug(f'[{call_id}]: apify_request was added to the RQ (apify_request={apify_request})') return bool(result['wasAlreadyPresent']) - def next_request(self) -> Optional[Request]: + def next_request(self: ApifyScheduler) -> Request | None: """Fetch the next request from the scheduler. Returns: @@ -103,18 +108,20 @@ def next_request(self) -> Optional[Request]: """ call_id = crypto_random_object_id(8) Actor.log.debug(f'[{call_id}]: ApifyScheduler.next_request was called...') - assert isinstance(self._rq, RequestQueue) + assert isinstance(self._rq, RequestQueue) # noqa: S101 try: apify_request = nested_event_loop.run_until_complete(self._rq.fetch_next_request()) except BaseException: traceback.print_exc() + raise Actor.log.debug(f'[{call_id}]: a new apify_request from the scheduler was fetched (apify_request={apify_request})') if apify_request is None: return None + assert isinstance(self.spider, Spider) # noqa: S101 scrapy_request = to_scrapy_request(apify_request, spider=self.spider) Actor.log.debug( f'[{call_id}]: apify_request was transformed to the scrapy_request which is gonna be returned (scrapy_request={scrapy_request})', diff --git a/src/apify/scrapy/utils.py b/src/apify/scrapy/utils.py index 561e77bc..6ccd32bf 100644 --- a/src/apify/scrapy/utils.py +++ b/src/apify/scrapy/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import codecs import pickle @@ -33,11 +35,12 @@ def to_apify_request(scrapy_request: Request, spider: Spider) -> dict: Args: scrapy_request: The Scrapy request to be converted. + spider: The Scrapy spider that the request is associated with. Returns: The converted Apify request. """ - assert isinstance(scrapy_request, Request) + assert isinstance(scrapy_request, Request) # noqa: S101 call_id = crypto_random_object_id(8) Actor.log.debug(f'[{call_id}]: to_apify_request was called (scrapy_request={scrapy_request})...') @@ -72,15 +75,16 @@ def to_scrapy_request(apify_request: dict, spider: Spider) -> Request: Args: apify_request: The Apify request to be converted. + spider: The Scrapy spider that the request is associated with. Returns: The converted Scrapy request. """ - assert isinstance(apify_request, dict) - assert 'url' in apify_request - assert 'method' in apify_request - assert 'id' in apify_request - assert 'uniqueKey' in apify_request + assert isinstance(apify_request, dict) # noqa: S101 + assert 'url' in apify_request # noqa: S101 + assert 'method' in apify_request # noqa: S101 + assert 'id' in apify_request # noqa: S101 + assert 'uniqueKey' in apify_request # noqa: S101 call_id = crypto_random_object_id(8) Actor.log.debug(f'[{call_id}]: to_scrapy_request was called (apify_request={apify_request})...') @@ -92,13 +96,13 @@ def to_scrapy_request(apify_request: dict, spider: Spider) -> Request: # the Scrapy Request object from its dictionary representation. Actor.log.debug(f'[{call_id}]: Restoring the Scrapy Request from the apify_request...') scrapy_request_dict_encoded = apify_request['userData']['scrapy_request'] - assert isinstance(scrapy_request_dict_encoded, str) + assert isinstance(scrapy_request_dict_encoded, str) # noqa: S101 scrapy_request_dict = pickle.loads(codecs.decode(scrapy_request_dict_encoded.encode(), 'base64')) - assert isinstance(scrapy_request_dict, dict) + assert isinstance(scrapy_request_dict, dict) # noqa: S101 scrapy_request = request_from_dict(scrapy_request_dict, spider=spider) - assert isinstance(scrapy_request, Request) + assert isinstance(scrapy_request, Request) # noqa: S101 Actor.log.debug(f'[{call_id}]: Scrapy Request successfully reconstructed (scrapy_request={scrapy_request})...') # Update the meta field with the meta field from the apify_request diff --git a/src/apify/storages/__init__.py b/src/apify/storages/__init__.py index f0c9dc38..e954ef20 100644 --- a/src/apify/storages/__init__.py +++ b/src/apify/storages/__init__.py @@ -3,4 +3,9 @@ from .request_queue import RequestQueue from .storage_client_manager import StorageClientManager -__all__ = ['Dataset', 'KeyValueStore', 'RequestQueue', 'StorageClientManager'] +__all__ = [ + 'Dataset', + 'KeyValueStore', + 'RequestQueue', + 'StorageClientManager', +] diff --git a/src/apify/storages/base_storage.py b/src/apify/storages/base_storage.py index c78e63d7..0daea3ed 100644 --- a/src/apify/storages/base_storage.py +++ b/src/apify/storages/base_storage.py @@ -1,10 +1,9 @@ +from __future__ import annotations + import asyncio from abc import ABC, abstractmethod -from typing import Dict, Generic, Optional, TypeVar, Union, cast - -from typing_extensions import Self +from typing import TYPE_CHECKING, Generic, TypeVar, cast -from apify_client import ApifyClientAsync from apify_shared.utils import ignore_docs from .._memory_storage import MemoryStorageClient @@ -12,6 +11,9 @@ from ..config import Configuration from .storage_client_manager import StorageClientManager +if TYPE_CHECKING: + from apify_client import ApifyClientAsync + BaseResourceClientType = TypeVar('BaseResourceClientType', bound=BaseResourceClient) BaseResourceCollectionClientType = TypeVar('BaseResourceCollectionClientType', bound=BaseResourceCollectionClient) @@ -21,15 +23,21 @@ class BaseStorage(ABC, Generic[BaseResourceClientType, BaseResourceCollectionCli """A class for managing storages.""" _id: str - _name: Optional[str] - _storage_client: Union[ApifyClientAsync, MemoryStorageClient] + _name: str | None + _storage_client: ApifyClientAsync | MemoryStorageClient _config: Configuration - _cache_by_id: Optional[Dict[str, Self]] = None - _cache_by_name: Optional[Dict[str, Self]] = None - _storage_creating_lock: Optional[asyncio.Lock] = None - - def __init__(self, id: str, name: Optional[str], client: Union[ApifyClientAsync, MemoryStorageClient], config: Configuration): + _cache_by_id: dict | None = None + _cache_by_name: dict | None = None + _storage_creating_lock: asyncio.Lock | None = None + + def __init__( + self: BaseStorage, + id: str, # noqa: A002 + name: str | None, + client: ApifyClientAsync | MemoryStorageClient, + config: Configuration, + ) -> None: """Initialize the storage. Do not use this method directly, but use `Actor.open_()` instead. @@ -47,26 +55,33 @@ def __init__(self, id: str, name: Optional[str], client: Union[ApifyClientAsync, @classmethod @abstractmethod - def _get_human_friendly_label(cls) -> str: + def _get_human_friendly_label(cls: type[BaseStorage]) -> str: raise NotImplementedError('You must override this method in the subclass!') @classmethod @abstractmethod - def _get_default_id(cls, config: Configuration) -> str: + def _get_default_id(cls: type[BaseStorage], config: Configuration) -> str: raise NotImplementedError('You must override this method in the subclass!') @classmethod @abstractmethod - def _get_single_storage_client(cls, id: str, client: Union[ApifyClientAsync, MemoryStorageClient]) -> BaseResourceClientType: + def _get_single_storage_client( + cls: type[BaseStorage], + id: str, # noqa: A002 + client: ApifyClientAsync | MemoryStorageClient, + ) -> BaseResourceClientType: raise NotImplementedError('You must override this method in the subclass!') @classmethod @abstractmethod - def _get_storage_collection_client(cls, client: Union[ApifyClientAsync, MemoryStorageClient]) -> BaseResourceCollectionClientType: + def _get_storage_collection_client( + cls: type[BaseStorage], + client: ApifyClientAsync | MemoryStorageClient, + ) -> BaseResourceCollectionClientType: raise NotImplementedError('You must override this method in the subclass!') @classmethod - def _ensure_class_initialized(cls) -> None: + def _ensure_class_initialized(cls: type[BaseStorage]) -> None: if cls._cache_by_id is None: cls._cache_by_id = {} if cls._cache_by_name is None: @@ -76,14 +91,14 @@ def _ensure_class_initialized(cls) -> None: @classmethod @abstractmethod - async def open( - cls, + async def open( # noqa: A003 + cls: type[BaseStorage], *, - id: Optional[str] = None, - name: Optional[str] = None, + id: str | None = None, # noqa: A002 + name: str | None = None, force_cloud: bool = False, - config: Optional[Configuration] = None, - ) -> Self: + config: Configuration | None = None, + ) -> BaseStorage: """Open a storage, or return a cached storage object if it was opened before. Opens a storage with the given ID or name. @@ -104,10 +119,9 @@ async def open( An instance of the storage. """ cls._ensure_class_initialized() - assert cls._cache_by_id is not None - assert cls._cache_by_name is not None - - assert not (id and name) + assert cls._cache_by_id is not None # noqa: S101 + assert cls._cache_by_name is not None # noqa: S101 + assert not (id and name) # noqa: S101 used_config = config or Configuration.get_global_configuration() used_client = StorageClientManager.get_storage_client(force_cloud=force_cloud) @@ -117,7 +131,7 @@ async def open( if not id and not name: if isinstance(used_client, MemoryStorageClient): is_default_storage_on_local = True - id = cls._get_default_id(used_config) + id = cls._get_default_id(used_config) # noqa: A001 # Try to get the storage instance from cache cached_storage = None @@ -128,13 +142,13 @@ async def open( if cached_storage is not None: # This cast is needed since MyPy doesn't understand very well that Self and Storage are the same - return cast(Self, cached_storage) + return cast(BaseStorage, cached_storage) # Purge default storages if configured if used_config.purge_on_start and isinstance(used_client, MemoryStorageClient): await used_client._purge_on_start() - assert cls._storage_creating_lock is not None + assert cls._storage_creating_lock is not None # noqa: S101 async with cls._storage_creating_lock: # Create the storage if id and not is_default_storage_on_local: @@ -159,7 +173,7 @@ async def open( return storage - def _remove_from_cache(self) -> None: + def _remove_from_cache(self: BaseStorage) -> None: if self.__class__._cache_by_id is not None: del self.__class__._cache_by_id[self._id] diff --git a/src/apify/storages/dataset.py b/src/apify/storages/dataset.py index b165d7b7..8b2a5405 100644 --- a/src/apify/storages/dataset.py +++ b/src/apify/storages/dataset.py @@ -1,35 +1,40 @@ +from __future__ import annotations + import csv import io import math -from typing import AsyncIterator, Dict, Iterable, Iterator, List, Optional, Union +from typing import TYPE_CHECKING, AsyncIterator, Iterable, Iterator -from apify_client import ApifyClientAsync -from apify_client.clients import DatasetClientAsync, DatasetCollectionClientAsync -from apify_shared.models import ListPage -from apify_shared.types import JSONSerializable from apify_shared.utils import ignore_docs, json_dumps -from .._memory_storage import MemoryStorageClient -from .._memory_storage.resource_clients import DatasetClient, DatasetCollectionClient -from .._utils import _wrap_internal -from ..config import Configuration +from .._utils import wrap_internal from ..consts import MAX_PAYLOAD_SIZE_BYTES from .base_storage import BaseStorage from .key_value_store import KeyValueStore +if TYPE_CHECKING: + from apify_client import ApifyClientAsync + from apify_client.clients import DatasetClientAsync, DatasetCollectionClientAsync + from apify_shared.models import ListPage + from apify_shared.types import JSONSerializable + + from .._memory_storage import MemoryStorageClient + from .._memory_storage.resource_clients import DatasetClient, DatasetCollectionClient + from ..config import Configuration + # 0.01% SAFETY_BUFFER_PERCENT = 0.01 / 100 EFFECTIVE_LIMIT_BYTES = MAX_PAYLOAD_SIZE_BYTES - math.ceil(MAX_PAYLOAD_SIZE_BYTES * SAFETY_BUFFER_PERCENT) -def _check_and_serialize(item: JSONSerializable, index: Optional[int] = None) -> str: +def _check_and_serialize(item: JSONSerializable, index: int | None = None) -> str: """Accept a JSON serializable object as an input, validate its serializability and its serialized size against `EFFECTIVE_LIMIT_BYTES`.""" s = ' ' if index is None else f' at index {index} ' try: payload = json_dumps(item) - except Exception as e: - raise ValueError(f'Data item{s}is not serializable to JSON.') from e + except Exception as exc: + raise ValueError(f'Data item{s}is not serializable to JSON.') from exc length_bytes = len(payload.encode('utf-8')) if length_bytes > EFFECTIVE_LIMIT_BYTES: @@ -91,11 +96,17 @@ class Dataset(BaseStorage): """ _id: str - _name: Optional[str] - _dataset_client: Union[DatasetClientAsync, DatasetClient] + _name: str | None + _dataset_client: DatasetClientAsync | DatasetClient @ignore_docs - def __init__(self, id: str, name: Optional[str], client: Union[ApifyClientAsync, MemoryStorageClient], config: Configuration) -> None: + def __init__( + self: Dataset, + id: str, # noqa: A002 + name: str | None, + client: ApifyClientAsync | MemoryStorageClient, + config: Configuration, + ) -> None: """Create a `Dataset` instance. Do not use the constructor directly, use the `Actor.open_dataset()` function instead. @@ -108,34 +119,38 @@ def __init__(self, id: str, name: Optional[str], client: Union[ApifyClientAsync, """ super().__init__(id=id, name=name, client=client, config=config) - self.get_data = _wrap_internal(self._get_data_internal, self.get_data) # type: ignore - self.push_data = _wrap_internal(self._push_data_internal, self.push_data) # type: ignore - self.export_to_json = _wrap_internal(self._export_to_json_internal, self.export_to_json) # type: ignore - self.export_to_csv = _wrap_internal(self._export_to_csv_internal, self.export_to_csv) # type: ignore + self.get_data = wrap_internal(self._get_data_internal, self.get_data) # type: ignore + self.push_data = wrap_internal(self._push_data_internal, self.push_data) # type: ignore + self.export_to_json = wrap_internal(self._export_to_json_internal, self.export_to_json) # type: ignore + self.export_to_csv = wrap_internal(self._export_to_csv_internal, self.export_to_csv) # type: ignore self._dataset_client = client.dataset(self._id) @classmethod - def _get_human_friendly_label(cls) -> str: + def _get_human_friendly_label(cls: type[Dataset]) -> str: return 'Dataset' @classmethod - def _get_default_id(cls, config: Configuration) -> str: + def _get_default_id(cls: type[Dataset], config: Configuration) -> str: return config.default_dataset_id @classmethod - def _get_single_storage_client(cls, id: str, client: Union[ApifyClientAsync, MemoryStorageClient]) -> Union[DatasetClientAsync, DatasetClient]: + def _get_single_storage_client( + cls: type[Dataset], + id: str, # noqa: A002 + client: ApifyClientAsync | MemoryStorageClient, + ) -> DatasetClientAsync | DatasetClient: return client.dataset(id) @classmethod def _get_storage_collection_client( - cls, - client: Union[ApifyClientAsync, MemoryStorageClient], - ) -> Union[DatasetCollectionClientAsync, DatasetCollectionClient]: + cls: type[Dataset], + client: ApifyClientAsync | MemoryStorageClient, + ) -> DatasetCollectionClientAsync | DatasetCollectionClient: return client.datasets() @classmethod - async def push_data(cls, data: JSONSerializable) -> None: + async def push_data(cls: type[Dataset], data: JSONSerializable) -> None: """Store an object or an array of objects to the dataset. The size of the data is limited by the receiving API and therefore `push_data()` will only @@ -149,7 +164,7 @@ async def push_data(cls, data: JSONSerializable) -> None: dataset = await cls.open() return await dataset.push_data(data) - async def _push_data_internal(self, data: JSONSerializable) -> None: + async def _push_data_internal(self: Dataset, data: JSONSerializable) -> None: # Handle singular items if not isinstance(data, list): payload = _check_and_serialize(data) @@ -161,22 +176,23 @@ async def _push_data_internal(self, data: JSONSerializable) -> None: # Invoke client in series to preserve the order of data for chunk in _chunk_by_size(payloads_generator): await self._dataset_client.push_items(chunk) + return None @classmethod async def get_data( - cls, + cls: type[Dataset], *, - offset: Optional[int] = None, - limit: Optional[int] = None, - clean: Optional[bool] = None, - desc: Optional[bool] = None, - fields: Optional[List[str]] = None, - omit: Optional[List[str]] = None, - unwind: Optional[str] = None, - skip_empty: Optional[bool] = None, - skip_hidden: Optional[bool] = None, - flatten: Optional[List[str]] = None, - view: Optional[str] = None, + offset: int | None = None, + limit: int | None = None, + clean: bool | None = None, + desc: bool | None = None, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool | None = None, + skip_hidden: bool | None = None, + flatten: list[str] | None = None, + view: str | None = None, ) -> ListPage: """Get items from the dataset. @@ -223,30 +239,22 @@ async def get_data( ) async def _get_data_internal( - self, + self: Dataset, *, - offset: Optional[int] = None, - limit: Optional[int] = None, - clean: Optional[bool] = None, - desc: Optional[bool] = None, - fields: Optional[List[str]] = None, - omit: Optional[List[str]] = None, - unwind: Optional[str] = None, - skip_empty: Optional[bool] = None, - skip_hidden: Optional[bool] = None, - flatten: Optional[List[str]] = None, - view: Optional[str] = None, + offset: int | None = None, + limit: int | None = None, + clean: bool | None = None, + desc: bool | None = None, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool | None = None, + skip_hidden: bool | None = None, + flatten: list[str] | None = None, + view: str | None = None, ) -> ListPage: - # try { - # return await this.client.listItems(options); - # } catch (e) { - # const error = e as Error; - # if (error.message.includes('Cannot create a string longer than')) { - # throw new Error('dataset.getData(): The response is too large for parsing. You can fix this by lowering the "limit" option.'); - # } - # throw e; - # } - # TODO: Simulate the above error in Python and handle accordingly... + # TODO: Improve error handling here + # https://github.com/apify/apify-sdk-python/issues/140 return await self._dataset_client.list_items( offset=offset, limit=limit, @@ -262,12 +270,12 @@ async def _get_data_internal( ) async def export_to( - self, + self: Dataset, key: str, *, - to_key_value_store_id: Optional[str] = None, - to_key_value_store_name: Optional[str] = None, - content_type: Optional[str] = None, + to_key_value_store_id: str | None = None, + to_key_value_store_name: str | None = None, + content_type: str | None = None, ) -> None: """Save the entirety of the dataset's contents into one file within a key-value store. @@ -280,7 +288,7 @@ async def export_to( content_type (str, optional): Either 'text/csv' or 'application/json'. Defaults to JSON. """ key_value_store = await KeyValueStore.open(id=to_key_value_store_id, name=to_key_value_store_name) - items: List[Dict] = [] + items: list[dict] = [] limit = 1000 offset = 0 while True: @@ -307,13 +315,13 @@ async def export_to( @classmethod async def export_to_json( - cls, + cls: type[Dataset], key: str, *, - from_dataset_id: Optional[str] = None, - from_dataset_name: Optional[str] = None, - to_key_value_store_id: Optional[str] = None, - to_key_value_store_name: Optional[str] = None, + from_dataset_id: str | None = None, + from_dataset_name: str | None = None, + to_key_value_store_id: str | None = None, + to_key_value_store_name: str | None = None, ) -> None: """Save the entirety of the dataset's contents into one JSON file within a key-value store. @@ -332,13 +340,13 @@ async def export_to_json( await dataset.export_to_json(key, to_key_value_store_id=to_key_value_store_id, to_key_value_store_name=to_key_value_store_name) async def _export_to_json_internal( - self, + self: Dataset, key: str, *, - from_dataset_id: Optional[str] = None, # noqa: U100 - from_dataset_name: Optional[str] = None, # noqa: U100 - to_key_value_store_id: Optional[str] = None, - to_key_value_store_name: Optional[str] = None, + from_dataset_id: str | None = None, # noqa: ARG002 + from_dataset_name: str | None = None, # noqa: ARG002 + to_key_value_store_id: str | None = None, + to_key_value_store_name: str | None = None, ) -> None: await self.export_to( key, @@ -349,13 +357,13 @@ async def _export_to_json_internal( @classmethod async def export_to_csv( - cls, + cls: type[Dataset], key: str, *, - from_dataset_id: Optional[str] = None, - from_dataset_name: Optional[str] = None, - to_key_value_store_id: Optional[str] = None, - to_key_value_store_name: Optional[str] = None, + from_dataset_id: str | None = None, + from_dataset_name: str | None = None, + to_key_value_store_id: str | None = None, + to_key_value_store_name: str | None = None, ) -> None: """Save the entirety of the dataset's contents into one CSV file within a key-value store. @@ -374,13 +382,13 @@ async def export_to_csv( await dataset.export_to_csv(key, to_key_value_store_id=to_key_value_store_id, to_key_value_store_name=to_key_value_store_name) async def _export_to_csv_internal( - self, + self: Dataset, key: str, *, - from_dataset_id: Optional[str] = None, # noqa: U100 - from_dataset_name: Optional[str] = None, # noqa: U100 - to_key_value_store_id: Optional[str] = None, - to_key_value_store_name: Optional[str] = None, + from_dataset_id: str | None = None, # noqa: ARG002 + from_dataset_name: str | None = None, # noqa: ARG002 + to_key_value_store_id: str | None = None, + to_key_value_store_name: str | None = None, ) -> None: await self.export_to( key, @@ -389,7 +397,7 @@ async def _export_to_csv_internal( content_type='text/csv', ) - async def get_info(self) -> Optional[Dict]: + async def get_info(self: Dataset) -> dict | None: """Get an object containing general information about the dataset. Returns: @@ -398,18 +406,18 @@ async def get_info(self) -> Optional[Dict]: return await self._dataset_client.get() def iterate_items( - self, + self: Dataset, *, offset: int = 0, - limit: Optional[int] = None, - clean: Optional[bool] = None, - desc: Optional[bool] = None, - fields: Optional[List[str]] = None, - omit: Optional[List[str]] = None, - unwind: Optional[str] = None, - skip_empty: Optional[bool] = None, - skip_hidden: Optional[bool] = None, - ) -> AsyncIterator[Dict]: + limit: int | None = None, + clean: bool | None = None, + desc: bool | None = None, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool | None = None, + skip_hidden: bool | None = None, + ) -> AsyncIterator[dict]: """Iterate over the items in the dataset. Args: @@ -449,20 +457,20 @@ def iterate_items( skip_hidden=skip_hidden, ) - async def drop(self) -> None: + async def drop(self: Dataset) -> None: """Remove the dataset either from the Apify cloud storage or from the local directory.""" await self._dataset_client.delete() self._remove_from_cache() @classmethod - async def open( - cls, + async def open( # noqa: A003 + cls: type[Dataset], *, - id: Optional[str] = None, - name: Optional[str] = None, + id: str | None = None, # noqa: A002 + name: str | None = None, force_cloud: bool = False, - config: Optional[Configuration] = None, - ) -> 'Dataset': + config: Configuration | None = None, + ) -> Dataset: """Open a dataset. Datasets are used to store structured data where each object stored has the same attributes, @@ -483,4 +491,4 @@ async def open( Returns: Dataset: An instance of the `Dataset` class for the given ID or name. """ - return await super().open(id=id, name=name, force_cloud=force_cloud, config=config) + return await super().open(id=id, name=name, force_cloud=force_cloud, config=config) # type: ignore diff --git a/src/apify/storages/key_value_store.py b/src/apify/storages/key_value_store.py index aa2bd972..0960eb3f 100644 --- a/src/apify/storages/key_value_store.py +++ b/src/apify/storages/key_value_store.py @@ -1,18 +1,35 @@ -from typing import Any, AsyncIterator, NamedTuple, Optional, TypedDict, TypeVar, Union, overload +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, AsyncIterator, NamedTuple, TypedDict, TypeVar, overload -from apify_client import ApifyClientAsync from apify_client.clients import KeyValueStoreClientAsync, KeyValueStoreCollectionClientAsync from apify_shared.utils import ignore_docs -from .._memory_storage import MemoryStorageClient -from .._memory_storage.resource_clients import KeyValueStoreClient, KeyValueStoreCollectionClient -from .._utils import _wrap_internal -from ..config import Configuration +from .._utils import wrap_internal from .base_storage import BaseStorage +if TYPE_CHECKING: + from apify_client import ApifyClientAsync + + from .._memory_storage import MemoryStorageClient + from .._memory_storage.resource_clients import KeyValueStoreClient, KeyValueStoreCollectionClient + from ..config import Configuration + + T = TypeVar('T') -IterateKeysInfo = TypedDict('IterateKeysInfo', {'size': int}) -IterateKeysTuple = NamedTuple('IterateKeysTuple', [('key', str), ('info', IterateKeysInfo)]) + + +class IterateKeysInfo(TypedDict): + """Contains information about a key-value store record.""" + + size: int + + +class IterateKeysTuple(NamedTuple): + """A tuple representing a key-value store record.""" + + key: str + info: IterateKeysInfo class KeyValueStore(BaseStorage): @@ -48,11 +65,17 @@ class KeyValueStore(BaseStorage): """ _id: str - _name: Optional[str] - _key_value_store_client: Union[KeyValueStoreClientAsync, KeyValueStoreClient] + _name: str | None + _key_value_store_client: KeyValueStoreClientAsync | KeyValueStoreClient @ignore_docs - def __init__(self, id: str, name: Optional[str], client: Union[ApifyClientAsync, MemoryStorageClient], config: Configuration) -> None: + def __init__( + self: KeyValueStore, + id: str, # noqa: A002 + name: str | None, + client: ApifyClientAsync | MemoryStorageClient, + config: Configuration, + ) -> None: """Create a `KeyValueStore` instance. Do not use the constructor directly, use the `Actor.open_key_value_store()` function instead. @@ -65,53 +88,53 @@ def __init__(self, id: str, name: Optional[str], client: Union[ApifyClientAsync, """ super().__init__(id=id, name=name, client=client, config=config) - self.get_value = _wrap_internal(self._get_value_internal, self.get_value) # type: ignore - self.set_value = _wrap_internal(self._set_value_internal, self.set_value) # type: ignore - self.get_public_url = _wrap_internal(self._get_public_url_internal, self.get_public_url) # type: ignore + self.get_value = wrap_internal(self._get_value_internal, self.get_value) # type: ignore + self.set_value = wrap_internal(self._set_value_internal, self.set_value) # type: ignore + self.get_public_url = wrap_internal(self._get_public_url_internal, self.get_public_url) # type: ignore self._id = id self._name = name self._key_value_store_client = client.key_value_store(self._id) @classmethod - def _get_human_friendly_label(cls) -> str: + def _get_human_friendly_label(cls: type[KeyValueStore]) -> str: return 'Key-value store' @classmethod - def _get_default_id(cls, config: Configuration) -> str: + def _get_default_id(cls: type[KeyValueStore], config: Configuration) -> str: return config.default_key_value_store_id @classmethod def _get_single_storage_client( - cls, - id: str, - client: Union[ApifyClientAsync, MemoryStorageClient], - ) -> Union[KeyValueStoreClientAsync, KeyValueStoreClient]: + cls: type[KeyValueStore], + id: str, # noqa: A002 + client: ApifyClientAsync | MemoryStorageClient, + ) -> KeyValueStoreClientAsync | KeyValueStoreClient: return client.key_value_store(id) @classmethod def _get_storage_collection_client( - cls, - client: Union[ApifyClientAsync, MemoryStorageClient], - ) -> Union[KeyValueStoreCollectionClientAsync, KeyValueStoreCollectionClient]: + cls: type[KeyValueStore], + client: ApifyClientAsync | MemoryStorageClient, + ) -> KeyValueStoreCollectionClientAsync | KeyValueStoreCollectionClient: return client.key_value_stores() @overload @classmethod - async def get_value(cls, key: str) -> Any: + async def get_value(cls: type[KeyValueStore], key: str) -> Any: ... @overload @classmethod - async def get_value(cls, key: str, default_value: T) -> T: + async def get_value(cls: type[KeyValueStore], key: str, default_value: T) -> T: ... @overload @classmethod - async def get_value(cls, key: str, default_value: Optional[T] = None) -> Optional[T]: + async def get_value(cls: type[KeyValueStore], key: str, default_value: T | None = None) -> T | None: ... @classmethod - async def get_value(cls, key: str, default_value: Optional[T] = None) -> Optional[T]: + async def get_value(cls: type[KeyValueStore], key: str, default_value: T | None = None) -> T | None: """Get a value from the key-value store. Args: @@ -124,11 +147,14 @@ async def get_value(cls, key: str, default_value: Optional[T] = None) -> Optiona store = await cls.open() return await store.get_value(key, default_value) - async def _get_value_internal(self, key: str, default_value: Optional[T] = None) -> Optional[T]: + async def _get_value_internal(self: KeyValueStore, key: str, default_value: T | None = None) -> T | None: record = await self._key_value_store_client.get_record(key) return record['value'] if record else default_value - async def iterate_keys(self, exclusive_start_key: Optional[str] = None) -> AsyncIterator[IterateKeysTuple]: + async def iterate_keys( + self: KeyValueStore, + exclusive_start_key: str | None = None, + ) -> AsyncIterator[IterateKeysTuple]: """Iterate over the keys in the key-value store. Args: @@ -149,25 +175,35 @@ async def iterate_keys(self, exclusive_start_key: Optional[str] = None) -> Async exclusive_start_key = list_keys['nextExclusiveStartKey'] @classmethod - async def set_value(cls, key: str, value: Optional[T], content_type: Optional[str] = None) -> None: + async def set_value( + cls: type[KeyValueStore], + key: str, + value: Any, + content_type: str | None = None, + ) -> None: """Set or delete a value in the key-value store. Args: key (str): The key under which the value should be saved. - value (Any, optional): The value to save. If the value is `None`, the corresponding key-value pair will be deleted. + value (Any): The value to save. If the value is `None`, the corresponding key-value pair will be deleted. content_type (str, optional): The content type of the saved value. """ store = await cls.open() return await store.set_value(key, value, content_type) - async def _set_value_internal(self, key: str, value: Optional[T], content_type: Optional[str] = None) -> None: + async def _set_value_internal( + self: KeyValueStore, + key: str, + value: Any, + content_type: str | None = None, + ) -> None: if value is None: return await self._key_value_store_client.delete_record(key) return await self._key_value_store_client.set_record(key, value, content_type) @classmethod - async def get_public_url(cls, key: str) -> str: + async def get_public_url(cls: type[KeyValueStore], key: str) -> str: """Get a URL for the given key that may be used to publicly access the value in the remote key-value store. Args: @@ -176,28 +212,28 @@ async def get_public_url(cls, key: str) -> str: store = await cls.open() return await store.get_public_url(key) - async def _get_public_url_internal(self, key: str) -> str: + async def _get_public_url_internal(self: KeyValueStore, key: str) -> str: if not isinstance(self._key_value_store_client, KeyValueStoreClientAsync): - raise RuntimeError('Cannot generate a public URL for this key-value store as it is not on the Apify Platform!') + raise RuntimeError('Cannot generate a public URL for this key-value store as it is not on the Apify Platform!') # noqa: TRY004 public_api_url = self._config.api_public_base_url return f'{public_api_url}/v2/key-value-stores/{self._id}/records/{key}' - async def drop(self) -> None: + async def drop(self: KeyValueStore) -> None: """Remove the key-value store either from the Apify cloud storage or from the local directory.""" await self._key_value_store_client.delete() self._remove_from_cache() @classmethod - async def open( - cls, + async def open( # noqa: A003 + cls: type[KeyValueStore], *, - id: Optional[str] = None, - name: Optional[str] = None, + id: str | None = None, # noqa: A002 + name: str | None = None, force_cloud: bool = False, - config: Optional[Configuration] = None, - ) -> 'KeyValueStore': + config: Configuration | None = None, + ) -> KeyValueStore: """Open a key-value store. Key-value stores are used to store records or files, along with their MIME content type. @@ -218,4 +254,4 @@ async def open( Returns: KeyValueStore: An instance of the `KeyValueStore` class for the given ID or name. """ - return await super().open(id=id, name=name, force_cloud=force_cloud, config=config) + return await super().open(id=id, name=name, force_cloud=force_cloud, config=config) # type: ignore diff --git a/src/apify/storages/request_queue.py b/src/apify/storages/request_queue.py index 8cdf8712..81fec196 100644 --- a/src/apify/storages/request_queue.py +++ b/src/apify/storages/request_queue.py @@ -1,23 +1,28 @@ +from __future__ import annotations + import asyncio from collections import OrderedDict from datetime import datetime, timezone -from typing import Dict, Optional +from typing import TYPE_CHECKING from typing import OrderedDict as OrderedDictType -from typing import Set, Union -from apify_client import ApifyClientAsync -from apify_client.clients import RequestQueueClientAsync, RequestQueueCollectionClientAsync from apify_shared.utils import ignore_docs from .._crypto import crypto_random_object_id -from .._memory_storage import MemoryStorageClient -from .._memory_storage.resource_clients import RequestQueueClient, RequestQueueCollectionClient -from .._utils import LRUCache, _budget_ow, _unique_key_to_request_id -from ..config import Configuration +from .._utils import LRUCache, budget_ow, unique_key_to_request_id from ..consts import REQUEST_QUEUE_HEAD_MAX_LIMIT from ..log import logger from .base_storage import BaseStorage +if TYPE_CHECKING: + from apify_client import ApifyClientAsync + from apify_client.clients import RequestQueueClientAsync, RequestQueueCollectionClientAsync + + from .._memory_storage import MemoryStorageClient + from .._memory_storage.resource_clients import RequestQueueClient, RequestQueueCollectionClient + from ..config import Configuration + + MAX_CACHED_REQUESTS = 1_000_000 # When requesting queue head we always fetch requestsInProgressCount * QUERY_HEAD_BUFFER number of requests. @@ -72,20 +77,26 @@ class RequestQueue(BaseStorage): cloud storage. """ - _request_queue_client: Union[RequestQueueClientAsync, RequestQueueClient] + _request_queue_client: RequestQueueClientAsync | RequestQueueClient _client_key = crypto_random_object_id() _queue_head_dict: OrderedDictType[str, str] - _query_queue_head_task: Optional[asyncio.Task] - _in_progress: Set[str] + _query_queue_head_task: asyncio.Task | None + _in_progress: set[str] _last_activity: datetime _internal_timeout_seconds = 5 * 60 _recently_handled: LRUCache[bool] _assumed_total_count = 0 _assumed_handled_count = 0 - _requests_cache: LRUCache[Dict] + _requests_cache: LRUCache[dict] @ignore_docs - def __init__(self, id: str, name: Optional[str], client: Union[ApifyClientAsync, MemoryStorageClient], config: Configuration) -> None: + def __init__( + self: RequestQueue, + id: str, # noqa: A002 + name: str | None, + client: ApifyClientAsync | MemoryStorageClient, + config: Configuration, + ) -> None: """Create a `RequestQueue` instance. Do not use the constructor directly, use the `Actor.open_request_queue()` function instead. @@ -107,29 +118,29 @@ def __init__(self, id: str, name: Optional[str], client: Union[ApifyClientAsync, self._requests_cache = LRUCache(max_length=MAX_CACHED_REQUESTS) @classmethod - def _get_human_friendly_label(cls) -> str: + def _get_human_friendly_label(cls: type[RequestQueue]) -> str: return 'Request queue' @classmethod - def _get_default_id(cls, config: Configuration) -> str: + def _get_default_id(cls: type[RequestQueue], config: Configuration) -> str: return config.default_request_queue_id @classmethod def _get_single_storage_client( - cls, - id: str, - client: Union[ApifyClientAsync, MemoryStorageClient], - ) -> Union[RequestQueueClientAsync, RequestQueueClient]: + cls: type[RequestQueue], + id: str, # noqa: A002 + client: ApifyClientAsync | MemoryStorageClient, + ) -> RequestQueueClientAsync | RequestQueueClient: return client.request_queue(id) @classmethod def _get_storage_collection_client( - cls, - client: Union[ApifyClientAsync, MemoryStorageClient], - ) -> Union[RequestQueueCollectionClientAsync, RequestQueueCollectionClient]: + cls: type[RequestQueue], + client: ApifyClientAsync | MemoryStorageClient, + ) -> RequestQueueCollectionClientAsync | RequestQueueCollectionClient: return client.request_queues() - async def add_request(self, request: Dict, *, forefront: bool = False) -> Dict: + async def add_request(self: RequestQueue, request: dict, *, forefront: bool = False) -> dict: """Add a request to the queue. Args: @@ -139,15 +150,20 @@ async def add_request(self, request: Dict, *, forefront: bool = False) -> Dict: Returns: dict: Information about the queue operation with keys `requestId`, `uniqueKey`, `wasAlreadyPresent`, `wasAlreadyHandled`. """ - _budget_ow(request, { - 'url': (str, True), - }) + budget_ow( + request, + { + 'url': (str, True), + }, + ) self._last_activity = datetime.now(timezone.utc) if request.get('uniqueKey') is None: - request['uniqueKey'] = request['url'] # TODO: Check Request class in crawlee and replicate uniqueKey generation logic... + # TODO: Check Request class in crawlee and replicate uniqueKey generation logic... + # https://github.com/apify/apify-sdk-python/issues/141 + request['uniqueKey'] = request['url'] - cache_key = _unique_key_to_request_id(request['uniqueKey']) + cache_key = unique_key_to_request_id(request['uniqueKey']) cached_info = self._requests_cache.get(cache_key) if cached_info: @@ -175,7 +191,7 @@ async def add_request(self, request: Dict, *, forefront: bool = False) -> Dict: return queue_operation_info - async def get_request(self, request_id: str) -> Optional[Dict]: + async def get_request(self: RequestQueue, request_id: str) -> dict | None: """Retrieve a request from the queue. Args: @@ -184,10 +200,10 @@ async def get_request(self, request_id: str) -> Optional[Dict]: Returns: dict, optional: The retrieved request, or `None`, if it does not exist. """ - _budget_ow(request_id, (str, True), 'request_id') + budget_ow(request_id, (str, True), 'request_id') return await self._request_queue_client.get_request(request_id) - async def fetch_next_request(self) -> Optional[Dict]: + async def fetch_next_request(self: RequestQueue) -> dict | None: """Return the next request in the queue to be processed. Once you successfully finish processing of the request, you need to call @@ -211,21 +227,24 @@ async def fetch_next_request(self) -> Optional[Dict]: # This should never happen, but... if next_request_id in self._in_progress or self._recently_handled.get(next_request_id): - logger.warning('Queue head returned a request that is already in progress?!', extra={ - 'nextRequestId': next_request_id, - 'inProgress': next_request_id in self._in_progress, - 'recentlyHandled': next_request_id in self._recently_handled, - }) + logger.warning( + 'Queue head returned a request that is already in progress?!', + extra={ + 'nextRequestId': next_request_id, + 'inProgress': next_request_id in self._in_progress, + 'recentlyHandled': next_request_id in self._recently_handled, + }, + ) return None self._in_progress.add(next_request_id) self._last_activity = datetime.now(timezone.utc) try: request = await self.get_request(next_request_id) - except Exception as e: + except Exception: # On error, remove the request from in progress, otherwise it would be there forever self._in_progress.remove(next_request_id) - raise e + raise # NOTE: It can happen that the queue head index is inconsistent with the main queue table. This can occur in two situations: @@ -252,7 +271,7 @@ async def fetch_next_request(self) -> Optional[Dict]: return request - async def mark_request_as_handled(self, request: Dict) -> Optional[Dict]: + async def mark_request_as_handled(self: RequestQueue, request: dict) -> dict | None: """Mark a request as handled after successful processing. Handled requests will never again be returned by the `RequestQueue.fetch_next_request` method. @@ -264,11 +283,14 @@ async def mark_request_as_handled(self, request: Dict) -> Optional[Dict]: dict, optional: Information about the queue operation with keys `requestId`, `uniqueKey`, `wasAlreadyPresent`, `wasAlreadyHandled`. `None` if the given request was not in progress. """ - _budget_ow(request, { - 'id': (str, True), - 'uniqueKey': (str, True), - 'handledAt': (datetime, False), - }) + budget_ow( + request, + { + 'id': (str, True), + 'uniqueKey': (str, True), + 'handledAt': (datetime, False), + }, + ) self._last_activity = datetime.now(timezone.utc) if request['id'] not in self._in_progress: logger.debug('Cannot mark request as handled, because it is not in progress!', extra={'requestId': request['id']}) @@ -284,11 +306,15 @@ async def mark_request_as_handled(self, request: Dict) -> Optional[Dict]: if not queue_operation_info['wasAlreadyHandled']: self._assumed_handled_count += 1 - self._cache_request(_unique_key_to_request_id(request['uniqueKey']), queue_operation_info) + self._cache_request(unique_key_to_request_id(request['uniqueKey']), queue_operation_info) return queue_operation_info - async def reclaim_request(self, request: Dict, forefront: bool = False) -> Optional[Dict]: + async def reclaim_request( + self: RequestQueue, + request: dict, + forefront: bool = False, # noqa: FBT001, FBT002 + ) -> dict | None: """Reclaim a failed request back to the queue. The request will be returned for processing later again @@ -301,21 +327,25 @@ async def reclaim_request(self, request: Dict, forefront: bool = False) -> Optio dict, optional: Information about the queue operation with keys `requestId`, `uniqueKey`, `wasAlreadyPresent`, `wasAlreadyHandled`. `None` if the given request was not in progress. """ - _budget_ow(request, { - 'id': (str, True), - 'uniqueKey': (str, True), - }) + budget_ow( + request, + { + 'id': (str, True), + 'uniqueKey': (str, True), + }, + ) self._last_activity = datetime.now(timezone.utc) if request['id'] not in self._in_progress: logger.debug('Cannot reclaim request, because it is not in progress!', extra={'requestId': request['id']}) return None - # TODO: If request hasn't been changed since the last getRequest(), - # we don't need to call updateRequest() and thus improve performance. + # TODO: If request hasn't been changed since the last getRequest(), we don't need to call updateRequest() + # and thus improve performance. + # https://github.com/apify/apify-sdk-python/issues/143 queue_operation_info = await self._request_queue_client.update_request(request, forefront=forefront) queue_operation_info['uniqueKey'] = request['uniqueKey'] - self._cache_request(_unique_key_to_request_id(request['uniqueKey']), queue_operation_info) + self._cache_request(unique_key_to_request_id(request['uniqueKey']), queue_operation_info) # Wait a little to increase a chance that the next call to fetchNextRequest() will return the request with updated data. # This is to compensate for the limitation of DynamoDB, where writes might not be immediately visible to subsequent reads. @@ -333,10 +363,10 @@ def callback() -> None: return queue_operation_info - def _in_progress_count(self) -> int: + def _in_progress_count(self: RequestQueue) -> int: return len(self._in_progress) - async def is_empty(self) -> bool: + async def is_empty(self: RequestQueue) -> bool: """Check whether the queue is empty. Returns: @@ -345,7 +375,7 @@ async def is_empty(self) -> bool: await self._ensure_head_is_non_empty() return len(self._queue_head_dict) == 0 - async def is_finished(self) -> bool: + async def is_finished(self: RequestQueue) -> bool: """Check whether the queue is finished. Due to the nature of distributed storage used by the queue, @@ -361,13 +391,13 @@ async def is_finished(self) -> bool: logger.warning(message) self._reset() - if (len(self._queue_head_dict) > 0 or self._in_progress_count() > 0): + if len(self._queue_head_dict) > 0 or self._in_progress_count() > 0: return False - is_head_consistent = await self._ensure_head_is_non_empty(True) + is_head_consistent = await self._ensure_head_is_non_empty(ensure_consistency=True) return is_head_consistent and len(self._queue_head_dict) == 0 and self._in_progress_count() == 0 - def _reset(self) -> None: + def _reset(self: RequestQueue) -> None: self._queue_head_dict.clear() self._query_queue_head_task = None self._in_progress.clear() @@ -377,7 +407,7 @@ def _reset(self) -> None: self._requests_cache.clear() self._last_activity = datetime.now(timezone.utc) - def _cache_request(self, cache_key: str, queue_operation_info: Dict) -> None: + def _cache_request(self: RequestQueue, cache_key: str, queue_operation_info: dict) -> None: self._requests_cache[cache_key] = { 'id': queue_operation_info['requestId'], 'isHandled': queue_operation_info['wasAlreadyHandled'], @@ -385,7 +415,7 @@ def _cache_request(self, cache_key: str, queue_operation_info: Dict) -> None: 'wasAlreadyHandled': queue_operation_info['wasAlreadyHandled'], } - async def _queue_query_head(self, limit: int) -> Dict: + async def _queue_query_head(self: RequestQueue, limit: int) -> dict: query_started_at = datetime.now(timezone.utc) list_head = await self._request_queue_client.list_head(limit=limit) @@ -394,12 +424,15 @@ async def _queue_query_head(self, limit: int) -> Dict: if not request['id'] or not request['uniqueKey'] or request['id'] in self._in_progress or self._recently_handled.get(request['id']): continue self._queue_head_dict[request['id']] = request['id'] - self._cache_request(_unique_key_to_request_id(request['uniqueKey']), { - 'requestId': request['id'], - 'wasAlreadyHandled': False, - 'wasAlreadyPresent': True, - 'uniqueKey': request['uniqueKey'], - }) + self._cache_request( + unique_key_to_request_id(request['uniqueKey']), + { + 'requestId': request['id'], + 'wasAlreadyHandled': False, + 'wasAlreadyPresent': True, + 'uniqueKey': request['uniqueKey'], + }, + ) # This is needed so that the next call to _ensureHeadIsNonEmpty() will fetch the queue head again. self._query_queue_head_task = None @@ -412,7 +445,12 @@ async def _queue_query_head(self, limit: int) -> Dict: 'hadMultipleClients': list_head['hadMultipleClients'], } - async def _ensure_head_is_non_empty(self, ensure_consistency: bool = False, limit: Optional[int] = None, iteration: int = 0) -> bool: + async def _ensure_head_is_non_empty( + self: RequestQueue, + ensure_consistency: bool = False, # noqa: FBT001, FBT002 + limit: int | None = None, + iteration: int = 0, + ) -> bool: # If is nonempty resolve immediately. if len(self._queue_head_dict) > 0: return True @@ -426,26 +464,28 @@ async def _ensure_head_is_non_empty(self, ensure_consistency: bool = False, limi queue_head = await self._query_queue_head_task # TODO: I feel this code below can be greatly simplified... (comes from TS implementation *wink*) + # https://github.com/apify/apify-sdk-python/issues/142 - """ If queue is still empty then one of the following holds: - - the other calls waiting for this task already consumed all the returned requests - - the limit was too low and contained only requests in progress - - the writes from other clients were not propagated yet - - the whole queue was processed and we are done - """ + # If queue is still empty then one of the following holds: + # - the other calls waiting for this task already consumed all the returned requests + # - the limit was too low and contained only requests in progress + # - the writes from other clients were not propagated yet + # - the whole queue was processed and we are done # If limit was not reached in the call then there are no more requests to be returned. - if (queue_head['prevLimit'] >= REQUEST_QUEUE_HEAD_MAX_LIMIT): + if queue_head['prevLimit'] >= REQUEST_QUEUE_HEAD_MAX_LIMIT: logger.warning('Reached the maximum number of requests in progress', extra={'limit': REQUEST_QUEUE_HEAD_MAX_LIMIT}) - should_repeat_with_higher_limit = len( - self._queue_head_dict) == 0 and queue_head['wasLimitReached'] and queue_head['prevLimit'] < REQUEST_QUEUE_HEAD_MAX_LIMIT + should_repeat_with_higher_limit = ( + len(self._queue_head_dict) == 0 and queue_head['wasLimitReached'] and queue_head['prevLimit'] < REQUEST_QUEUE_HEAD_MAX_LIMIT + ) # If ensureConsistency=true then we must ensure that either: # - queueModifiedAt is older than queryStartedAt by at least API_PROCESSED_REQUESTS_DELAY_MILLIS # - hadMultipleClients=false and this.assumedTotalCount<=this.assumedHandledCount - is_database_consistent = (queue_head['queryStartedAt'] - queue_head['queueModifiedAt'].replace(tzinfo=timezone.utc) - ).seconds >= (API_PROCESSED_REQUESTS_DELAY_MILLIS // 1000) + is_database_consistent = (queue_head['queryStartedAt'] - queue_head['queueModifiedAt'].replace(tzinfo=timezone.utc)).seconds >= ( + API_PROCESSED_REQUESTS_DELAY_MILLIS // 1000 + ) is_locally_consistent = not queue_head['hadMultipleClients'] and self._assumed_total_count <= self._assumed_handled_count # Consistent information from one source is enough to consider request queue finished. should_repeat_for_consistency = ensure_consistency and not is_database_consistent and not is_locally_consistent @@ -463,14 +503,17 @@ async def _ensure_head_is_non_empty(self, ensure_consistency: bool = False, limi # If we are repeating for consistency then wait required time. if should_repeat_for_consistency: - delay_seconds = (API_PROCESSED_REQUESTS_DELAY_MILLIS // 1000) - \ - (datetime.now(timezone.utc) - queue_head['queueModifiedAt']).seconds + delay_seconds = (API_PROCESSED_REQUESTS_DELAY_MILLIS // 1000) - (datetime.now(timezone.utc) - queue_head['queueModifiedAt']).seconds logger.info(f'Waiting for {delay_seconds}s before considering the queue as finished to ensure that the data is consistent.') await asyncio.sleep(delay_seconds) return await self._ensure_head_is_non_empty(ensure_consistency, next_limit, iteration + 1) - def _maybe_add_request_to_queue_head(self, request_id: str, forefront: bool) -> None: + def _maybe_add_request_to_queue_head( + self: RequestQueue, + request_id: str, + forefront: bool, # noqa: FBT001 + ) -> None: if forefront: self._queue_head_dict[request_id] = request_id # Move to start, i.e. forefront of the queue @@ -479,12 +522,12 @@ def _maybe_add_request_to_queue_head(self, request_id: str, forefront: bool) -> # OrderedDict puts the item to the end of the queue by default self._queue_head_dict[request_id] = request_id - async def drop(self) -> None: + async def drop(self: RequestQueue) -> None: """Remove the request queue either from the Apify cloud storage or from the local directory.""" await self._request_queue_client.delete() self._remove_from_cache() - async def get_info(self) -> Optional[Dict]: + async def get_info(self: RequestQueue) -> dict | None: """Get an object containing general information about the request queue. Returns: @@ -493,14 +536,14 @@ async def get_info(self) -> Optional[Dict]: return await self._request_queue_client.get() @classmethod - async def open( - cls, + async def open( # noqa: A003 + cls: type[RequestQueue], *, - id: Optional[str] = None, - name: Optional[str] = None, + id: str | None = None, # noqa: A002 + name: str | None = None, force_cloud: bool = False, - config: Optional[Configuration] = None, - ) -> 'RequestQueue': + config: Configuration | None = None, + ) -> RequestQueue: """Open a request queue. Request queue represents a queue of URLs to crawl, which is stored either on local filesystem or in the Apify cloud. @@ -523,5 +566,5 @@ async def open( RequestQueue: An instance of the `RequestQueue` class for the given ID or name. """ queue = await super().open(id=id, name=name, force_cloud=force_cloud, config=config) - await queue._ensure_head_is_non_empty() - return queue + await queue._ensure_head_is_non_empty() # type: ignore + return queue # type: ignore diff --git a/src/apify/storages/storage_client_manager.py b/src/apify/storages/storage_client_manager.py index 0e02c3e8..bee2c781 100644 --- a/src/apify/storages/storage_client_manager.py +++ b/src/apify/storages/storage_client_manager.py @@ -1,11 +1,15 @@ -from typing import Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING -from apify_client import ApifyClientAsync from apify_shared.utils import ignore_docs from .._memory_storage import MemoryStorageClient from ..config import Configuration +if TYPE_CHECKING: + from apify_client import ApifyClientAsync + @ignore_docs class StorageClientManager: @@ -13,17 +17,17 @@ class StorageClientManager: _config: Configuration - _local_client: Optional[MemoryStorageClient] = None - _cloud_client: Optional[ApifyClientAsync] = None + _local_client: MemoryStorageClient | None = None + _cloud_client: ApifyClientAsync | None = None - _default_instance: Optional['StorageClientManager'] = None + _default_instance: StorageClientManager | None = None - def __init__(self) -> None: + def __init__(self: StorageClientManager) -> None: """Create a `StorageClientManager` instance.""" self._config = Configuration.get_global_configuration() @classmethod - def set_config(cls, config: Configuration) -> None: + def set_config(cls: type[StorageClientManager], config: Configuration) -> None: """Set the config for the StorageClientManager. Args: @@ -32,7 +36,10 @@ def set_config(cls, config: Configuration) -> None: cls._get_default_instance()._config = config @classmethod - def get_storage_client(cls, force_cloud: bool = False) -> Union[ApifyClientAsync, MemoryStorageClient]: + def get_storage_client( + cls: type[StorageClientManager], + force_cloud: bool = False, # noqa: FBT001, FBT002 + ) -> ApifyClientAsync | MemoryStorageClient: """Get the current storage client instance. Returns: @@ -43,13 +50,13 @@ def get_storage_client(cls, force_cloud: bool = False) -> Union[ApifyClientAsync default_instance._local_client = MemoryStorageClient(persist_storage=default_instance._config.persist_storage, write_metadata=True) if default_instance._config.is_at_home or force_cloud: - assert default_instance._cloud_client is not None + assert default_instance._cloud_client is not None # noqa: S101 return default_instance._cloud_client return default_instance._local_client @classmethod - def set_cloud_client(cls, client: ApifyClientAsync) -> None: + def set_cloud_client(cls: type[StorageClientManager], client: ApifyClientAsync) -> None: """Set the storage client. Args: @@ -58,7 +65,7 @@ def set_cloud_client(cls, client: ApifyClientAsync) -> None: cls._get_default_instance()._cloud_client = client @classmethod - def _get_default_instance(cls) -> 'StorageClientManager': + def _get_default_instance(cls: type[StorageClientManager]) -> StorageClientManager: if cls._default_instance is None: cls._default_instance = cls() diff --git a/tests/integration/_utils.py b/tests/integration/_utils.py index e6e89052..b69d6d58 100644 --- a/tests/integration/_utils.py +++ b/tests/integration/_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from apify._crypto import crypto_random_object_id diff --git a/tests/integration/actor_source_base/src/__main__.py b/tests/integration/actor_source_base/src/__main__.py index 6aafff27..643eb63c 100644 --- a/tests/integration/actor_source_base/src/__main__.py +++ b/tests/integration/actor_source_base/src/__main__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import logging diff --git a/tests/integration/actor_source_base/src/main.py b/tests/integration/actor_source_base/src/main.py index 210b9204..78c03a48 100644 --- a/tests/integration/actor_source_base/src/main.py +++ b/tests/integration/actor_source_base/src/main.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from apify import Actor diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index ccdbd195..93c46710 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import inspect import os @@ -5,7 +7,7 @@ import sys import textwrap from pathlib import Path -from typing import AsyncIterator, Awaitable, Callable, Dict, List, Mapping, Optional, Protocol, Union +from typing import TYPE_CHECKING, AsyncIterator, Awaitable, Callable, Mapping, Protocol import pytest from filelock import FileLock @@ -14,11 +16,13 @@ from apify.config import Configuration from apify.storages import Dataset, KeyValueStore, RequestQueue, StorageClientManager from apify_client import ApifyClientAsync -from apify_client.clients.resource_clients import ActorClientAsync from apify_shared.consts import ActorJobStatus, ActorSourceType from ._utils import generate_unique_resource_name +if TYPE_CHECKING: + from apify_client.clients.resource_clients import ActorClientAsync + TOKEN_ENV_VAR = 'APIFY_TEST_USER_API_TOKEN' API_URL_ENV_VAR = 'APIFY_INTEGRATION_TESTS_API_URL' SDK_ROOT_PATH = Path(__file__).parent.parent.parent.resolve() @@ -44,7 +48,7 @@ def _reset_and_patch_default_instances(monkeypatch: pytest.MonkeyPatch) -> None: # because `httpx.AsyncClient` in `ApifyClientAsync` tries to reuse the same event loop across requests, # but `pytest-asyncio` closes the event loop after each test, # and uses a new one for the next test. -@pytest.fixture +@pytest.fixture() def apify_client_async() -> ApifyClientAsync: api_token = os.getenv(TOKEN_ENV_VAR) api_url = os.getenv(API_URL_ENV_VAR) @@ -65,7 +69,7 @@ def sdk_wheel_path(tmp_path_factory: pytest.TempPathFactory, testrun_uid: str) - # through an indicator file saying that the wheel was already built was_wheel_built_this_test_run_file = tmp_path_factory.getbasetemp() / f'wheel_was_built_in_run_{testrun_uid}' if not was_wheel_built_this_test_run_file.exists(): - subprocess.run('python -m build', cwd=SDK_ROOT_PATH, shell=True, check=True, capture_output=True) + subprocess.run('python -m build', cwd=SDK_ROOT_PATH, shell=True, check=True, capture_output=True) # noqa: S602, S607 was_wheel_built_this_test_run_file.touch() # Read the current package version, necessary for getting the right wheel filename @@ -87,14 +91,14 @@ def sdk_wheel_path(tmp_path_factory: pytest.TempPathFactory, testrun_uid: str) - @pytest.fixture(scope='session') -def actor_base_source_files(sdk_wheel_path: Path) -> Dict[str, Union[str, bytes]]: +def actor_base_source_files(sdk_wheel_path: Path) -> dict[str, str | bytes]: """Create a dictionary of the base source files for a testing actor. It takes the files from `tests/integration/actor_source_base`, builds the Apify SDK wheel from the current codebase, and adds them all together in a dictionary. """ - source_files: Dict[str, Union[str, bytes]] = {} + source_files: dict[str, str | bytes] = {} # First read the actor_source_base files sdk_root_path = Path(__file__).parent.parent.parent.resolve() @@ -124,27 +128,27 @@ def actor_base_source_files(sdk_wheel_path: Path) -> Dict[str, Union[str, bytes] # Just a type for the make_actor result, so that we can import it in tests class ActorFactory(Protocol): def __call__( - self, + self: ActorFactory, actor_label: str, *, - main_func: Optional[Callable] = None, - main_py: Optional[str] = None, - source_files: Optional[Mapping[str, Union[str, bytes]]] = None, + main_func: Callable | None = None, + main_py: str | None = None, + source_files: Mapping[str, str | bytes] | None = None, ) -> Awaitable[ActorClientAsync]: ... -@pytest.fixture -async def make_actor(actor_base_source_files: Dict[str, Union[str, bytes]], apify_client_async: ApifyClientAsync) -> AsyncIterator[ActorFactory]: +@pytest.fixture() +async def make_actor(actor_base_source_files: dict[str, str | bytes], apify_client_async: ApifyClientAsync) -> AsyncIterator[ActorFactory]: """A fixture for returning a temporary actor factory.""" - actor_clients_for_cleanup: List[ActorClientAsync] = [] + actor_clients_for_cleanup: list[ActorClientAsync] = [] async def _make_actor( actor_label: str, *, - main_func: Optional[Callable] = None, - main_py: Optional[str] = None, - source_files: Optional[Mapping[str, Union[str, bytes]]] = None, + main_func: Callable | None = None, + main_py: str | None = None, + source_files: Mapping[str, str | bytes] | None = None, ) -> ActorClientAsync: """Create a temporary actor from the given main function or source file(s). @@ -190,16 +194,18 @@ async def _make_actor( if isinstance(file_contents, str): file_format = 'TEXT' if file_name.endswith('.py'): - file_contents = textwrap.dedent(file_contents).lstrip() + file_contents = textwrap.dedent(file_contents).lstrip() # noqa: PLW2901 else: file_format = 'BASE64' - file_contents = base64.b64encode(file_contents).decode('utf-8') + file_contents = base64.b64encode(file_contents).decode('utf-8') # noqa: PLW2901 - source_files_for_api.append({ - 'name': file_name, - 'format': file_format, - 'content': file_contents, - }) + source_files_for_api.append( + { + 'name': file_name, + 'format': file_format, + 'content': file_contents, + } + ) print(f'Creating actor {actor_name}...') created_actor = await apify_client_async.actors().create( @@ -207,12 +213,14 @@ async def _make_actor( default_run_build='latest', default_run_memory_mbytes=256, default_run_timeout_secs=300, - versions=[{ - 'versionNumber': '0.0', - 'buildTag': 'latest', - 'sourceType': ActorSourceType.SOURCE_FILES, - 'sourceFiles': source_files_for_api, - }], + versions=[ + { + 'versionNumber': '0.0', + 'buildTag': 'latest', + 'sourceType': ActorSourceType.SOURCE_FILES, + 'sourceFiles': source_files_for_api, + } + ], ) actor_client = apify_client_async.actor(created_actor['id']) diff --git a/tests/integration/test_actor_api_helpers.py b/tests/integration/test_actor_api_helpers.py index 9b7e3132..025d9bee 100644 --- a/tests/integration/test_actor_api_helpers.py +++ b/tests/integration/test_actor_api_helpers.py @@ -1,16 +1,22 @@ +from __future__ import annotations + import asyncio import json +from typing import TYPE_CHECKING from apify import Actor from apify._crypto import crypto_random_object_id -from apify_client import ApifyClientAsync from ._utils import generate_unique_resource_name -from .conftest import ActorFactory + +if TYPE_CHECKING: + from apify_client import ApifyClientAsync + + from .conftest import ActorFactory class TestActorIsAtHome: - async def test_actor_is_at_home(self, make_actor: ActorFactory) -> None: + async def test_actor_is_at_home(self: TestActorIsAtHome, make_actor: ActorFactory) -> None: async def main() -> None: async with Actor: assert Actor.is_at_home() is True @@ -24,7 +30,7 @@ async def main() -> None: class TestActorGetEnv: - async def test_actor_get_env(self, make_actor: ActorFactory) -> None: + async def test_actor_get_env(self: TestActorGetEnv, make_actor: ActorFactory) -> None: async def main() -> None: async with Actor: env_dict = Actor.get_env() @@ -50,7 +56,7 @@ async def main() -> None: class TestActorNewClient: - async def test_actor_new_client(self, make_actor: ActorFactory) -> None: + async def test_actor_new_client(self: TestActorNewClient, make_actor: ActorFactory) -> None: async def main() -> None: import os @@ -78,11 +84,11 @@ async def main() -> None: class TestActorSetStatusMessage: - async def test_actor_set_status_message(self, make_actor: ActorFactory) -> None: + async def test_actor_set_status_message(self: TestActorSetStatusMessage, make_actor: ActorFactory) -> None: async def main() -> None: async with Actor: - input = await Actor.get_input() or {} - await Actor.set_status_message('testing-status-message', **input) + actor_input = await Actor.get_input() or {} + await Actor.set_status_message('testing-status-message', **actor_input) actor = await make_actor('set-status-message', main_func=main) @@ -102,19 +108,19 @@ async def main() -> None: class TestActorStart: - async def test_actor_start(self, make_actor: ActorFactory) -> None: + async def test_actor_start(self: TestActorStart, make_actor: ActorFactory) -> None: async def main_inner() -> None: async with Actor: await asyncio.sleep(5) - input = await Actor.get_input() or {} - test_value = input.get('test_value') + actor_input = await Actor.get_input() or {} + test_value = actor_input.get('test_value') await Actor.set_value('OUTPUT', f'{test_value}_XXX_{test_value}') async def main_outer() -> None: async with Actor: - input = await Actor.get_input() or {} - inner_actor_id = input.get('inner_actor_id') - test_value = input.get('test_value') + actor_input = await Actor.get_input() or {} + inner_actor_id = actor_input.get('inner_actor_id') + test_value = actor_input.get('test_value') assert inner_actor_id is not None @@ -143,19 +149,19 @@ async def main_outer() -> None: class TestActorCall: - async def test_actor_call(self, make_actor: ActorFactory) -> None: + async def test_actor_call(self: TestActorCall, make_actor: ActorFactory) -> None: async def main_inner() -> None: async with Actor: await asyncio.sleep(5) - input = await Actor.get_input() or {} - test_value = input.get('test_value') + actor_input = await Actor.get_input() or {} + test_value = actor_input.get('test_value') await Actor.set_value('OUTPUT', f'{test_value}_XXX_{test_value}') async def main_outer() -> None: async with Actor: - input = await Actor.get_input() or {} - inner_actor_id = input.get('inner_actor_id') - test_value = input.get('test_value') + actor_input = await Actor.get_input() or {} + inner_actor_id = actor_input.get('inner_actor_id') + test_value = actor_input.get('test_value') assert inner_actor_id is not None @@ -184,18 +190,22 @@ async def main_outer() -> None: class TestActorCallTask: - async def test_actor_call_task(self, make_actor: ActorFactory, apify_client_async: ApifyClientAsync) -> None: + async def test_actor_call_task( + self: TestActorCallTask, + make_actor: ActorFactory, + apify_client_async: ApifyClientAsync, + ) -> None: async def main_inner() -> None: async with Actor: await asyncio.sleep(5) - input = await Actor.get_input() or {} - test_value = input.get('test_value') + actor_input = await Actor.get_input() or {} + test_value = actor_input.get('test_value') await Actor.set_value('OUTPUT', f'{test_value}_XXX_{test_value}') async def main_outer() -> None: async with Actor: - input = await Actor.get_input() or {} - inner_task_id = input.get('inner_task_id') + actor_input = await Actor.get_input() or {} + inner_task_id = actor_input.get('inner_task_id') assert inner_task_id is not None @@ -232,7 +242,7 @@ async def main_outer() -> None: class TestActorAbort: - async def test_actor_abort(self, make_actor: ActorFactory) -> None: + async def test_actor_abort(self: TestActorAbort, make_actor: ActorFactory) -> None: async def main_inner() -> None: async with Actor: await asyncio.sleep(180) @@ -241,8 +251,8 @@ async def main_inner() -> None: async def main_outer() -> None: async with Actor: - input = await Actor.get_input() or {} - inner_run_id = input.get('inner_run_id') + actor_input = await Actor.get_input() or {} + inner_run_id = actor_input.get('inner_run_id') assert inner_run_id is not None @@ -268,7 +278,7 @@ async def main_outer() -> None: class TestActorMetamorph: - async def test_actor_metamorph(self, make_actor: ActorFactory) -> None: + async def test_actor_metamorph(self: TestActorMetamorph, make_actor: ActorFactory) -> None: async def main_inner() -> None: import os @@ -277,9 +287,9 @@ async def main_inner() -> None: async with Actor: assert os.getenv(ActorEnvVars.INPUT_KEY) is not None assert os.getenv(ActorEnvVars.INPUT_KEY) != 'INPUT' - input = await Actor.get_input() or {} + actor_input = await Actor.get_input() or {} - test_value = input.get('test_value', '') + test_value = actor_input.get('test_value', '') assert test_value.endswith('_BEFORE_METAMORPH') output = test_value.replace('_BEFORE_METAMORPH', '_AFTER_METAMORPH') @@ -287,9 +297,9 @@ async def main_inner() -> None: async def main_outer() -> None: async with Actor: - input = await Actor.get_input() or {} - inner_actor_id = input.get('inner_actor_id') - test_value = input.get('test_value') + actor_input = await Actor.get_input() or {} + inner_actor_id = actor_input.get('inner_actor_id') + test_value = actor_input.get('test_value') new_test_value = f'{test_value}_BEFORE_METAMORPH' assert inner_actor_id is not None @@ -324,7 +334,7 @@ async def main_outer() -> None: class TestActorReboot: - async def test_actor_reboot(self, make_actor: ActorFactory) -> None: + async def test_actor_reboot(self: TestActorReboot, make_actor: ActorFactory) -> None: async def main() -> None: async with Actor: print('Starting...') @@ -353,7 +363,7 @@ async def main() -> None: class TestActorAddWebhook: - async def test_actor_add_webhook(self, make_actor: ActorFactory) -> None: + async def test_actor_add_webhook(self: TestActorAddWebhook, make_actor: ActorFactory) -> None: async def main_server() -> None: import os from http.server import BaseHTTPRequestHandler, HTTPServer @@ -363,13 +373,14 @@ async def main_server() -> None: webhook_body = '' async with Actor: + class WebhookHandler(BaseHTTPRequestHandler): - def do_GET(self) -> None: # noqa: N802 + def do_GET(self) -> None: # noqa: N802, ANN101 self.send_response(200) self.end_headers() self.wfile.write(bytes('Hello, world!', encoding='utf-8')) - def do_POST(self) -> None: # noqa: N802 + def do_POST(self) -> None: # noqa: N802, ANN101 nonlocal webhook_body content_length = self.headers.get('content-length') length = int(content_length) if content_length else 0 @@ -382,7 +393,7 @@ def do_POST(self) -> None: # noqa: N802 container_port = int(os.getenv(ActorEnvVars.WEB_SERVER_PORT, '')) with HTTPServer(('', container_port), WebhookHandler) as server: - await Actor.set_value('INITIALIZED', True) + await Actor.set_value('INITIALIZED', value=True) while not webhook_body: server.handle_request() @@ -390,9 +401,10 @@ def do_POST(self) -> None: # noqa: N802 async def main_client() -> None: from apify_shared.consts import WebhookEventType + async with Actor: - input = await Actor.get_input() or {} - server_actor_container_url = str(input.get('server_actor_container_url')) + actor_input = await Actor.get_input() or {} + server_actor_container_url = str(actor_input.get('server_actor_container_url')) await Actor.add_webhook( event_types=[WebhookEventType.ACTOR_RUN_SUCCEEDED], diff --git a/tests/integration/test_actor_create_proxy_configuration.py b/tests/integration/test_actor_create_proxy_configuration.py index e555d750..50c5f78b 100644 --- a/tests/integration/test_actor_create_proxy_configuration.py +++ b/tests/integration/test_actor_create_proxy_configuration.py @@ -1,10 +1,18 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from apify import Actor -from .conftest import ActorFactory +if TYPE_CHECKING: + from .conftest import ActorFactory class TestActorCreateProxyConfiguration: - async def test_create_proxy_configuration_basic(self, make_actor: ActorFactory) -> None: + async def test_create_proxy_configuration_basic( + self: TestActorCreateProxyConfiguration, + make_actor: ActorFactory, + ) -> None: async def main() -> None: groups = ['SHADER'] country_code = 'US' @@ -26,25 +34,32 @@ async def main() -> None: assert run_result is not None assert run_result['status'] == 'SUCCEEDED' - async def test_create_proxy_configuration_complex(self, make_actor: ActorFactory) -> None: + async def test_create_proxy_configuration_complex( + self: TestActorCreateProxyConfiguration, + make_actor: ActorFactory, + ) -> None: async def main() -> None: await Actor.init() proxy_url_suffix = f'{Actor.config.proxy_password}@{Actor.config.proxy_hostname}:{Actor.config.proxy_port}' - proxy_configuration = await Actor.create_proxy_configuration(actor_proxy_input={ - 'useApifyProxy': True, - }) + proxy_configuration = await Actor.create_proxy_configuration( + actor_proxy_input={ + 'useApifyProxy': True, + } + ) assert proxy_configuration is not None assert await proxy_configuration.new_url() == f'http://auto:{proxy_url_suffix}' groups = ['SHADER', 'BUYPROXIES94952'] country_code = 'US' - proxy_configuration = await Actor.create_proxy_configuration(actor_proxy_input={ - 'useApifyProxy': True, - 'apifyProxyGroups': groups, - 'apifyProxyCountry': country_code, - }) + proxy_configuration = await Actor.create_proxy_configuration( + actor_proxy_input={ + 'useApifyProxy': True, + 'apifyProxyGroups': groups, + 'apifyProxyCountry': country_code, + } + ) assert proxy_configuration is not None assert await proxy_configuration.new_url() == f'http://groups-{"+".join(groups)},country-{country_code}:{proxy_url_suffix}' diff --git a/tests/integration/test_actor_dataset.py b/tests/integration/test_actor_dataset.py index a845b25c..1486dbca 100644 --- a/tests/integration/test_actor_dataset.py +++ b/tests/integration/test_actor_dataset.py @@ -1,15 +1,22 @@ -import pytest +from __future__ import annotations + +from typing import TYPE_CHECKING from apify import Actor -from apify_client import ApifyClientAsync from apify_shared.consts import ApifyEnvVars from ._utils import generate_unique_resource_name -from .conftest import ActorFactory + +if TYPE_CHECKING: + import pytest + + from apify_client import ApifyClientAsync + + from .conftest import ActorFactory class TestActorPushData: - async def test_push_data(self, make_actor: ActorFactory) -> None: + async def test_push_data(self: TestActorPushData, make_actor: ActorFactory) -> None: desired_item_count = 100 # Also change inside main() if you're changing this async def main() -> None: @@ -28,7 +35,7 @@ async def main() -> None: assert list_page.items[-1]['id'] == desired_item_count - 1 assert len(list_page.items) == list_page.count == desired_item_count - async def test_push_data_over_9mb(self, make_actor: ActorFactory) -> None: + async def test_push_data_over_9mb(self: TestActorPushData, make_actor: ActorFactory) -> None: async def main() -> None: async with Actor: await Actor.push_data([{'str': 'x' * 10000} for _ in range(5000)]) # ~50MB @@ -44,7 +51,7 @@ async def main() -> None: class TestActorOpenDataset: - async def test_same_references_default(self, make_actor: ActorFactory) -> None: + async def test_same_references_default(self: TestActorOpenDataset, make_actor: ActorFactory) -> None: async def main() -> None: async with Actor: dataset1 = await Actor.open_dataset() @@ -57,7 +64,7 @@ async def main() -> None: assert run_result is not None assert run_result['status'] == 'SUCCEEDED' - async def test_same_references_named(self, make_actor: ActorFactory) -> None: + async def test_same_references_named(self: TestActorOpenDataset, make_actor: ActorFactory) -> None: dataset_name = generate_unique_resource_name('dataset') async def main() -> None: @@ -81,7 +88,11 @@ async def main() -> None: assert run_result is not None assert run_result['status'] == 'SUCCEEDED' - async def test_force_cloud(self, apify_client_async: ApifyClientAsync, monkeypatch: pytest.MonkeyPatch) -> None: + async def test_force_cloud( + self: TestActorOpenDataset, + apify_client_async: ApifyClientAsync, + monkeypatch: pytest.MonkeyPatch, + ) -> None: assert apify_client_async.token is not None monkeypatch.setenv(ApifyEnvVars.TOKEN, apify_client_async.token) diff --git a/tests/integration/test_actor_events.py b/tests/integration/test_actor_events.py index da896c85..9e63c908 100644 --- a/tests/integration/test_actor_events.py +++ b/tests/integration/test_actor_events.py @@ -1,13 +1,17 @@ +from __future__ import annotations + import asyncio +from typing import TYPE_CHECKING from apify import Actor from apify_shared.consts import ActorEventTypes -from .conftest import ActorFactory +if TYPE_CHECKING: + from .conftest import ActorFactory class TestActorEvents: - async def test_interval_events(self, make_actor: ActorFactory) -> None: + async def test_interval_events(self: TestActorEvents, make_actor: ActorFactory) -> None: async def main() -> None: import os from datetime import datetime @@ -59,7 +63,7 @@ async def log_event(data: Any) -> None: assert len(persist_state_events) > 2 assert len(system_info_events) > 0 - async def test_off_event(self, make_actor: ActorFactory) -> None: + async def test_off_event(self: TestActorEvents, make_actor: ActorFactory) -> None: async def main() -> None: import os @@ -69,7 +73,7 @@ async def main() -> None: counter = 0 - def count_event(data): # type: ignore + def count_event(data): # type: ignore # noqa: ANN202, ANN001 nonlocal counter print(data) counter += 1 diff --git a/tests/integration/test_actor_key_value_store.py b/tests/integration/test_actor_key_value_store.py index 2242ed04..e1edb958 100644 --- a/tests/integration/test_actor_key_value_store.py +++ b/tests/integration/test_actor_key_value_store.py @@ -1,15 +1,22 @@ -import pytest +from __future__ import annotations + +from typing import TYPE_CHECKING from apify import Actor -from apify_client import ApifyClientAsync from apify_shared.consts import ApifyEnvVars from ._utils import generate_unique_resource_name -from .conftest import ActorFactory + +if TYPE_CHECKING: + import pytest + + from apify_client import ApifyClientAsync + + from .conftest import ActorFactory class TestActorOpenKeyValueStore: - async def test_same_references_default(self, make_actor: ActorFactory) -> None: + async def test_same_references_default(self: TestActorOpenKeyValueStore, make_actor: ActorFactory) -> None: async def main() -> None: async with Actor: kvs1 = await Actor.open_key_value_store() @@ -22,7 +29,7 @@ async def main() -> None: assert run_result is not None assert run_result['status'] == 'SUCCEEDED' - async def test_same_references_named(self, make_actor: ActorFactory) -> None: + async def test_same_references_named(self: TestActorOpenKeyValueStore, make_actor: ActorFactory) -> None: kvs_name = generate_unique_resource_name('key-value-store') async def main() -> None: @@ -46,7 +53,11 @@ async def main() -> None: assert run_result is not None assert run_result['status'] == 'SUCCEEDED' - async def test_force_cloud(self, apify_client_async: ApifyClientAsync, monkeypatch: pytest.MonkeyPatch) -> None: + async def test_force_cloud( + self: TestActorOpenKeyValueStore, + apify_client_async: ApifyClientAsync, + monkeypatch: pytest.MonkeyPatch, + ) -> None: assert apify_client_async.token is not None monkeypatch.setenv(ApifyEnvVars.TOKEN, apify_client_async.token) @@ -73,7 +84,7 @@ async def test_force_cloud(self, apify_client_async: ApifyClientAsync, monkeypat class TestActorGetSetValue: - async def test_actor_get_set_value_simple(self, make_actor: ActorFactory) -> None: + async def test_actor_get_set_value_simple(self: TestActorGetSetValue, make_actor: ActorFactory) -> None: async def main() -> None: async with Actor: await Actor.set_value('test', {'number': 123, 'string': 'a string', 'nested': {'test': 1}}) @@ -88,7 +99,7 @@ async def main() -> None: assert run_result is not None assert run_result['status'] == 'SUCCEEDED' - async def test_actor_get_set_value_complex(self, make_actor: ActorFactory) -> None: + async def test_actor_get_set_value_complex(self: TestActorGetSetValue, make_actor: ActorFactory) -> None: async def main_set() -> None: async with Actor: await Actor.set_value('test', {'number': 123, 'string': 'a string', 'nested': {'test': 1}}) @@ -126,7 +137,7 @@ async def main_get() -> None: class TestActorGetInput: - async def test_actor_get_input(self, make_actor: ActorFactory) -> None: + async def test_actor_get_input(self: TestActorGetInput, make_actor: ActorFactory) -> None: actor_source_files = { 'INPUT_SCHEMA.json': """ { @@ -161,18 +172,20 @@ async def main(): } actor = await make_actor('actor-get-input', source_files=actor_source_files) - run_result = await actor.call(run_input={ - 'number': 123, - 'string': 'a string', - 'nested': {'test': 1}, - 'password': 'very secret', - }) + run_result = await actor.call( + run_input={ + 'number': 123, + 'string': 'a string', + 'nested': {'test': 1}, + 'password': 'very secret', + } + ) assert run_result is not None assert run_result['status'] == 'SUCCEEDED' class TestGetPublicUrl: - async def test_get_public_url(self, make_actor: ActorFactory) -> None: + async def test_get_public_url(self: TestGetPublicUrl, make_actor: ActorFactory) -> None: async def main() -> None: async with Actor: public_api_url = Actor.config.api_public_base_url diff --git a/tests/integration/test_actor_lifecycle.py b/tests/integration/test_actor_lifecycle.py index f2374e12..ae517f90 100644 --- a/tests/integration/test_actor_lifecycle.py +++ b/tests/integration/test_actor_lifecycle.py @@ -1,10 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from apify import Actor -from .conftest import ActorFactory +if TYPE_CHECKING: + from .conftest import ActorFactory class TestActorInit: - async def test_actor_init(self, make_actor: ActorFactory) -> None: + async def test_actor_init(self: TestActorInit, make_actor: ActorFactory) -> None: async def main() -> None: my_actor = Actor() await my_actor.init() @@ -14,16 +19,16 @@ async def main() -> None: await my_actor.init() double_init = True except RuntimeError as err: - assert str(err) == 'The actor was already initialized!' - except Exception as err: - raise err + assert str(err) == 'The actor was already initialized!' # noqa: PT017 + except Exception: + raise try: await Actor.init() double_init = True except RuntimeError as err: - assert str(err) == 'The actor was already initialized!' - except Exception as err: - raise err + assert str(err) == 'The actor was already initialized!' # noqa: PT017 + except Exception: + raise await my_actor.exit() assert double_init is False assert my_actor._is_initialized is False @@ -35,7 +40,7 @@ async def main() -> None: assert run_result is not None assert run_result['status'] == 'SUCCEEDED' - async def test_async_with_actor_properly_initialize(self, make_actor: ActorFactory) -> None: + async def test_async_with_actor_properly_initialize(self: TestActorInit, make_actor: ActorFactory) -> None: async def main() -> None: async with Actor: assert Actor._get_default_instance()._is_initialized @@ -50,10 +55,10 @@ async def main() -> None: class TestActorExit: - async def test_actor_exit_code(self, make_actor: ActorFactory) -> None: + async def test_actor_exit_code(self: TestActorExit, make_actor: ActorFactory) -> None: async def main() -> None: async with Actor: - input = await Actor.get_input() + input = await Actor.get_input() # noqa: A001 await Actor.exit(**input) actor = await make_actor('actor-exit', main_func=main) @@ -66,10 +71,10 @@ async def main() -> None: class TestActorFail: - async def test_fail_exit_code(self, make_actor: ActorFactory) -> None: + async def test_fail_exit_code(self: TestActorFail, make_actor: ActorFactory) -> None: async def main() -> None: async with Actor: - input = await Actor.get_input() + input = await Actor.get_input() # noqa: A001 await Actor.fail(**input) if input else await Actor.fail() actor = await make_actor('actor-fail', main_func=main) @@ -91,10 +96,10 @@ async def main() -> None: assert run_result['status'] == 'FAILED' assert run_result.get('statusMessage') == 'This is a test message' - async def test_with_actor_fail_correctly(self, make_actor: ActorFactory) -> None: + async def test_with_actor_fail_correctly(self: TestActorFail, make_actor: ActorFactory) -> None: async def main() -> None: async with Actor: - raise Exception('This is a test exception') + raise Exception('This is a test exception') # noqa: TRY002 actor = await make_actor('with-actor-fail', main_func=main) run_result = await actor.call() @@ -104,13 +109,13 @@ async def main() -> None: class TestActorMain: - async def test_actor_main(self, make_actor: ActorFactory) -> None: + async def test_actor_main(self: TestActorMain, make_actor: ActorFactory) -> None: async def main() -> None: async def actor_function() -> None: - input = await Actor.get_input() + input = await Actor.get_input() # noqa: A001 if input.get('raise_exception'): - raise Exception(input.get('raise_exception')) - elif input.get('exit_code'): + raise Exception(input.get('raise_exception')) # noqa: TRY002 + if input.get('exit_code'): await Actor.exit(exit_code=input.get('exit_code')) elif input.get('fail'): await Actor.fail() diff --git a/tests/integration/test_actor_log.py b/tests/integration/test_actor_log.py index 3abc9c2e..f07598e6 100644 --- a/tests/integration/test_actor_log.py +++ b/tests/integration/test_actor_log.py @@ -1,10 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from apify import Actor, __version__ -from .conftest import ActorFactory +if TYPE_CHECKING: + from .conftest import ActorFactory class TestActorLog: - async def test_actor_log(self, make_actor: ActorFactory) -> None: + async def test_actor_log(self: TestActorLog, make_actor: ActorFactory) -> None: async def main() -> None: import logging diff --git a/tests/integration/test_actor_request_queue.py b/tests/integration/test_actor_request_queue.py index 771b08b7..d559fc73 100644 --- a/tests/integration/test_actor_request_queue.py +++ b/tests/integration/test_actor_request_queue.py @@ -1,15 +1,22 @@ -import pytest +from __future__ import annotations + +from typing import TYPE_CHECKING from apify import Actor -from apify_client import ApifyClientAsync from apify_shared.consts import ApifyEnvVars from ._utils import generate_unique_resource_name -from .conftest import ActorFactory + +if TYPE_CHECKING: + import pytest + + from apify_client import ApifyClientAsync + + from .conftest import ActorFactory class TestActorOpenRequestQueue: - async def test_same_references_default(self, make_actor: ActorFactory) -> None: + async def test_same_references_default(self: TestActorOpenRequestQueue, make_actor: ActorFactory) -> None: async def main() -> None: async with Actor: rq1 = await Actor.open_request_queue() @@ -22,7 +29,7 @@ async def main() -> None: assert run_result is not None assert run_result['status'] == 'SUCCEEDED' - async def test_same_references_named(self, make_actor: ActorFactory) -> None: + async def test_same_references_named(self: TestActorOpenRequestQueue, make_actor: ActorFactory) -> None: rq_name = generate_unique_resource_name('request-queue') async def main() -> None: @@ -46,7 +53,11 @@ async def main() -> None: assert run_result is not None assert run_result['status'] == 'SUCCEEDED' - async def test_force_cloud(self, apify_client_async: ApifyClientAsync, monkeypatch: pytest.MonkeyPatch) -> None: + async def test_force_cloud( + self: TestActorOpenRequestQueue, + apify_client_async: ApifyClientAsync, + monkeypatch: pytest.MonkeyPatch, + ) -> None: assert apify_client_async.token is not None monkeypatch.setenv(ApifyEnvVars.TOKEN, apify_client_async.token) diff --git a/tests/integration/test_fixtures.py b/tests/integration/test_fixtures.py index a26e3018..c5c67a4d 100644 --- a/tests/integration/test_fixtures.py +++ b/tests/integration/test_fixtures.py @@ -1,14 +1,19 @@ +from __future__ import annotations + from datetime import datetime, timezone +from typing import TYPE_CHECKING from apify import Actor from apify._crypto import crypto_random_object_id -from apify_client import ApifyClientAsync -from .conftest import ActorFactory +if TYPE_CHECKING: + from apify_client import ApifyClientAsync + + from .conftest import ActorFactory class TestMakeActorFixture: - async def test_main_func(self, make_actor: ActorFactory) -> None: + async def test_main_func(self: TestMakeActorFixture, make_actor: ActorFactory) -> None: async def main() -> None: import os @@ -28,7 +33,7 @@ async def main() -> None: assert output_record is not None assert run_result['actId'] == output_record['value'] - async def test_main_py(self, make_actor: ActorFactory) -> None: + async def test_main_py(self: TestMakeActorFixture, make_actor: ActorFactory) -> None: expected_output = f'ACTOR_OUTPUT_{crypto_random_object_id(5)}' main_py_source = f""" import asyncio @@ -48,7 +53,7 @@ async def main(): assert output_record is not None assert output_record['value'] == expected_output - async def test_source_files(self, make_actor: ActorFactory) -> None: + async def test_source_files(self: TestMakeActorFixture, make_actor: ActorFactory) -> None: test_started_at = datetime.now(timezone.utc) actor_source_files = { 'src/utils.py': """ @@ -83,5 +88,8 @@ async def main(): class TestApifyClientAsyncFixture: - async def test_apify_client_async_works(self, apify_client_async: ApifyClientAsync) -> None: + async def test_apify_client_async_works( + self: TestApifyClientAsyncFixture, + apify_client_async: ApifyClientAsync, + ) -> None: assert await apify_client_async.user('me').get() is not None diff --git a/tests/integration/test_request_queue.py b/tests/integration/test_request_queue.py index 2fbd9c63..9e81aa43 100644 --- a/tests/integration/test_request_queue.py +++ b/tests/integration/test_request_queue.py @@ -1,10 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from apify import Actor -from .conftest import ActorFactory +if TYPE_CHECKING: + from .conftest import ActorFactory class TestRequestQueue: - async def test_simple(self, make_actor: ActorFactory) -> None: + async def test_simple(self: TestRequestQueue, make_actor: ActorFactory) -> None: async def main() -> None: async with Actor: desired_request_count = 100 @@ -14,9 +19,7 @@ async def main() -> None: # Add some requests for i in range(desired_request_count): print(f'Adding request {i}...') - await rq.add_request({ - 'url': f'https://example.com/{i}', - }) + await rq.add_request({'url': f'https://example.com/{i}'}) handled_request_count = 0 while next_request := await rq.fetch_next_request(): diff --git a/tests/unit/actor/test_actor_create_proxy_configuration.py b/tests/unit/actor/test_actor_create_proxy_configuration.py index 32cea4b0..29c6e928 100644 --- a/tests/unit/actor/test_actor_create_proxy_configuration.py +++ b/tests/unit/actor/test_actor_create_proxy_configuration.py @@ -1,30 +1,31 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import httpx import pytest -from respx import MockRouter from apify import Actor from apify_client import ApifyClientAsync from apify_shared.consts import ApifyEnvVars -from ..conftest import ApifyClientAsyncPatcher +if TYPE_CHECKING: + from respx import MockRouter + + from ..conftest import ApifyClientAsyncPatcher DUMMY_PASSWORD = 'DUMMY_PASSWORD' -@pytest.fixture +@pytest.fixture() def patched_apify_client(apify_client_async_patcher: ApifyClientAsyncPatcher) -> ApifyClientAsync: - apify_client_async_patcher.patch('user', 'get', return_value={ - 'proxy': { - 'password': DUMMY_PASSWORD, - }, - }) - + apify_client_async_patcher.patch('user', 'get', return_value={'proxy': {'password': DUMMY_PASSWORD}}) return ApifyClientAsync() class TestActorCreateProxyConfiguration: async def test_create_proxy_configuration_basic( - self, + self: TestActorCreateProxyConfiguration, monkeypatch: pytest.MonkeyPatch, respx_mock: MockRouter, patched_apify_client: ApifyClientAsync, @@ -34,21 +35,23 @@ async def test_create_proxy_configuration_basic( monkeypatch.setenv(ApifyEnvVars.PROXY_STATUS_URL.value, dummy_proxy_status_url) route = respx_mock.get(dummy_proxy_status_url) - route.mock(httpx.Response(200, json={ - 'connected': True, - 'connectionError': None, - 'isManInTheMiddle': True, - })) + route.mock( + httpx.Response( + 200, + json={ + 'connected': True, + 'connectionError': None, + 'isManInTheMiddle': True, + }, + ) + ) groups = ['GROUP1', 'GROUP2'] country_code = 'US' await Actor.init() - proxy_configuration = await Actor.create_proxy_configuration( - groups=groups, - country_code=country_code, - ) + proxy_configuration = await Actor.create_proxy_configuration(groups=groups, country_code=country_code) assert proxy_configuration is not None assert proxy_configuration._groups == groups @@ -61,7 +64,7 @@ async def test_create_proxy_configuration_basic( await Actor.exit() async def test_create_proxy_configuration_actor_proxy_input( - self, + self: TestActorCreateProxyConfiguration, monkeypatch: pytest.MonkeyPatch, respx_mock: MockRouter, patched_apify_client: ApifyClientAsync, @@ -73,47 +76,62 @@ async def test_create_proxy_configuration_actor_proxy_input( monkeypatch.setenv(ApifyEnvVars.PROXY_STATUS_URL.value, dummy_proxy_status_url) route = respx_mock.get(dummy_proxy_status_url) - route.mock(httpx.Response(200, json={ - 'connected': True, - 'connectionError': None, - 'isManInTheMiddle': True, - })) + route.mock( + httpx.Response( + 200, + json={ + 'connected': True, + 'connectionError': None, + 'isManInTheMiddle': True, + }, + ) + ) await Actor.init() proxy_configuration = await Actor.create_proxy_configuration(actor_proxy_input={}) assert proxy_configuration is None - proxy_configuration = await Actor.create_proxy_configuration(actor_proxy_input={ - 'useApifyProxy': False, - }) + proxy_configuration = await Actor.create_proxy_configuration( + actor_proxy_input={ + 'useApifyProxy': False, + } + ) assert proxy_configuration is None - proxy_configuration = await Actor.create_proxy_configuration(actor_proxy_input={ - 'proxyUrls': [], - }) + proxy_configuration = await Actor.create_proxy_configuration( + actor_proxy_input={ + 'proxyUrls': [], + } + ) assert proxy_configuration is None - proxy_configuration = await Actor.create_proxy_configuration(actor_proxy_input={ - 'useApifyProxy': False, - 'proxyUrls': [dummy_proxy_url], - }) + proxy_configuration = await Actor.create_proxy_configuration( + actor_proxy_input={ + 'useApifyProxy': False, + 'proxyUrls': [dummy_proxy_url], + } + ) assert proxy_configuration is not None assert await proxy_configuration.new_url() == dummy_proxy_url - proxy_configuration = await Actor.create_proxy_configuration(actor_proxy_input={ - 'useApifyProxy': True, - }) + proxy_configuration = await Actor.create_proxy_configuration( + actor_proxy_input={ + 'useApifyProxy': True, + } + ) assert proxy_configuration is not None assert await proxy_configuration.new_url() == f'http://auto:{DUMMY_PASSWORD}@proxy.apify.com:8000' groups = ['GROUP1', 'GROUP2'] country_code = 'US' - proxy_configuration = await Actor.create_proxy_configuration(actor_proxy_input={ - 'useApifyProxy': True, - 'apifyProxyGroups': groups, - 'apifyProxyCountry': country_code, - }) + proxy_configuration = await Actor.create_proxy_configuration( + actor_proxy_input={ + 'useApifyProxy': True, + 'apifyProxyGroups': groups, + 'apifyProxyCountry': country_code, + } + ) assert proxy_configuration is not None assert await proxy_configuration.new_url() == f'http://groups-{"+".join(groups)},country-{country_code}:{DUMMY_PASSWORD}@proxy.apify.com:8000' diff --git a/tests/unit/actor/test_actor_dataset.py b/tests/unit/actor/test_actor_dataset.py index 3643eb00..beb294b5 100644 --- a/tests/unit/actor/test_actor_dataset.py +++ b/tests/unit/actor/test_actor_dataset.py @@ -1,19 +1,25 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING import pytest from apify import Actor -from apify._memory_storage import MemoryStorageClient from apify_shared.consts import ActorEnvVars -# NOTE: We only test the dataset methond available on Actor class/instance. Actual tests for the implementations are in storages/. +if TYPE_CHECKING: + from apify._memory_storage import MemoryStorageClient + +# NOTE: We only test the dataset methods available on Actor class/instance. +# Actual tests for the implementations are in storages/. class TestActorOpenDataset: - async def test_throws_without_init(self) -> None: + async def test_throws_without_init(self: TestActorOpenDataset) -> None: with pytest.raises(RuntimeError): await Actor.open_dataset() - async def test_same_references(self) -> None: + async def test_same_references(self: TestActorOpenDataset) -> None: async with Actor: dataset1 = await Actor.open_dataset() dataset2 = await Actor.open_dataset() @@ -31,7 +37,7 @@ async def test_same_references(self) -> None: assert dataset_by_id_2 is dataset_by_id_1 async def test_open_datatset_based_env_var( - self, + self: TestActorOpenDataset, monkeypatch: pytest.MonkeyPatch, memory_storage_client: MemoryStorageClient, ) -> None: @@ -45,7 +51,7 @@ async def test_open_datatset_based_env_var( class TestActorPushData: - async def test_push_data(self) -> None: + async def test_push_data(self: TestActorPushData) -> None: async with Actor() as my_actor: dataset = await my_actor.open_dataset() desired_item_count = 100 diff --git a/tests/unit/actor/test_actor_env_helpers.py b/tests/unit/actor/test_actor_env_helpers.py index 3dc90091..f95f7c94 100644 --- a/tests/unit/actor/test_actor_env_helpers.py +++ b/tests/unit/actor/test_actor_env_helpers.py @@ -1,21 +1,24 @@ +from __future__ import annotations + import random import string from datetime import datetime, timezone -from typing import Any, Dict - -import pytest +from typing import TYPE_CHECKING, Any from apify import Actor from apify_shared.consts import BOOL_ENV_VARS, DATETIME_ENV_VARS, FLOAT_ENV_VARS, INTEGER_ENV_VARS, STRING_ENV_VARS, ActorEnvVars, ApifyEnvVars +if TYPE_CHECKING: + import pytest + class TestIsAtHome: - async def test_is_at_home_local(self) -> None: + async def test_is_at_home_local(self: TestIsAtHome) -> None: async with Actor as actor: is_at_home = actor.is_at_home() assert is_at_home is False - async def test_is_at_home_on_apify(self, monkeypatch: pytest.MonkeyPatch) -> None: + async def test_is_at_home_on_apify(self: TestIsAtHome, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv(ApifyEnvVars.IS_AT_HOME, 'true') async with Actor as actor: is_at_home = actor.is_at_home() @@ -23,9 +26,9 @@ async def test_is_at_home_on_apify(self, monkeypatch: pytest.MonkeyPatch) -> Non class TestGetEnv: - async def test_get_env_use_env_vars(self, monkeypatch: pytest.MonkeyPatch) -> None: + async def test_get_env_use_env_vars(self: TestGetEnv, monkeypatch: pytest.MonkeyPatch) -> None: # Set up random env vars - expected_get_env: Dict[str, Any] = {} + expected_get_env: dict[str, Any] = {} for int_env_var in INTEGER_ENV_VARS: int_get_env_var = int_env_var.name.lower() expected_get_env[int_get_env_var] = random.randint(1, 99999) diff --git a/tests/unit/actor/test_actor_helpers.py b/tests/unit/actor/test_actor_helpers.py index 3be55215..daaed48b 100644 --- a/tests/unit/actor/test_actor_helpers.py +++ b/tests/unit/actor/test_actor_helpers.py @@ -1,14 +1,19 @@ -import pytest +from __future__ import annotations + +from typing import TYPE_CHECKING from apify import Actor from apify_client import ApifyClientAsync from apify_shared.consts import ApifyEnvVars, WebhookEventType -from ..conftest import ApifyClientAsyncPatcher +if TYPE_CHECKING: + import pytest + + from ..conftest import ApifyClientAsyncPatcher class TestActorNewClient: - async def test_actor_new_client_config(self, monkeypatch: pytest.MonkeyPatch) -> None: + async def test_actor_new_client_config(self: TestActorNewClient, monkeypatch: pytest.MonkeyPatch) -> None: token = 'my-token' monkeypatch.setenv(ApifyEnvVars.TOKEN, token) my_actor = Actor() @@ -27,7 +32,10 @@ async def test_actor_new_client_config(self, monkeypatch: pytest.MonkeyPatch) -> class TestActorCallStartAbortActor: - async def test_actor_call(self, apify_client_async_patcher: ApifyClientAsyncPatcher) -> None: + async def test_actor_call( + self: TestActorCallStartAbortActor, + apify_client_async_patcher: ApifyClientAsyncPatcher, + ) -> None: apify_client_async_patcher.patch('actor', 'call', return_value=None) actor_id = 'some-actor-id' my_actor = Actor() @@ -40,7 +48,10 @@ async def test_actor_call(self, apify_client_async_patcher: ApifyClientAsyncPatc await my_actor.exit() - async def test_actor_call_task(self, apify_client_async_patcher: ApifyClientAsyncPatcher) -> None: + async def test_actor_call_task( + self: TestActorCallStartAbortActor, + apify_client_async_patcher: ApifyClientAsyncPatcher, + ) -> None: apify_client_async_patcher.patch('task', 'call', return_value=None) task_id = 'some-task-id' my_actor = Actor() @@ -52,7 +63,10 @@ async def test_actor_call_task(self, apify_client_async_patcher: ApifyClientAsyn await my_actor.exit() - async def test_actor_start(self, apify_client_async_patcher: ApifyClientAsyncPatcher) -> None: + async def test_actor_start( + self: TestActorCallStartAbortActor, + apify_client_async_patcher: ApifyClientAsyncPatcher, + ) -> None: apify_client_async_patcher.patch('actor', 'start', return_value=None) actor_id = 'some-id' my_actor = Actor() @@ -64,7 +78,10 @@ async def test_actor_start(self, apify_client_async_patcher: ApifyClientAsyncPat await my_actor.exit() - async def test_actor_abort(self, apify_client_async_patcher: ApifyClientAsyncPatcher) -> None: + async def test_actor_abort( + self: TestActorCallStartAbortActor, + apify_client_async_patcher: ApifyClientAsyncPatcher, + ) -> None: apify_client_async_patcher.patch('run', 'abort', return_value=None) run_id = 'some-run-id' my_actor = Actor() @@ -80,28 +97,40 @@ async def test_actor_abort(self, apify_client_async_patcher: ApifyClientAsyncPat class TestActorMethodsWorksOnlyOnPlatform: # NOTE: These medhods will be tested properly using integrations tests. - async def test_actor_metamorpth_not_work_locally(self, caplog: pytest.LogCaptureFixture) -> None: + async def test_actor_metamorpth_not_work_locally( + self: TestActorMethodsWorksOnlyOnPlatform, + caplog: pytest.LogCaptureFixture, + ) -> None: async with Actor() as my_actor: await my_actor.metamorph('random-id') assert len(caplog.records) == 1 assert caplog.records[0].levelname == 'ERROR' assert 'Actor.metamorph() is only supported when running on the Apify platform.' in caplog.records[0].message - async def test_actor_reboot_not_work_locally(self, caplog: pytest.LogCaptureFixture) -> None: + async def test_actor_reboot_not_work_locally( + self: TestActorMethodsWorksOnlyOnPlatform, + caplog: pytest.LogCaptureFixture, + ) -> None: async with Actor() as my_actor: await my_actor.reboot() assert len(caplog.records) == 1 assert caplog.records[0].levelname == 'ERROR' assert 'Actor.reboot() is only supported when running on the Apify platform.' in caplog.records[0].message - async def test_actor_add_webhook_not_work_locally(self, caplog: pytest.LogCaptureFixture) -> None: + async def test_actor_add_webhook_not_work_locally( + self: TestActorMethodsWorksOnlyOnPlatform, + caplog: pytest.LogCaptureFixture, + ) -> None: async with Actor() as my_actor: await my_actor.add_webhook(event_types=[WebhookEventType.ACTOR_BUILD_ABORTED], request_url='https://example.com') assert len(caplog.records) == 1 assert caplog.records[0].levelname == 'ERROR' assert 'Actor.add_webhook() is only supported when running on the Apify platform.' in caplog.records[0].message - async def test_actor_set_status_message_mock_locally(self, caplog: pytest.LogCaptureFixture) -> None: + async def test_actor_set_status_message_mock_locally( + self: TestActorMethodsWorksOnlyOnPlatform, + caplog: pytest.LogCaptureFixture, + ) -> None: caplog.set_level('INFO') async with Actor() as my_actor: await my_actor.set_status_message('test-status-message') @@ -110,7 +139,10 @@ async def test_actor_set_status_message_mock_locally(self, caplog: pytest.LogCap assert matching_records[0].levelname == 'INFO' assert '[Status message]: test-status-message' in matching_records[0].message - async def test_actor_set_status_message_terminal_mock_locally(self, caplog: pytest.LogCaptureFixture) -> None: + async def test_actor_set_status_message_terminal_mock_locally( + self: TestActorMethodsWorksOnlyOnPlatform, + caplog: pytest.LogCaptureFixture, + ) -> None: caplog.set_level('INFO') async with Actor() as my_actor: await my_actor.fail(status_message='test-terminal-message') diff --git a/tests/unit/actor/test_actor_key_value_store.py b/tests/unit/actor/test_actor_key_value_store.py index eb6a0e19..bb9453ae 100644 --- a/tests/unit/actor/test_actor_key_value_store.py +++ b/tests/unit/actor/test_actor_key_value_store.py @@ -1,18 +1,25 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import pytest from apify import Actor from apify._crypto import public_encrypt -from apify._memory_storage import MemoryStorageClient from apify.consts import ENCRYPTED_INPUT_VALUE_PREFIX from apify_shared.consts import ApifyEnvVars from apify_shared.utils import json_dumps from ..test_crypto import PRIVATE_KEY_PASSWORD, PRIVATE_KEY_PEM_BASE64, PUBLIC_KEY +if TYPE_CHECKING: + from apify._memory_storage import MemoryStorageClient + -# NOTE: We only test the key-value store methond available on Actor class/instance. Actual tests for the implementations are in storages/. +# NOTE: We only test the key-value store methods available on Actor class/instance. +# Actual tests for the implementations are in storages/. class TestOpenKeyValueStore: - async def test_same_references(self) -> None: + async def test_same_references(self: TestOpenKeyValueStore) -> None: async with Actor: kvs1 = await Actor.open_key_value_store() kvs2 = await Actor.open_key_value_store() @@ -30,11 +37,11 @@ async def test_same_references(self) -> None: class TestKeyValueStoreOnActor: - async def test_throws_without_init(self) -> None: + async def test_throws_without_init(self: TestKeyValueStoreOnActor) -> None: with pytest.raises(RuntimeError): await Actor.open_key_value_store() - async def test_get_set_value(self) -> None: + async def test_get_set_value(self: TestKeyValueStoreOnActor) -> None: test_key = 'test_key' test_value = 'test_value' test_content_type = 'text/plain' @@ -43,7 +50,7 @@ async def test_get_set_value(self) -> None: value = await my_actor.get_value(key=test_key) assert value == test_value - async def test_get_input(self, memory_storage_client: MemoryStorageClient) -> None: + async def test_get_input(self: TestKeyValueStoreOnActor, memory_storage_client: MemoryStorageClient) -> None: input_key = 'INPUT' test_input = {'foo': 'bar'} @@ -55,10 +62,14 @@ async def test_get_input(self, memory_storage_client: MemoryStorageClient) -> No ) async with Actor() as my_actor: - input = await my_actor.get_input() + input = await my_actor.get_input() # noqa: A001 assert input['foo'] == test_input['foo'] - async def test_get_input_with_secrets(self, monkeypatch: pytest.MonkeyPatch, memory_storage_client: MemoryStorageClient) -> None: + async def test_get_input_with_secrets( + self: TestKeyValueStoreOnActor, + monkeypatch: pytest.MonkeyPatch, + memory_storage_client: MemoryStorageClient, + ) -> None: monkeypatch.setenv(ApifyEnvVars.INPUT_SECRETS_PRIVATE_KEY_FILE, PRIVATE_KEY_PEM_BASE64) monkeypatch.setenv(ApifyEnvVars.INPUT_SECRETS_PRIVATE_KEY_PASSPHRASE, PRIVATE_KEY_PASSWORD) @@ -78,6 +89,6 @@ async def test_get_input_with_secrets(self, monkeypatch: pytest.MonkeyPatch, mem ) async with Actor() as my_actor: - input = await my_actor.get_input() + input = await my_actor.get_input() # noqa: A001 assert input['foo'] == input_with_secret['foo'] assert input['secret'] == secret_string diff --git a/tests/unit/actor/test_actor_lifecycle.py b/tests/unit/actor/test_actor_lifecycle.py index 5d4d32de..0dfbf018 100644 --- a/tests/unit/actor/test_actor_lifecycle.py +++ b/tests/unit/actor/test_actor_lifecycle.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import contextlib from datetime import datetime @@ -11,12 +13,12 @@ class TestActorInit: - async def test_async_with_actor_properly_initialize(self) -> None: + async def test_async_with_actor_properly_initialize(self: TestActorInit) -> None: async with Actor: assert Actor._get_default_instance()._is_initialized assert Actor._get_default_instance()._is_initialized is False - async def test_actor_init(self) -> None: + async def test_actor_init(self: TestActorInit) -> None: my_actor = Actor() await my_actor.init() @@ -25,7 +27,7 @@ async def test_actor_init(self) -> None: await my_actor.exit() assert my_actor._is_initialized is False - async def test_double_init(self) -> None: + async def test_double_init(self: TestActorInit) -> None: my_actor = Actor() await my_actor.init() @@ -40,7 +42,7 @@ async def test_double_init(self) -> None: class TestActorExit: - async def test_with_actor_exit(self, monkeypatch: pytest.MonkeyPatch) -> None: + async def test_with_actor_exit(self: TestActorExit, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv(ApifyEnvVars.SYSTEM_INFO_INTERVAL_MILLIS, '100') monkeypatch.setenv(ApifyEnvVars.PERSIST_STATE_INTERVAL_MILLIS, '100') on_persist = [] @@ -51,7 +53,7 @@ def on_event(event_type: ActorEventTypes) -> Callable: nonlocal on_system_info if event_type == ActorEventTypes.PERSIST_STATE: return lambda data: on_persist.append(data) - elif event_type == ActorEventTypes.SYSTEM_INFO: + if event_type == ActorEventTypes.SYSTEM_INFO: return lambda data: on_system_info.append(data) return lambda data: print(data) @@ -74,36 +76,36 @@ def on_event(event_type: ActorEventTypes) -> Callable: # Check `createdAt` is a datetime (so it's the same locally and on platform) assert isinstance(on_system_info[0]['createdAt'], datetime) - async def test_raise_on_exit_witout_init(self) -> None: + async def test_raise_on_exit_witout_init(self: TestActorExit) -> None: with pytest.raises(RuntimeError): await Actor.exit() class TestActorFail: - async def test_with_actor_fail(self) -> None: + async def test_with_actor_fail(self: TestActorFail) -> None: async with Actor() as my_actor: assert my_actor._is_initialized await my_actor.fail() assert my_actor._is_initialized is False - async def test_with_actor_failed(self) -> None: + async def test_with_actor_failed(self: TestActorFail) -> None: with contextlib.suppress(Exception): async with Actor() as my_actor: assert my_actor._is_initialized - raise Exception('Failed') + raise Exception('Failed') # noqa: TRY002 assert my_actor._is_initialized is False - async def test_raise_on_fail_without_init(self) -> None: + async def test_raise_on_fail_without_init(self: TestActorFail) -> None: with pytest.raises(RuntimeError): await Actor.fail() - async def test_actor_reboot_not_work_locally(self) -> None: + async def test_actor_reboot_not_work_locally(self: TestActorFail) -> None: with pytest.raises(RuntimeError): await Actor.reboot() class TestActorMainMethod: - async def test_actor_main_method(self) -> None: + async def test_actor_main_method(self: TestActorMainMethod) -> None: my_actor = Actor() main_was_called = False @@ -111,11 +113,12 @@ async def actor_function() -> None: nonlocal main_was_called main_was_called = True assert my_actor._is_initialized + await my_actor.main(actor_function) assert my_actor._is_initialized is False assert main_was_called - async def test_actor_main_method_throw_exception(self) -> None: + async def test_actor_main_method_throw_exception(self: TestActorMainMethod) -> None: my_actor = Actor() err = Exception('Failed') my_actor.fail = AsyncMock() # type: ignore @@ -131,7 +134,7 @@ async def actor_function() -> None: # This is necessary to stop the event emitting intervals await my_actor.exit() - async def test_actor_main_method_raise_return_value(self) -> None: + async def test_actor_main_method_raise_return_value(self: TestActorMainMethod) -> None: my_actor = Actor() expected_string = 'Hello world' @@ -144,7 +147,7 @@ async def actor_function() -> str: class TestMigratingEvent: - async def test_migrating_event(self, monkeypatch: pytest.MonkeyPatch) -> None: + async def test_migrating_event(self: TestMigratingEvent, monkeypatch: pytest.MonkeyPatch) -> None: # This should test whether when you get a MIGRATING event, # the actor automatically emits the PERSIST_STATE event with data `{'isMigrating': True}` monkeypatch.setenv(ApifyEnvVars.PERSIST_STATE_INTERVAL_MILLIS, '500') diff --git a/tests/unit/actor/test_actor_log.py b/tests/unit/actor/test_actor_log.py index 8aef3ff1..38a7281c 100644 --- a/tests/unit/actor/test_actor_log.py +++ b/tests/unit/actor/test_actor_log.py @@ -1,16 +1,20 @@ +from __future__ import annotations + import contextlib import logging import sys - -import pytest +from typing import TYPE_CHECKING from apify import Actor, __version__ from apify.log import logger from apify_client import __version__ as apify_client_version +if TYPE_CHECKING: + import pytest + class TestActorLog: - async def test_actor_log(self, caplog: pytest.LogCaptureFixture) -> None: + async def test_actor_log(self: TestActorLog, caplog: pytest.LogCaptureFixture) -> None: caplog.set_level(logging.DEBUG, logger='apify') with contextlib.suppress(RuntimeError): async with Actor: diff --git a/tests/unit/actor/test_actor_memory_storage_e2e.py b/tests/unit/actor/test_actor_memory_storage_e2e.py index c1f7e1e0..0721685d 100644 --- a/tests/unit/actor/test_actor_memory_storage_e2e.py +++ b/tests/unit/actor/test_actor_memory_storage_e2e.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime, timezone from typing import Callable @@ -11,7 +13,7 @@ @pytest.mark.parametrize('purge_on_start', [True, False]) async def test_actor_memory_storage_client_key_value_store_e2e( monkeypatch: pytest.MonkeyPatch, - purge_on_start: bool, + purge_on_start: bool, # noqa: FBT001 reset_default_instances: Callable[[], None], ) -> None: """This test simulates two clean runs using memory storage. @@ -51,7 +53,7 @@ async def test_actor_memory_storage_client_key_value_store_e2e( @pytest.mark.parametrize('purge_on_start', [True, False]) async def test_actor_memory_storage_client_request_queue_e2e( monkeypatch: pytest.MonkeyPatch, - purge_on_start: bool, + purge_on_start: bool, # noqa: FBT001 reset_default_instances: Callable[[], None], ) -> None: """This test simulates two clean runs using memory storage. @@ -66,11 +68,14 @@ async def test_actor_memory_storage_client_request_queue_e2e( request_url = f'http://example.com/{i}' forefront = i % 3 == 1 was_handled = i % 3 == 2 - await default_queue.add_request({ - 'uniqueKey': str(i), - 'url': request_url, - 'handledAt': datetime.now(timezone.utc) if was_handled else None, - }, forefront=forefront) + await default_queue.add_request( + { + 'uniqueKey': str(i), + 'url': request_url, + 'handledAt': datetime.now(timezone.utc) if was_handled else None, + }, + forefront=forefront, + ) # We simulate another clean run, we expect the memory storage to read from the local data directory # Default storages are purged based on purge_on_start parameter. @@ -83,11 +88,14 @@ async def test_actor_memory_storage_client_request_queue_e2e( request_url = f'http://example.com/{i}' forefront = i % 3 == 1 was_handled = i % 3 == 2 - await default_queue.add_request({ - 'uniqueKey': str(i), - 'url': request_url, - 'handledAt': datetime.now(timezone.utc) if was_handled else None, - }, forefront=forefront) + await default_queue.add_request( + { + 'uniqueKey': str(i), + 'url': request_url, + 'handledAt': datetime.now(timezone.utc) if was_handled else None, + }, + forefront=forefront, + ) queue_info = await default_queue.get_info() assert queue_info is not None diff --git a/tests/unit/actor/test_actor_request_queue.py b/tests/unit/actor/test_actor_request_queue.py index b7f672d8..de58d26c 100644 --- a/tests/unit/actor/test_actor_request_queue.py +++ b/tests/unit/actor/test_actor_request_queue.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from apify import Actor @@ -6,11 +8,11 @@ class TestActorOpenRequestQueue: - async def test_throws_without_init(self) -> None: + async def test_throws_without_init(self: TestActorOpenRequestQueue) -> None: with pytest.raises(RuntimeError): await Actor.open_request_queue() - async def test_same_references(self) -> None: + async def test_same_references(self: TestActorOpenRequestQueue) -> None: async with Actor: rq1 = await Actor.open_request_queue() rq2 = await Actor.open_request_queue() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index cff65294..eac8ae8f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import asyncio import inspect from collections import defaultdict -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, get_type_hints +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Callable, get_type_hints import pytest @@ -13,8 +15,11 @@ from apify_client.client import ApifyClientAsync from apify_shared.consts import ApifyEnvVars +if TYPE_CHECKING: + from pathlib import Path + -@pytest.fixture +@pytest.fixture() def reset_default_instances(monkeypatch: pytest.MonkeyPatch) -> Callable[[], None]: def reset() -> None: monkeypatch.setattr(Actor, '_default_instance', None) @@ -42,18 +47,18 @@ def _reset_and_patch_default_instances(monkeypatch: pytest.MonkeyPatch, tmp_path # This class is used to patch the ApifyClientAsync methods to return a fixed value or be replaced with another method. class ApifyClientAsyncPatcher: - def __init__(self, monkeypatch: pytest.MonkeyPatch) -> None: + def __init__(self: ApifyClientAsyncPatcher, monkeypatch: pytest.MonkeyPatch) -> None: self.monkeypatch = monkeypatch - self.calls: Dict[str, Dict[str, List[Tuple[Any, Any]]]] = defaultdict(lambda: defaultdict(list)) + self.calls: dict[str, dict[str, list[tuple[Any, Any]]]] = defaultdict(lambda: defaultdict(list)) def patch( - self, + self: ApifyClientAsyncPatcher, method: str, submethod: str, *, - return_value: Optional[Any] = None, - replacement_method: Optional[Callable] = None, - is_async: Optional[bool] = None, + return_value: Any = None, + replacement_method: Callable | None = None, + is_async: bool | None = None, ) -> None: """ Patch a method in ApifyClientAsync. @@ -78,7 +83,26 @@ def patch( if not client_method: raise ValueError(f'ApifyClientAsync does not contain method "{method}"!') - client_method_return_type = get_type_hints(client_method)['return'] + try: + # Try to get the return type of the client method using `typing.get_type_hints()` + client_method_return_type = get_type_hints(client_method)['return'] + except TypeError: + # There is a known issue with `typing.get_type_hints()` on Python 3.8 and 3.9. It raises a `TypeError` + # when `|` (Union) is used in the type hint, even with `from __future__ import annotations`. Since we + # only need the return type, we attempt the following workaround. + + # 1. Create a deep copy of the client method object + client_method_copied = deepcopy(client_method) + + # 2. Restrict the annotations to only include the return type + client_method_copied.__annotations__ = {'return': client_method.__annotations__['return']} + + # 3. Try to get the return type again using `typing.get_type_hints()` + client_method_return_type = get_type_hints(client_method_copied)['return'] + + # TODO: Remove this fallback once we drop support for Python 3.8 and 3.9 + # https://github.com/apify/apify-sdk-python/issues/151 + original_submethod = getattr(client_method_return_type, submethod, None) if not original_submethod: @@ -100,7 +124,8 @@ async def replacement_method(*args: Any, **kwargs: Any) -> Any: return_value.set_result(original_return_value) if not replacement_method: - def replacement_method(*_args: Any, **_kwargs: Any) -> Optional[Any]: + + def replacement_method(*_args: Any, **_kwargs: Any) -> Any: return return_value def wrapper(*args: Any, **kwargs: Any) -> Any: @@ -125,11 +150,11 @@ def getattr_override(apify_client_instance: Any, attr_name: str) -> Any: self.monkeypatch.setattr(ApifyClientAsync, '__getattr__', getattr_override, raising=False) -@pytest.fixture +@pytest.fixture() def apify_client_async_patcher(monkeypatch: pytest.MonkeyPatch) -> ApifyClientAsyncPatcher: return ApifyClientAsyncPatcher(monkeypatch) -@pytest.fixture +@pytest.fixture() def memory_storage_client() -> MemoryStorageClient: return MemoryStorageClient(write_metadata=True, persist_storage=True) diff --git a/tests/unit/memory_storage/resource_clients/test_dataset.py b/tests/unit/memory_storage/resource_clients/test_dataset.py index 478dfdd5..6c5aaecf 100644 --- a/tests/unit/memory_storage/resource_clients/test_dataset.py +++ b/tests/unit/memory_storage/resource_clients/test_dataset.py @@ -1,13 +1,17 @@ +from __future__ import annotations + import asyncio import os +from typing import TYPE_CHECKING import pytest -from apify._memory_storage import MemoryStorageClient -from apify._memory_storage.resource_clients import DatasetClient +if TYPE_CHECKING: + from apify._memory_storage import MemoryStorageClient + from apify._memory_storage.resource_clients import DatasetClient -@pytest.fixture +@pytest.fixture() async def dataset_client(memory_storage_client: MemoryStorageClient) -> DatasetClient: datasets_client = memory_storage_client.datasets() dataset_info = await datasets_client.get_or_create(name='test') diff --git a/tests/unit/memory_storage/resource_clients/test_dataset_collection.py b/tests/unit/memory_storage/resource_clients/test_dataset_collection.py index a08066be..89b79228 100644 --- a/tests/unit/memory_storage/resource_clients/test_dataset_collection.py +++ b/tests/unit/memory_storage/resource_clients/test_dataset_collection.py @@ -1,11 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING import pytest -from apify._memory_storage import MemoryStorageClient -from apify._memory_storage.resource_clients import DatasetCollectionClient +if TYPE_CHECKING: + from apify._memory_storage import MemoryStorageClient + from apify._memory_storage.resource_clients import DatasetCollectionClient -@pytest.fixture +@pytest.fixture() def datasets_client(memory_storage_client: MemoryStorageClient) -> DatasetCollectionClient: return memory_storage_client.datasets() diff --git a/tests/unit/memory_storage/resource_clients/test_key_value_store.py b/tests/unit/memory_storage/resource_clients/test_key_value_store.py index 9dcaf1f2..209599e8 100644 --- a/tests/unit/memory_storage/resource_clients/test_key_value_store.py +++ b/tests/unit/memory_storage/resource_clients/test_key_value_store.py @@ -1,26 +1,31 @@ +from __future__ import annotations + import asyncio import base64 import json import os from datetime import datetime, timezone -from pathlib import Path -from typing import Dict +from typing import TYPE_CHECKING import pytest from apify._crypto import crypto_random_object_id -from apify._memory_storage import MemoryStorageClient -from apify._memory_storage.resource_clients import KeyValueStoreClient -from apify._utils import _maybe_parse_body +from apify._utils import maybe_parse_body from apify_shared.utils import json_dumps +if TYPE_CHECKING: + from pathlib import Path + + from apify._memory_storage import MemoryStorageClient + from apify._memory_storage.resource_clients import KeyValueStoreClient + TINY_PNG = base64.b64decode('iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVQYV2NgYAAAAAMAAWgmWQ0AAAAASUVORK5CYII=') TINY_BYTES = b'\x12\x34\x56\x78\x90\xAB\xCD\xEF' TINY_DATA = {'a': 'b'} TINY_TEXT = 'abcd' -@pytest.fixture +@pytest.fixture() async def key_value_store_client(memory_storage_client: MemoryStorageClient) -> KeyValueStoreClient: key_value_stores_client = memory_storage_client.key_value_stores() kvs_info = await key_value_stores_client.get_or_create(name='test') @@ -164,7 +169,7 @@ async def test_get_and_set_record(tmp_path: Path, key_value_store_client: KeyVal # Test using bytes bytes_key = 'test-json' - bytes_value = 'testing bytes set_record'.encode('utf-8') + bytes_value = b'testing bytes set_record' await key_value_store_client.set_record(bytes_key, bytes_value, 'unknown') bytes_record_info = await key_value_store_client.get_record(bytes_key) assert bytes_record_info is not None @@ -173,7 +178,7 @@ async def test_get_and_set_record(tmp_path: Path, key_value_store_client: KeyVal assert bytes_record_info['value'].decode('utf-8') == bytes_value.decode('utf-8') # Test using file descriptor - with open(os.path.join(tmp_path, 'test.json'), 'w+', encoding='utf-8') as f: + with open(os.path.join(tmp_path, 'test.json'), 'w+', encoding='utf-8') as f: # noqa: ASYNC101 f.write('Test') with pytest.raises(NotImplementedError, match='File-like values are not supported in local memory storage'): await key_value_store_client.set_record('file', f) @@ -196,35 +201,60 @@ async def test_delete_record(key_value_store_client: KeyValueStoreClient) -> Non await key_value_store_client.delete_record(record_key) -@pytest.mark.parametrize('test_case', [ - {'input': {'key': 'image', 'value': TINY_PNG, 'contentType': None}, - 'expectedOutput': {'filename': 'image', 'key': 'image', 'contentType': 'application/octet-stream'}}, - {'input': {'key': 'image', 'value': TINY_PNG, 'contentType': 'image/png'}, - 'expectedOutput': {'filename': 'image.png', 'key': 'image', 'contentType': 'image/png'}}, - {'input': {'key': 'image.png', 'value': TINY_PNG, 'contentType': None}, - 'expectedOutput': {'filename': 'image.png', 'key': 'image.png', 'contentType': 'application/octet-stream'}}, - {'input': {'key': 'image.png', 'value': TINY_PNG, 'contentType': 'image/png'}, - 'expectedOutput': {'filename': 'image.png', 'key': 'image.png', 'contentType': 'image/png'}}, - - {'input': {'key': 'data', 'value': TINY_DATA, 'contentType': None}, - 'expectedOutput': {'filename': 'data.json', 'key': 'data', 'contentType': 'application/json'}}, - {'input': {'key': 'data', 'value': TINY_DATA, 'contentType': 'application/json'}, - 'expectedOutput': {'filename': 'data.json', 'key': 'data', 'contentType': 'application/json'}}, - {'input': {'key': 'data.json', 'value': TINY_DATA, 'contentType': None}, - 'expectedOutput': {'filename': 'data.json', 'key': 'data.json', 'contentType': 'application/json'}}, - {'input': {'key': 'data.json', 'value': TINY_DATA, 'contentType': 'application/json'}, - 'expectedOutput': {'filename': 'data.json', 'key': 'data.json', 'contentType': 'application/json'}}, - - {'input': {'key': 'text', 'value': TINY_TEXT, 'contentType': None}, - 'expectedOutput': {'filename': 'text.txt', 'key': 'text', 'contentType': 'text/plain'}}, - {'input': {'key': 'text', 'value': TINY_TEXT, 'contentType': 'text/plain'}, - 'expectedOutput': {'filename': 'text.txt', 'key': 'text', 'contentType': 'text/plain'}}, - {'input': {'key': 'text.txt', 'value': TINY_TEXT, 'contentType': None}, - 'expectedOutput': {'filename': 'text.txt', 'key': 'text.txt', 'contentType': 'text/plain'}}, - {'input': {'key': 'text.txt', 'value': TINY_TEXT, 'contentType': 'text/plain'}, - 'expectedOutput': {'filename': 'text.txt', 'key': 'text.txt', 'contentType': 'text/plain'}}, -]) -async def test_writes_correct_metadata(memory_storage_client: MemoryStorageClient, test_case: Dict) -> None: +@pytest.mark.parametrize( + 'test_case', + [ + { + 'input': {'key': 'image', 'value': TINY_PNG, 'contentType': None}, + 'expectedOutput': {'filename': 'image', 'key': 'image', 'contentType': 'application/octet-stream'}, + }, + { + 'input': {'key': 'image', 'value': TINY_PNG, 'contentType': 'image/png'}, + 'expectedOutput': {'filename': 'image.png', 'key': 'image', 'contentType': 'image/png'}, + }, + { + 'input': {'key': 'image.png', 'value': TINY_PNG, 'contentType': None}, + 'expectedOutput': {'filename': 'image.png', 'key': 'image.png', 'contentType': 'application/octet-stream'}, + }, + { + 'input': {'key': 'image.png', 'value': TINY_PNG, 'contentType': 'image/png'}, + 'expectedOutput': {'filename': 'image.png', 'key': 'image.png', 'contentType': 'image/png'}, + }, + { + 'input': {'key': 'data', 'value': TINY_DATA, 'contentType': None}, + 'expectedOutput': {'filename': 'data.json', 'key': 'data', 'contentType': 'application/json'}, + }, + { + 'input': {'key': 'data', 'value': TINY_DATA, 'contentType': 'application/json'}, + 'expectedOutput': {'filename': 'data.json', 'key': 'data', 'contentType': 'application/json'}, + }, + { + 'input': {'key': 'data.json', 'value': TINY_DATA, 'contentType': None}, + 'expectedOutput': {'filename': 'data.json', 'key': 'data.json', 'contentType': 'application/json'}, + }, + { + 'input': {'key': 'data.json', 'value': TINY_DATA, 'contentType': 'application/json'}, + 'expectedOutput': {'filename': 'data.json', 'key': 'data.json', 'contentType': 'application/json'}, + }, + { + 'input': {'key': 'text', 'value': TINY_TEXT, 'contentType': None}, + 'expectedOutput': {'filename': 'text.txt', 'key': 'text', 'contentType': 'text/plain'}, + }, + { + 'input': {'key': 'text', 'value': TINY_TEXT, 'contentType': 'text/plain'}, + 'expectedOutput': {'filename': 'text.txt', 'key': 'text', 'contentType': 'text/plain'}, + }, + { + 'input': {'key': 'text.txt', 'value': TINY_TEXT, 'contentType': None}, + 'expectedOutput': {'filename': 'text.txt', 'key': 'text.txt', 'contentType': 'text/plain'}, + }, + { + 'input': {'key': 'text.txt', 'value': TINY_TEXT, 'contentType': 'text/plain'}, + 'expectedOutput': {'filename': 'text.txt', 'key': 'text.txt', 'contentType': 'text/plain'}, + }, + ], +) +async def test_writes_correct_metadata(memory_storage_client: MemoryStorageClient, test_case: dict) -> None: test_input = test_case['input'] expected_output = test_case['expectedOutput'] key_value_store_name = crypto_random_object_id() @@ -242,49 +272,79 @@ async def test_writes_correct_metadata(memory_storage_client: MemoryStorageClien assert os.path.exists(item_path) assert os.path.exists(metadata_path) - with open(item_path, 'rb') as item_file: - actual_value = _maybe_parse_body(item_file.read(), expected_output['contentType']) + with open(item_path, 'rb') as item_file: # noqa: ASYNC101 + actual_value = maybe_parse_body(item_file.read(), expected_output['contentType']) assert actual_value == test_input['value'] - with open(metadata_path, 'r', encoding='utf-8') as metadata_file: + with open(metadata_path, encoding='utf-8') as metadata_file: # noqa: ASYNC101 metadata = json.load(metadata_file) assert metadata['key'] == expected_output['key'] assert expected_output['contentType'] in metadata['contentType'] -@pytest.mark.parametrize('test_case', [ - {'input': {'filename': 'image', 'value': TINY_PNG, 'metadata': None}, - 'expectedOutput': {'key': 'image', 'filename': 'image', 'contentType': 'application/octet-stream'}}, - {'input': {'filename': 'image.png', 'value': TINY_PNG, 'metadata': None}, - 'expectedOutput': {'key': 'image', 'filename': 'image.png', 'contentType': 'image/png'}}, - {'input': {'filename': 'image', 'value': TINY_PNG, 'metadata': {'key': 'image', 'contentType': 'application/octet-stream'}}, - 'expectedOutput': {'key': 'image', 'contentType': 'application/octet-stream'}}, - {'input': {'filename': 'image', 'value': TINY_PNG, 'metadata': {'key': 'image', 'contentType': 'image/png'}}, - 'expectedOutput': {'key': 'image', 'filename': 'image', 'contentType': 'image/png'}}, - {'input': {'filename': 'image.png', 'value': TINY_PNG, 'metadata': {'key': 'image.png', 'contentType': 'application/octet-stream'}}, - 'expectedOutput': {'key': 'image.png', 'contentType': 'application/octet-stream'}}, - {'input': {'filename': 'image.png', 'value': TINY_PNG, 'metadata': {'key': 'image.png', 'contentType': 'image/png'}}, - 'expectedOutput': {'key': 'image.png', 'contentType': 'image/png'}}, - {'input': {'filename': 'image.png', 'value': TINY_PNG, 'metadata': {'key': 'image', 'contentType': 'image/png'}}, - 'expectedOutput': {'key': 'image', 'contentType': 'image/png'}}, - {'input': {'filename': 'input', 'value': TINY_BYTES, 'metadata': None}, - 'expectedOutput': {'key': 'input', 'contentType': 'application/octet-stream'}}, - {'input': {'filename': 'input.json', 'value': TINY_DATA, 'metadata': None}, - 'expectedOutput': {'key': 'input', 'contentType': 'application/json'}}, - {'input': {'filename': 'input.txt', 'value': TINY_TEXT, 'metadata': None}, - 'expectedOutput': {'key': 'input', 'contentType': 'text/plain'}}, - {'input': {'filename': 'input.bin', 'value': TINY_BYTES, 'metadata': None}, - 'expectedOutput': {'key': 'input', 'contentType': 'application/octet-stream'}}, - {'input': {'filename': 'input', 'value': TINY_BYTES, 'metadata': {'key': 'input', 'contentType': 'application/octet-stream'}}, - 'expectedOutput': {'key': 'input', 'contentType': 'application/octet-stream'}}, - {'input': {'filename': 'input.json', 'value': TINY_DATA, 'metadata': {'key': 'input', 'contentType': 'application/json'}}, - 'expectedOutput': {'key': 'input', 'contentType': 'application/json'}}, - {'input': {'filename': 'input.txt', 'value': TINY_TEXT, 'metadata': {'key': 'input', 'contentType': 'text/plain'}}, - 'expectedOutput': {'key': 'input', 'contentType': 'text/plain'}}, - {'input': {'filename': 'input.bin', 'value': TINY_BYTES, 'metadata': {'key': 'input', 'contentType': 'application/octet-stream'}}, - 'expectedOutput': {'key': 'input', 'contentType': 'application/octet-stream'}}, -]) -async def test_reads_correct_metadata(memory_storage_client: MemoryStorageClient, test_case: Dict) -> None: +@pytest.mark.parametrize( + 'test_case', + [ + { + 'input': {'filename': 'image', 'value': TINY_PNG, 'metadata': None}, + 'expectedOutput': {'key': 'image', 'filename': 'image', 'contentType': 'application/octet-stream'}, + }, + { + 'input': {'filename': 'image.png', 'value': TINY_PNG, 'metadata': None}, + 'expectedOutput': {'key': 'image', 'filename': 'image.png', 'contentType': 'image/png'}, + }, + { + 'input': {'filename': 'image', 'value': TINY_PNG, 'metadata': {'key': 'image', 'contentType': 'application/octet-stream'}}, + 'expectedOutput': {'key': 'image', 'contentType': 'application/octet-stream'}, + }, + { + 'input': {'filename': 'image', 'value': TINY_PNG, 'metadata': {'key': 'image', 'contentType': 'image/png'}}, + 'expectedOutput': {'key': 'image', 'filename': 'image', 'contentType': 'image/png'}, + }, + { + 'input': {'filename': 'image.png', 'value': TINY_PNG, 'metadata': {'key': 'image.png', 'contentType': 'application/octet-stream'}}, + 'expectedOutput': {'key': 'image.png', 'contentType': 'application/octet-stream'}, + }, + { + 'input': {'filename': 'image.png', 'value': TINY_PNG, 'metadata': {'key': 'image.png', 'contentType': 'image/png'}}, + 'expectedOutput': {'key': 'image.png', 'contentType': 'image/png'}, + }, + { + 'input': {'filename': 'image.png', 'value': TINY_PNG, 'metadata': {'key': 'image', 'contentType': 'image/png'}}, + 'expectedOutput': {'key': 'image', 'contentType': 'image/png'}, + }, + { + 'input': {'filename': 'input', 'value': TINY_BYTES, 'metadata': None}, + 'expectedOutput': {'key': 'input', 'contentType': 'application/octet-stream'}, + }, + { + 'input': {'filename': 'input.json', 'value': TINY_DATA, 'metadata': None}, + 'expectedOutput': {'key': 'input', 'contentType': 'application/json'}, + }, + {'input': {'filename': 'input.txt', 'value': TINY_TEXT, 'metadata': None}, 'expectedOutput': {'key': 'input', 'contentType': 'text/plain'}}, + { + 'input': {'filename': 'input.bin', 'value': TINY_BYTES, 'metadata': None}, + 'expectedOutput': {'key': 'input', 'contentType': 'application/octet-stream'}, + }, + { + 'input': {'filename': 'input', 'value': TINY_BYTES, 'metadata': {'key': 'input', 'contentType': 'application/octet-stream'}}, + 'expectedOutput': {'key': 'input', 'contentType': 'application/octet-stream'}, + }, + { + 'input': {'filename': 'input.json', 'value': TINY_DATA, 'metadata': {'key': 'input', 'contentType': 'application/json'}}, + 'expectedOutput': {'key': 'input', 'contentType': 'application/json'}, + }, + { + 'input': {'filename': 'input.txt', 'value': TINY_TEXT, 'metadata': {'key': 'input', 'contentType': 'text/plain'}}, + 'expectedOutput': {'key': 'input', 'contentType': 'text/plain'}, + }, + { + 'input': {'filename': 'input.bin', 'value': TINY_BYTES, 'metadata': {'key': 'input', 'contentType': 'application/octet-stream'}}, + 'expectedOutput': {'key': 'input', 'contentType': 'application/octet-stream'}, + }, + ], +) +async def test_reads_correct_metadata(memory_storage_client: MemoryStorageClient, test_case: dict) -> None: test_input = test_case['input'] expected_output = test_case['expectedOutput'] key_value_store_name = crypto_random_object_id() @@ -304,12 +364,12 @@ async def test_reads_correct_metadata(memory_storage_client: MemoryStorageClient # Write the store metadata to disk store_metadata_path = os.path.join(storage_path, '__metadata__.json') - with open(store_metadata_path, mode='wb') as store_metadata_file: + with open(store_metadata_path, mode='wb') as store_metadata_file: # noqa: ASYNC101 store_metadata_file.write(json_dumps(store_metadata).encode('utf-8')) # Write the test input item to the disk item_path = os.path.join(storage_path, test_input['filename']) - with open(item_path, 'wb') as item_file: + with open(item_path, 'wb') as item_file: # noqa: ASYNC101 if isinstance(test_input['value'], bytes): item_file.write(test_input['value']) elif isinstance(test_input['value'], str): @@ -320,11 +380,15 @@ async def test_reads_correct_metadata(memory_storage_client: MemoryStorageClient # Optionally write the metadata to disk if there is some if test_input['metadata'] is not None: metadata_path = os.path.join(storage_path, test_input['filename'] + '.__metadata__.json') - with open(metadata_path, 'w', encoding='utf-8') as metadata_file: - metadata_file.write(json_dumps({ - 'key': test_input['metadata']['key'], - 'contentType': test_input['metadata']['contentType'], - })) + with open(metadata_path, 'w', encoding='utf-8') as metadata_file: # noqa: ASYNC101 + metadata_file.write( + json_dumps( + { + 'key': test_input['metadata']['key'], + 'contentType': test_input['metadata']['contentType'], + } + ) + ) # Create the key-value store client to load the items from disk store_details = await memory_storage_client.key_value_stores().get_or_create(name=key_value_store_name) diff --git a/tests/unit/memory_storage/resource_clients/test_key_value_store_collection.py b/tests/unit/memory_storage/resource_clients/test_key_value_store_collection.py index 2a1fe763..f645df01 100644 --- a/tests/unit/memory_storage/resource_clients/test_key_value_store_collection.py +++ b/tests/unit/memory_storage/resource_clients/test_key_value_store_collection.py @@ -1,10 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import pytest -from apify._memory_storage import MemoryStorageClient -from apify._memory_storage.resource_clients import KeyValueStoreCollectionClient +if TYPE_CHECKING: + from apify._memory_storage import MemoryStorageClient + from apify._memory_storage.resource_clients import KeyValueStoreCollectionClient -@pytest.fixture +@pytest.fixture() def key_value_stores_client(memory_storage_client: MemoryStorageClient) -> KeyValueStoreCollectionClient: return memory_storage_client.key_value_stores() diff --git a/tests/unit/memory_storage/resource_clients/test_request_queue.py b/tests/unit/memory_storage/resource_clients/test_request_queue.py index 5547b175..c66bc68f 100644 --- a/tests/unit/memory_storage/resource_clients/test_request_queue.py +++ b/tests/unit/memory_storage/resource_clients/test_request_queue.py @@ -1,14 +1,18 @@ +from __future__ import annotations + import asyncio import os from datetime import datetime, timezone +from typing import TYPE_CHECKING import pytest -from apify._memory_storage import MemoryStorageClient -from apify._memory_storage.resource_clients import RequestQueueClient +if TYPE_CHECKING: + from apify._memory_storage import MemoryStorageClient + from apify._memory_storage.resource_clients import RequestQueueClient -@pytest.fixture +@pytest.fixture() async def request_queue_client(memory_storage_client: MemoryStorageClient) -> RequestQueueClient: request_queues_client = memory_storage_client.request_queues() rq_info = await request_queues_client.get_or_create(name='test') @@ -33,10 +37,12 @@ async def test_get(request_queue_client: RequestQueueClient) -> None: async def test_update(request_queue_client: RequestQueueClient) -> None: new_rq_name = 'test-update' - await request_queue_client.add_request({ - 'uniqueKey': 'https://apify.com', - 'url': 'https://apify.com', - }) + await request_queue_client.add_request( + { + 'uniqueKey': 'https://apify.com', + 'url': 'https://apify.com', + } + ) old_rq_info = await request_queue_client.get() assert old_rq_info is not None old_rq_directory = os.path.join(request_queue_client._memory_storage_client._request_queues_directory, old_rq_info['name']) @@ -59,10 +65,12 @@ async def test_update(request_queue_client: RequestQueueClient) -> None: async def test_delete(request_queue_client: RequestQueueClient) -> None: - await request_queue_client.add_request({ - 'uniqueKey': 'https://apify.com', - 'url': 'https://apify.com', - }) + await request_queue_client.add_request( + { + 'uniqueKey': 'https://apify.com', + 'url': 'https://apify.com', + } + ) rq_info = await request_queue_client.get() assert rq_info is not None @@ -79,14 +87,18 @@ async def test_delete(request_queue_client: RequestQueueClient) -> None: async def test_list_head(request_queue_client: RequestQueueClient) -> None: request_1_url = 'https://apify.com' request_2_url = 'https://example.com' - await request_queue_client.add_request({ - 'uniqueKey': request_1_url, - 'url': request_1_url, - }) - await request_queue_client.add_request({ - 'uniqueKey': request_2_url, - 'url': request_2_url, - }) + await request_queue_client.add_request( + { + 'uniqueKey': request_1_url, + 'url': request_1_url, + } + ) + await request_queue_client.add_request( + { + 'uniqueKey': request_2_url, + 'url': request_2_url, + } + ) list_head = await request_queue_client.list_head() assert len(list_head['items']) == 2 for item in list_head['items']: @@ -96,14 +108,20 @@ async def test_list_head(request_queue_client: RequestQueueClient) -> None: async def test_add_record(request_queue_client: RequestQueueClient) -> None: request_forefront_url = 'https://apify.com' request_not_forefront_url = 'https://example.com' - request_forefront_info = await request_queue_client.add_request({ - 'uniqueKey': request_forefront_url, - 'url': request_forefront_url, - }, forefront=True) - request_not_forefront_info = await request_queue_client.add_request({ - 'uniqueKey': request_not_forefront_url, - 'url': request_not_forefront_url, - }, forefront=False) + request_forefront_info = await request_queue_client.add_request( + { + 'uniqueKey': request_forefront_url, + 'url': request_forefront_url, + }, + forefront=True, + ) + request_not_forefront_info = await request_queue_client.add_request( + { + 'uniqueKey': request_not_forefront_url, + 'url': request_not_forefront_url, + }, + forefront=False, + ) assert request_forefront_info.get('requestId') is not None assert request_not_forefront_info.get('requestId') is not None @@ -118,10 +136,12 @@ async def test_add_record(request_queue_client: RequestQueueClient) -> None: async def test_get_record(request_queue_client: RequestQueueClient) -> None: request_url = 'https://apify.com' - request_info = await request_queue_client.add_request({ - 'uniqueKey': request_url, - 'url': request_url, - }) + request_info = await request_queue_client.add_request( + { + 'uniqueKey': request_url, + 'url': request_url, + } + ) request = await request_queue_client.get_request(request_info['requestId']) assert request is not None assert 'id' in request @@ -133,10 +153,12 @@ async def test_get_record(request_queue_client: RequestQueueClient) -> None: async def test_update_record(request_queue_client: RequestQueueClient) -> None: request_url = 'https://apify.com' - request_info = await request_queue_client.add_request({ - 'uniqueKey': request_url, - 'url': request_url, - }) + request_info = await request_queue_client.add_request( + { + 'uniqueKey': request_url, + 'url': request_url, + } + ) request = await request_queue_client.get_request(request_info['requestId']) assert request is not None @@ -156,15 +178,19 @@ async def test_update_record(request_queue_client: RequestQueueClient) -> None: async def test_delete_record(request_queue_client: RequestQueueClient) -> None: request_url = 'https://apify.com' - pending_request_info = await request_queue_client.add_request({ - 'uniqueKey': 'pending', - 'url': request_url, - }) - handled_request_info = await request_queue_client.add_request({ - 'uniqueKey': 'handled', - 'url': request_url, - 'handledAt': datetime.now(tz=timezone.utc), - }) + pending_request_info = await request_queue_client.add_request( + { + 'uniqueKey': 'pending', + 'url': request_url, + } + ) + handled_request_info = await request_queue_client.add_request( + { + 'uniqueKey': 'handled', + 'url': request_url, + 'handledAt': datetime.now(tz=timezone.utc), + } + ) rq_info_before_delete = await request_queue_client.get() assert rq_info_before_delete is not None @@ -197,11 +223,14 @@ async def test_forefront(request_queue_client: RequestQueueClient) -> None: request_url = f'http://example.com/{i}' forefront = i % 3 == 1 was_handled = i % 3 == 2 - await request_queue_client.add_request({ - 'uniqueKey': str(i), - 'url': request_url, - 'handledAt': datetime.now(timezone.utc) if was_handled else None, - }, forefront=forefront) + await request_queue_client.add_request( + { + 'uniqueKey': str(i), + 'url': request_url, + 'handledAt': datetime.now(timezone.utc) if was_handled else None, + }, + forefront=forefront, + ) # Check that the queue head (unhandled items) is in the right order queue_head = await request_queue_client.list_head() @@ -209,16 +238,21 @@ async def test_forefront(request_queue_client: RequestQueueClient) -> None: assert req_unique_keys == ['7', '4', '1', '0', '3', '6'] # Mark request #1 as handled - await request_queue_client.update_request({ - 'uniqueKey': '1', - 'url': 'http://example.com/1', - 'handledAt': datetime.now(timezone.utc), - }) + await request_queue_client.update_request( + { + 'uniqueKey': '1', + 'url': 'http://example.com/1', + 'handledAt': datetime.now(timezone.utc), + } + ) # Move request #3 to forefront - await request_queue_client.update_request({ - 'uniqueKey': '3', - 'url': 'http://example.com/3', - }, forefront=True) + await request_queue_client.update_request( + { + 'uniqueKey': '3', + 'url': 'http://example.com/3', + }, + forefront=True, + ) # Check that the queue head (unhandled items) is in the right order after the updates queue_head = await request_queue_client.list_head() diff --git a/tests/unit/memory_storage/resource_clients/test_request_queue_collection.py b/tests/unit/memory_storage/resource_clients/test_request_queue_collection.py index dc362d2e..3c33a2ac 100644 --- a/tests/unit/memory_storage/resource_clients/test_request_queue_collection.py +++ b/tests/unit/memory_storage/resource_clients/test_request_queue_collection.py @@ -1,10 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import pytest -from apify._memory_storage import MemoryStorageClient -from apify._memory_storage.resource_clients import RequestQueueCollectionClient +if TYPE_CHECKING: + from apify._memory_storage import MemoryStorageClient + from apify._memory_storage.resource_clients import RequestQueueCollectionClient -@pytest.fixture +@pytest.fixture() def request_queues_client(memory_storage_client: MemoryStorageClient) -> RequestQueueCollectionClient: return memory_storage_client.request_queues() diff --git a/tests/unit/memory_storage/test_memory_storage.py b/tests/unit/memory_storage/test_memory_storage.py index eb22173a..7a5f1c1c 100644 --- a/tests/unit/memory_storage/test_memory_storage.py +++ b/tests/unit/memory_storage/test_memory_storage.py @@ -1,11 +1,16 @@ +from __future__ import annotations + import os -from pathlib import Path +from typing import TYPE_CHECKING import pytest from apify._memory_storage import MemoryStorageClient from apify_shared.consts import ApifyEnvVars +if TYPE_CHECKING: + from pathlib import Path + async def test_write_metadata(tmp_path: Path) -> None: dataset_name = 'test' diff --git a/tests/unit/storages/test_dataset.py b/tests/unit/storages/test_dataset.py index 6596b772..ca3b1ca3 100644 --- a/tests/unit/storages/test_dataset.py +++ b/tests/unit/storages/test_dataset.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import pytest from apify.storages import Dataset, KeyValueStore -@pytest.fixture +@pytest.fixture() async def dataset() -> Dataset: return await Dataset.open() diff --git a/tests/unit/storages/test_key_value_store.py b/tests/unit/storages/test_key_value_store.py index 3bba9058..042fd873 100644 --- a/tests/unit/storages/test_key_value_store.py +++ b/tests/unit/storages/test_key_value_store.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import pytest from apify.storages import KeyValueStore -@pytest.fixture +@pytest.fixture() async def key_value_store() -> KeyValueStore: return await KeyValueStore.open() diff --git a/tests/unit/storages/test_request_queue.py b/tests/unit/storages/test_request_queue.py index 2bd4e65b..2922e5b8 100644 --- a/tests/unit/storages/test_request_queue.py +++ b/tests/unit/storages/test_request_queue.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio from datetime import datetime, timezone @@ -6,7 +8,7 @@ from apify.storages import RequestQueue -@pytest.fixture +@pytest.fixture() async def request_queue() -> RequestQueue: return await RequestQueue.open() @@ -50,10 +52,12 @@ async def test_drop() -> None: async def test_get_request(request_queue: RequestQueue) -> None: url = 'https://example.com' - add_request_info = await request_queue.add_request({ - 'uniqueKey': url, - 'url': url, - }) + add_request_info = await request_queue.add_request( + { + 'uniqueKey': url, + 'url': url, + } + ) request = await request_queue.get_request(add_request_info['requestId']) assert request is not None assert request['url'] == url @@ -64,21 +68,23 @@ async def test_add_fetch_handle_request(request_queue: RequestQueue) -> None: assert await request_queue.is_empty() is True with pytest.raises(ValueError, match='"url" is required'): await request_queue.add_request({}) - add_request_info = await request_queue.add_request({ - 'uniqueKey': url, - 'url': url, - }) + add_request_info = await request_queue.add_request( + { + 'uniqueKey': url, + 'url': url, + } + ) assert add_request_info['wasAlreadyPresent'] is False assert add_request_info['wasAlreadyHandled'] is False assert await request_queue.is_empty() is False # Fetch the request - next = await request_queue.fetch_next_request() - assert next is not None + next_request = await request_queue.fetch_next_request() + assert next_request is not None # Mark it as handled - next['handledAt'] = datetime.now(timezone.utc) - queue_operation_info = await request_queue.mark_request_as_handled(next) + next_request['handledAt'] = datetime.now(timezone.utc) + queue_operation_info = await request_queue.mark_request_as_handled(next_request) assert queue_operation_info is not None assert queue_operation_info['uniqueKey'] == url assert await request_queue.is_finished() is True @@ -86,17 +92,19 @@ async def test_add_fetch_handle_request(request_queue: RequestQueue) -> None: async def test_reclaim_request(request_queue: RequestQueue) -> None: url = 'https://example.com' - await request_queue.add_request({ - 'uniqueKey': url, - 'url': url, - }) + await request_queue.add_request( + { + 'uniqueKey': url, + 'url': url, + } + ) # Fetch the request - next = await request_queue.fetch_next_request() - assert next is not None - assert next['uniqueKey'] == url + next_request = await request_queue.fetch_next_request() + assert next_request is not None + assert next_request['uniqueKey'] == url # Reclaim - await request_queue.reclaim_request(next) + await request_queue.reclaim_request(next_request) # Try to fetch again after a few secs await asyncio.sleep(4) # 3 seconds is the consistency delay in request queue next_again = await request_queue.fetch_next_request() diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 3af2c2be..c1f77d86 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,14 +1,18 @@ -from datetime import datetime, timezone +from __future__ import annotations -import pytest +from datetime import datetime, timezone +from typing import TYPE_CHECKING from apify.config import Configuration from apify_shared.consts import ActorEnvVars, ApifyEnvVars +if TYPE_CHECKING: + import pytest + class TestConfiguration: # Test that some config properties have some reasonable defaults - def test_configuration_defaults(self) -> None: + def test_configuration_defaults(self: TestConfiguration) -> None: config = Configuration() assert config.token is None assert config.proxy_password is None @@ -23,7 +27,7 @@ def test_configuration_defaults(self) -> None: assert config.started_at is None # Test that defining properties via env vars works - def test_configuration_from_env_vars(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_configuration_from_env_vars(self: TestConfiguration, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv(ApifyEnvVars.TOKEN, 'DUMMY_TOKEN') monkeypatch.setenv(ApifyEnvVars.PROXY_PASSWORD, 'DUMMY_PROXY_PASSWORD') monkeypatch.setenv(ApifyEnvVars.API_BASE_URL, 'DUMMY_API_BASE_URL') @@ -50,7 +54,7 @@ def test_configuration_from_env_vars(self, monkeypatch: pytest.MonkeyPatch) -> N assert config.started_at == datetime(2023, 1, 1, 12, 34, 56, 789000, tzinfo=timezone.utc) # Test that constructor arguments take precedence over env vars - def test_configuration_from_constructor_arguments(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_configuration_from_constructor_arguments(self: TestConfiguration, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv(ApifyEnvVars.TOKEN, 'DUMMY_TOKEN') monkeypatch.setenv(ApifyEnvVars.PROXY_PASSWORD, 'DUMMY_PROXY_PASSWORD') monkeypatch.setenv(ApifyEnvVars.API_BASE_URL, 'DUMMY_API_BASE_URL') diff --git a/tests/unit/test_crypto.py b/tests/unit/test_crypto.py index 5c382adb..f820a59f 100644 --- a/tests/unit/test_crypto.py +++ b/tests/unit/test_crypto.py @@ -1,29 +1,61 @@ +from __future__ import annotations + import base64 import pytest -from apify._crypto import _load_private_key, _load_public_key, crypto_random_object_id, private_decrypt, public_encrypt +from apify._crypto import _load_public_key, crypto_random_object_id, load_private_key, private_decrypt, public_encrypt # NOTE: Uses the same keys as in: # https://github.com/apify/apify-shared-js/blob/master/test/crypto.test.ts PRIVATE_KEY_PEM_BASE64 = 'LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpQcm9jLVR5cGU6IDQsRU5DUllQVEVECkRFSy1JbmZvOiBERVMtRURFMy1DQkMsNTM1QURERjIzNUQ4QkFGOQoKMXFWUzl0S0FhdkVhVUVFMktESnpjM3plMk1lZkc1dmVEd2o1UVJ0ZkRaMXdWNS9VZmIvcU5sVThTSjlNaGhKaQp6RFdrWExueUUzSW0vcEtITVZkS0czYWZkcFRtcis2TmtidXptd0dVMk0vSWpzRjRJZlpad0lGbGJoY09jUnp4CmZmWVIvTlVyaHNrS1RpNGhGV0lBUDlLb3Z6VDhPSzNZY3h6eVZQWUxYNGVWbWt3UmZzeWkwUU5Xb0tGT3d0ZC8KNm9HYzFnd2piRjI5ZDNnUThZQjFGWmRLa1AyMTJGbkt1cTIrUWgvbE1zTUZrTHlTQTRLTGJ3ZG1RSXExbE1QUwpjbUNtZnppV3J1MlBtNEZoM0dmWlQyaE1JWHlIRFdEVzlDTkxKaERodExOZ2RRamFBUFpVT1E4V2hwSkE5MS9vCjJLZzZ3MDd5Z2RCcVd5dTZrc0pXcjNpZ1JpUEJ5QmVNWEpEZU5HY3NhaUZ3Q2c5eFlja1VORXR3NS90WlRsTjIKSEdZV0NpVU5Ed0F2WllMUHR1SHpIOFRFMGxsZm5HR0VuVC9QQlp1UHV4andlZlRleE1mdzFpbGJRU3lkcy9HMgpOOUlKKzkydms0N0ZXR2NOdGh1Q3lCbklva0NpZ0c1ZlBlV2IwQTdpdjk0UGtwRTRJZ3plc0hGQ0ZFQWoxWldLCnpQdFRBQlkwZlJrUzBNc3UwMHYxOXloTTUrdFUwYkVCZWo2eWpzWHRoYzlwS01hcUNIZWlQTC9TSHRkaWsxNVMKQmU4Sml4dVJxZitUeGlYWWVuNTg2aDlzTFpEYzA3cGpkUGp2NVNYRnBYQjhIMlVxQ0tZY2p4R3RvQWpTV0pjWApMNHc3RHNEby80bVg1N0htR09iamlCN1ZyOGhVWEJDdFh2V0dmQXlmcEFZNS9vOXowdm4zREcxaDc1NVVwdDluCkF2MFZrbm9qcmJVYjM1ZlJuU1lYTVltS01LSnpNRlMrdmFvRlpwV0ZjTG10cFRWSWNzc0JGUEYyZEo3V1c0WHMKK0d2Vkl2eFl3S2wyZzFPTE1TTXRZa09vekdlblBXTzdIdU0yMUVKVGIvbHNEZ25GaTkrYWRGZHBLY3R2cm0zdgpmbW1HeG5pRmhLU05GU0xtNms5YStHL2pjK3NVQVBhb2FZNEQ3NHVGajh0WGp0eThFUHdRRGxVUGRVZld3SE9PClF3bVgyMys1REh4V0VoQy91Tm8yNHNNY2ZkQzFGZUpBV281bUNuVU5vUVVmMStNRDVhMzNJdDhhMmlrNUkxUWoKeSs1WGpRaG0xd3RBMWhWTWE4aUxBR0toT09lcFRuK1VBZHpyS0hvNjVtYzNKbGgvSFJDUXJabnVxWkErK0F2WgpjeWU0dWZGWC8xdmRQSTdLb2Q0MEdDM2dlQnhweFFNYnp1OFNUcGpOcElJRkJvRVc5dFRhemUzeHZXWnV6dDc0CnFjZS8xWURuUHBLeW5lM0xGMk94VWoyYWVYUW5YQkpYcGhTZTBVTGJMcWJtUll4bjJKWkl1d09RNHV5dm94NjUKdG9TWGNac054dUs4QTErZXNXR3JSN3pVc0djdU9QQTFERE9Ja2JjcGtmRUxMNjk4RTJRckdqTU9JWnhrcWdxZQoySE5VNktWRmV2NzdZeEJDbm1VcVdXZEhYMjcyU2NPMUYzdWpUdFVnRVBNWGN0aEdBckYzTWxEaUw1Q0k0RkhqCnhHc3pVemxzalRQTmpiY2MzdUE2MjVZS3VVZEI2c1h1Rk5NUHk5UDgwTzBpRWJGTXl3MWxmN2VpdFhvaUUxWVoKc3NhMDVxTUx4M3pPUXZTLzFDdFpqaFp4cVJMRW5pQ3NWa2JVRlVYclpodEU4dG94bGpWSUtpQ25qbitORmtqdwo2bTZ1anpBSytZZHd2Nk5WMFB4S0gwUk5NYVhwb1lmQk1oUmZ3dGlaS3V3Y2hyRFB5UEhBQ2J3WXNZOXdtUE9rCnpwdDNxWi9JdDVYTmVqNDI0RzAzcGpMbk1sd1B1T1VzYmFQUWQ2VHU4TFhsckZReUVjTXJDNHdjUTA1SzFVN3kKM1NNN3RFaTlnbjV3RjY1YVI5eEFBR0grTUtMMk5WNnQrUmlTazJVaWs1clNmeDE4Mk9wYmpSQ2grdmQ4UXhJdwotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=' # noqa: E501 PRIVATE_KEY_PASSWORD = 'pwd1234' PUBLIC_KEY_PEM_BASE64 = 'LS0tLS1CRUdJTiBQVUJMSUMgS0VZLS0tLS0KTUlJQklqQU5CZ2txaGtpRzl3MEJBUUVGQUFPQ0FROEFNSUlCQ2dLQ0FRRUF0dis3NlNXbklhOFFKWC94RUQxRQpYdnBBQmE3ajBnQnVYenJNUU5adjhtTW1RU0t2VUF0TmpOL2xacUZpQ0haZUQxU2VDcGV1MnFHTm5XbGRxNkhUCnh5cXJpTVZEbFNKaFBNT09QSENISVNVdFI4Tk5lR1Y1MU0wYkxJcENabHcyTU9GUjdqdENWejVqZFRpZ1NvYTIKQWxrRUlRZWQ4UVlDKzk1aGJoOHk5bGcwQ0JxdEdWN1FvMFZQR2xKQ0hGaWNuaWxLVFFZay9MZzkwWVFnUElPbwozbUppeFl5bWFGNmlMZTVXNzg1M0VHWUVFVWdlWmNaZFNjaGVBMEdBMGpRSFVTdnYvMEZjay9adkZNZURJOTVsCmJVQ0JoQjFDbFg4OG4wZUhzUmdWZE5vK0NLMDI4T2IvZTZTK1JLK09VaHlFRVdPTi90alVMdGhJdTJkQWtGcmkKOFFJREFRQUIKLS0tLS1FTkQgUFVCTElDIEtFWS0tLS0tCg==' # noqa: E501 -PRIVATE_KEY = _load_private_key( - PRIVATE_KEY_PEM_BASE64, - PRIVATE_KEY_PASSWORD, -) +PRIVATE_KEY = load_private_key(PRIVATE_KEY_PEM_BASE64, PRIVATE_KEY_PASSWORD) PUBLIC_KEY = _load_public_key(PUBLIC_KEY_PEM_BASE64) -class TestCrypto(): - def test_encrypt_decrypt_varions_string(self) -> None: - for value in [crypto_random_object_id(10), '👍', '!', '@', '#', '$', '%', '^', '&', '*', '(', ')', '-', '_', '=', '+', '[', ']', '{', '}', '|', ';', ':', '"', "'", ',', '.', '<', '>', '?', '/', '~']: # noqa: E501 +class TestCrypto: + def test_encrypt_decrypt_varions_string(self: TestCrypto) -> None: + for value in [ + crypto_random_object_id(10), + '👍', + '!', + '@', + '#', + '$', + '%', + '^', + '&', + '*', + '(', + ')', + '-', + '_', + '=', + '+', + '[', + ']', + '{', + '}', + '|', + ';', + ':', + '"', + "'", + ',', + '.', + '<', + '>', + '?', + '/', + '~', + ]: encrypted = public_encrypt(value, public_key=PUBLIC_KEY) decrypted_value = private_decrypt(**encrypted, private_key=PRIVATE_KEY) assert decrypted_value == value - def test_throw_if_password_is_not_valid(self) -> None: + def test_throw_if_password_is_not_valid(self: TestCrypto) -> None: test_value = 'test' encrypted = public_encrypt(test_value, public_key=PUBLIC_KEY) encrypted['encrypted_password'] = base64.b64encode(b'invalid_password').decode('utf-8') @@ -31,16 +63,17 @@ def test_throw_if_password_is_not_valid(self) -> None: with pytest.raises(ValueError, match='Ciphertext length must be equal to key size.'): private_decrypt(**encrypted, private_key=PRIVATE_KEY) - def test_throw_error_if_cipher_is_manipulated(self) -> None: + def test_throw_error_if_cipher_is_manipulated(self: TestCrypto) -> None: test_value = 'test2' encrypted = public_encrypt(test_value, public_key=PUBLIC_KEY) encrypted['encrypted_value'] = base64.b64encode( - b'invalid_cipher' + base64.b64decode(encrypted['encrypted_value'].encode('utf-8'))).decode('utf-8') + b'invalid_cipher' + base64.b64decode(encrypted['encrypted_value'].encode('utf-8')), + ).decode('utf-8') with pytest.raises(ValueError, match='Decryption failed, malformed encrypted value or password.'): private_decrypt(**encrypted, private_key=PRIVATE_KEY) - def test_same_encrypted_value_should_return_deffirent_cipher(self) -> None: + def test_same_encrypted_value_should_return_deffirent_cipher(self: TestCrypto) -> None: test_value = 'test3' encrypted1 = public_encrypt(test_value, public_key=PUBLIC_KEY) encrypted2 = public_encrypt(test_value, public_key=PUBLIC_KEY) @@ -48,7 +81,7 @@ def test_same_encrypted_value_should_return_deffirent_cipher(self) -> None: # Check if the method is compatible with js version of the same method in: # https://github.com/apify/apify-shared-js/blob/master/packages/utilities/src/crypto.ts - def test_private_encrypt_node_js_encrypted_value(self) -> None: + def test_private_encrypt_node_js_encrypted_value(self: TestCrypto) -> None: value = 'encrypted_with_node_js' # This was encrypted with nodejs version of the same method. encrypted_value_with_node_js = { @@ -62,7 +95,7 @@ def test_private_encrypt_node_js_encrypted_value(self) -> None: assert decrypted_value == value - def test_crypto_random_object_id(self) -> None: + def test_crypto_random_object_id(self: TestCrypto) -> None: assert len(crypto_random_object_id()) == 17 assert len(crypto_random_object_id(5)) == 5 long_random_object_id = crypto_random_object_id(1000) diff --git a/tests/unit/test_event_manager.py b/tests/unit/test_event_manager.py index cc2c09e1..145c954c 100644 --- a/tests/unit/test_event_manager.py +++ b/tests/unit/test_event_manager.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import asyncio import json import logging import time from collections import defaultdict from pprint import pprint -from typing import Any, Callable, Dict, Optional, Set +from typing import Any, Callable import pytest import websockets @@ -16,7 +18,7 @@ class TestEventManagerLocal: - async def test_lifecycle_local(self, caplog: pytest.LogCaptureFixture) -> None: + async def test_lifecycle_local(self: TestEventManagerLocal, caplog: pytest.LogCaptureFixture) -> None: caplog.set_level(logging.DEBUG, logger='apify') config = Configuration() @@ -39,7 +41,7 @@ async def test_lifecycle_local(self, caplog: pytest.LogCaptureFixture) -> None: assert event_manager._initialized is False - async def test_event_handling_local(self) -> None: + async def test_event_handling_local(self: TestEventManagerLocal) -> None: config = Configuration() event_manager = EventManager(config) @@ -47,10 +49,11 @@ async def test_event_handling_local(self) -> None: event_calls = defaultdict(list) - def on_event(event: ActorEventTypes, id: Optional[int] = None) -> Callable: + def on_event(event: ActorEventTypes, id: int | None = None) -> Callable: # noqa: A002 def event_handler(data: Any) -> None: nonlocal event_calls event_calls[event].append((id, data)) + return event_handler handler_system_info = on_event(ActorEventTypes.SYSTEM_INFO) @@ -106,7 +109,7 @@ def event_handler(data: Any) -> None: await event_manager.close() - async def test_event_handler_argument_counts_local(self) -> None: + async def test_event_handler_argument_counts_local(self: TestEventManagerLocal) -> None: config = Configuration() event_manager = EventManager(config) @@ -173,7 +176,7 @@ async def async_two_arguments_one_default(event_data: Any, _arg2: Any = 'default assert ('sync_two_arguments_one_default', 'DUMMY_SYSTEM_INFO') in event_calls assert ('async_two_arguments_one_default', 'DUMMY_SYSTEM_INFO') in event_calls - async def test_event_async_handling_local(self) -> None: + async def test_event_async_handling_local(self: TestEventManagerLocal) -> None: config = Configuration() event_manager = EventManager(config) @@ -196,7 +199,10 @@ async def event_handler(data: Any) -> None: await event_manager.close() - async def test_wait_for_all_listeners_to_complete(self, caplog: pytest.LogCaptureFixture) -> None: + async def test_wait_for_all_listeners_to_complete( + self: TestEventManagerLocal, + caplog: pytest.LogCaptureFixture, + ) -> None: config = Configuration() event_manager = EventManager(config) @@ -204,12 +210,13 @@ async def test_wait_for_all_listeners_to_complete(self, caplog: pytest.LogCaptur event_calls = [] - def on_event(sleep_secs: Optional[int] = None) -> Callable: + def on_event(sleep_secs: int | None = None) -> Callable: async def event_handler(data: Any) -> None: nonlocal event_calls if sleep_secs: await asyncio.sleep(sleep_secs) event_calls.append(data) + return event_handler # Create three handlers, all with a different sleep time, and add them @@ -268,7 +275,10 @@ async def event_handler(data: Any) -> None: class TestEventManagerOnPlatform: - async def test_lifecycle_on_platform_without_websocket(self, monkeypatch: pytest.MonkeyPatch) -> None: + async def test_lifecycle_on_platform_without_websocket( + self: TestEventManagerOnPlatform, + monkeypatch: pytest.MonkeyPatch, + ) -> None: monkeypatch.setenv(ActorEnvVars.EVENTS_WEBSOCKET_URL, 'ws://localhost:56565') config = Configuration() @@ -279,8 +289,8 @@ async def test_lifecycle_on_platform_without_websocket(self, monkeypatch: pytest assert event_manager._initialized is False - async def test_lifecycle_on_platform(self, monkeypatch: pytest.MonkeyPatch) -> None: - connected_ws_clients: Set[websockets.server.WebSocketServerProtocol] = set() + async def test_lifecycle_on_platform(self: TestEventManagerOnPlatform, monkeypatch: pytest.MonkeyPatch) -> None: + connected_ws_clients: set[websockets.server.WebSocketServerProtocol] = set() async def handler(websocket: websockets.server.WebSocketServerProtocol) -> None: connected_ws_clients.add(websocket) @@ -307,8 +317,11 @@ async def handler(websocket: websockets.server.WebSocketServerProtocol) -> None: assert event_manager._initialized is False - async def test_event_handling_on_platform(self, monkeypatch: pytest.MonkeyPatch) -> None: - connected_ws_clients: Set[websockets.server.WebSocketServerProtocol] = set() + async def test_event_handling_on_platform( + self: TestEventManagerOnPlatform, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + connected_ws_clients: set[websockets.server.WebSocketServerProtocol] = set() async def handler(websocket: websockets.server.WebSocketServerProtocol) -> None: connected_ws_clients.add(websocket) @@ -318,7 +331,7 @@ async def handler(websocket: websockets.server.WebSocketServerProtocol) -> None: connected_ws_clients.remove(websocket) async def send_platform_event(event_name: ActorEventTypes, data: Any = None) -> None: - message: Dict[str, Any] = {'name': event_name} + message: dict[str, Any] = {'name': event_name} if data: message['data'] = data diff --git a/tests/unit/test_lru_cache.py b/tests/unit/test_lru_cache.py index bd492aa1..fe298ae6 100644 --- a/tests/unit/test_lru_cache.py +++ b/tests/unit/test_lru_cache.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import pytest from apify._utils import LRUCache -@pytest.fixture +@pytest.fixture() def lru_cache() -> LRUCache[int]: cache = LRUCache[int](3) cache['a'] = 1 diff --git a/tests/unit/test_proxy_configuration.py b/tests/unit/test_proxy_configuration.py index deb5ceb4..1db2349a 100644 --- a/tests/unit/test_proxy_configuration.py +++ b/tests/unit/test_proxy_configuration.py @@ -1,22 +1,27 @@ +from __future__ import annotations + import asyncio import re -from typing import List, Optional, Union +from typing import TYPE_CHECKING import httpx import pytest -from respx import MockRouter -from apify.proxy_configuration import ProxyConfiguration, _is_url +from apify.proxy_configuration import ProxyConfiguration, is_url from apify_client import ApifyClientAsync from apify_shared.consts import ApifyEnvVars -from .conftest import ApifyClientAsyncPatcher +if TYPE_CHECKING: + from respx import MockRouter + + from .conftest import ApifyClientAsyncPatcher + DUMMY_PASSWORD = 'DUMMY_PASSWORD' class TestProxyConfiguration: - def test_constructor_basic(self) -> None: + def test_constructor_basic(self: TestProxyConfiguration) -> None: groups = ['GROUP1', 'GROUP2'] password = 'abcd1234' country_code = 'US' @@ -29,7 +34,7 @@ def test_constructor_basic(self) -> None: assert proxy_configuration._password == password assert proxy_configuration._country_code == country_code - def test_constructor_fallback(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_constructor_fallback(self: TestProxyConfiguration, monkeypatch: pytest.MonkeyPatch) -> None: hostname = 'example.com' password = 'abcd1234' port = 1234 @@ -44,8 +49,8 @@ def test_constructor_fallback(self, monkeypatch: pytest.MonkeyPatch) -> None: assert proxy_configuration._password == password assert proxy_configuration._port == port - def test__fails_with_invalid_arguments(self) -> None: - for (invalid_groups, bad_group_index) in [ + def test__fails_with_invalid_arguments(self: TestProxyConfiguration) -> None: + for invalid_groups, bad_group_index in [ (['abc', 'de-f', 'geh'], 1), (['', 'def', 'geh'], 0), (['abc', 'DEF', 'geh$'], 2), @@ -72,7 +77,7 @@ def test__fails_with_invalid_arguments(self) -> None: class TestProxyConfigurationNewUrl: - async def test_new_url_basic(self) -> None: + async def test_new_url_basic(self: TestProxyConfigurationNewUrl) -> None: groups = ['GROUP1', 'GROUP2'] password = 'abcd1234' country_code = 'US' @@ -89,7 +94,7 @@ async def test_new_url_basic(self) -> None: assert proxy_url == f'http://{expected_username}:{password}@{expected_hostname}:{expected_port}' - async def test_new_url_session_id(self) -> None: + async def test_new_url_session_id(self: TestProxyConfigurationNewUrl) -> None: groups = ['GROUP1', 'GROUP2'] password = 'abcd1234' country_code = 'US' @@ -99,9 +104,16 @@ async def test_new_url_session_id(self) -> None: country_code=country_code, ) - session_ids: List[Union[str, int]] = [ - 'a', 'a_b', 'a_2', 'a_1_b', 'aaa~BBB', - '1', '0.34252352', 123456, 'XXXXXXXXXXxxxxxxxxxxXXXXXXXXXXxxxxxxxxxxXXXXXXXXXX', + session_ids: list[str | int] = [ + 'a', + 'a_b', + 'a_2', + 'a_1_b', + 'aaa~BBB', + '1', + '0.34252352', + 123456, + 'XXXXXXXXXXxxxxxxxxxxXXXXXXXXXXxxxxxxxxxxXXXXXXXXXX', ] for session_id in session_ids: expected_username = f'groups-{"+".join(groups)},session-{session_id},country-{country_code}' @@ -116,7 +128,7 @@ async def test_new_url_session_id(self) -> None: with pytest.raises(ValueError, match=re.escape(str(invalid_session_id))): await proxy_configuration.new_url(invalid_session_id) - async def test_rotating_custom_urls(self) -> None: + async def test_rotating_custom_urls(self: TestProxyConfigurationNewUrl) -> None: proxy_urls = ['http://proxy.com:1111', 'http://proxy.com:2222', 'http://proxy.com:3333'] proxy_configuration = ProxyConfiguration(proxy_urls=proxy_urls) @@ -127,7 +139,7 @@ async def test_rotating_custom_urls(self) -> None: assert await proxy_configuration.new_url() == proxy_urls[1] assert await proxy_configuration.new_url() == proxy_urls[2] - async def test_rotating_custom_urls_with_sessions(self) -> None: + async def test_rotating_custom_urls_with_sessions(self: TestProxyConfigurationNewUrl) -> None: sessions = ['sesssion_01', 'sesssion_02', 'sesssion_03', 'sesssion_04', 'sesssion_05', 'sesssion_06'] proxy_urls = ['http://proxy.com:1111', 'http://proxy.com:2222', 'http://proxy.com:3333'] @@ -149,13 +161,17 @@ async def test_rotating_custom_urls_with_sessions(self) -> None: assert await proxy_configuration.new_url(sessions[1]) == proxy_urls[1] assert await proxy_configuration.new_url(sessions[3]) == proxy_urls[0] - async def test_custom_new_url_function(self) -> None: + async def test_custom_new_url_function(self: TestProxyConfigurationNewUrl) -> None: custom_urls = [ - 'http://proxy.com:1111', 'http://proxy.com:2222', 'http://proxy.com:3333', - 'http://proxy.com:4444', 'http://proxy.com:5555', 'http://proxy.com:6666', + 'http://proxy.com:1111', + 'http://proxy.com:2222', + 'http://proxy.com:3333', + 'http://proxy.com:4444', + 'http://proxy.com:5555', + 'http://proxy.com:6666', ] - def custom_new_url_function(_session_id: Optional[str]) -> str: + def custom_new_url_function(_session_id: str | None) -> str: nonlocal custom_urls return custom_urls.pop() @@ -164,13 +180,17 @@ def custom_new_url_function(_session_id: Optional[str]) -> str: for custom_url in reversed(custom_urls): assert await proxy_configuration.new_url() == custom_url - async def test_custom_new_url_function_async(self) -> None: + async def test_custom_new_url_function_async(self: TestProxyConfigurationNewUrl) -> None: custom_urls = [ - 'http://proxy.com:1111', 'http://proxy.com:2222', 'http://proxy.com:3333', - 'http://proxy.com:4444', 'http://proxy.com:5555', 'http://proxy.com:6666', + 'http://proxy.com:1111', + 'http://proxy.com:2222', + 'http://proxy.com:3333', + 'http://proxy.com:4444', + 'http://proxy.com:5555', + 'http://proxy.com:6666', ] - async def custom_new_url_function(_session_id: Optional[str]) -> str: + async def custom_new_url_function(_session_id: str | None) -> str: nonlocal custom_urls await asyncio.sleep(0.1) return custom_urls.pop() @@ -180,16 +200,16 @@ async def custom_new_url_function(_session_id: Optional[str]) -> str: for custom_url in reversed(custom_urls): assert await proxy_configuration.new_url() == custom_url - async def test_invalid_custom_new_url_function(self) -> None: - def custom_new_url_function(_session_id: Optional[str]) -> str: - raise ValueError() + async def test_invalid_custom_new_url_function(self: TestProxyConfigurationNewUrl) -> None: + def custom_new_url_function(_session_id: str | None) -> str: + raise ValueError proxy_configuration = ProxyConfiguration(new_url_function=custom_new_url_function) with pytest.raises(ValueError, match='The provided "new_url_function" did not return a valid URL'): await proxy_configuration.new_url() - async def test_proxy_configuration_not_sharing_references(self) -> None: + async def test_proxy_configuration_not_sharing_references(self: TestProxyConfigurationNewUrl) -> None: urls = [ 'http://proxy-example-1.com:8000', 'http://proxy-example-2.com:8000', @@ -216,7 +236,7 @@ async def test_proxy_configuration_not_sharing_references(self) -> None: class TestProxyConfigurationNewProxyInfo: - async def test_new_proxy_info_basic(self) -> None: + async def test_new_proxy_info_basic(self: TestProxyConfigurationNewProxyInfo) -> None: groups = ['GROUP1', 'GROUP2'] password = 'abcd1234' country_code = 'US' @@ -241,7 +261,7 @@ async def test_new_proxy_info_basic(self) -> None: 'password': password, } - async def test_new_proxy_info_rotates_urls(self) -> None: + async def test_new_proxy_info_rotates_urls(self: TestProxyConfigurationNewProxyInfo) -> None: proxy_urls = ['http://proxy.com:1111', 'http://proxy.com:2222', 'http://proxy.com:3333'] proxy_configuration = ProxyConfiguration(proxy_urls=proxy_urls) @@ -252,7 +272,7 @@ async def test_new_proxy_info_rotates_urls(self) -> None: assert (await proxy_configuration.new_proxy_info())['url'] == proxy_urls[1] assert (await proxy_configuration.new_proxy_info())['url'] == proxy_urls[2] - async def test_new_proxy_info_rotates_urls_with_sessions(self) -> None: + async def test_new_proxy_info_rotates_urls_with_sessions(self: TestProxyConfigurationNewProxyInfo) -> None: sessions = ['sesssion_01', 'sesssion_02', 'sesssion_03', 'sesssion_04', 'sesssion_05', 'sesssion_06'] proxy_urls = ['http://proxy.com:1111', 'http://proxy.com:2222', 'http://proxy.com:3333'] @@ -275,20 +295,24 @@ async def test_new_proxy_info_rotates_urls_with_sessions(self) -> None: assert (await proxy_configuration.new_proxy_info(sessions[3]))['url'] == proxy_urls[0] -@pytest.fixture +@pytest.fixture() def patched_apify_client(apify_client_async_patcher: ApifyClientAsyncPatcher) -> ApifyClientAsync: - apify_client_async_patcher.patch('user', 'get', return_value={ - 'proxy': { - 'password': DUMMY_PASSWORD, + apify_client_async_patcher.patch( + 'user', + 'get', + return_value={ + 'proxy': { + 'password': DUMMY_PASSWORD, + }, }, - }) + ) return ApifyClientAsync() class TestProxyConfigurationInitialize: async def test_initialize_basic( - self, + self: TestProxyConfigurationInitialize, monkeypatch: pytest.MonkeyPatch, respx_mock: MockRouter, patched_apify_client: ApifyClientAsync, @@ -298,11 +322,16 @@ async def test_initialize_basic( monkeypatch.setenv(ApifyEnvVars.PROXY_STATUS_URL.value, dummy_proxy_status_url) route = respx_mock.get(dummy_proxy_status_url) - route.mock(httpx.Response(200, json={ - 'connected': True, - 'connectionError': None, - 'isManInTheMiddle': True, - })) + route.mock( + httpx.Response( + 200, + json={ + 'connected': True, + 'connectionError': None, + 'isManInTheMiddle': True, + }, + ) + ) proxy_configuration = ProxyConfiguration(_apify_client=patched_apify_client) @@ -314,25 +343,30 @@ async def test_initialize_basic( assert len(patched_apify_client.calls['user']['get']) == 1 # type: ignore assert len(route.calls) == 1 - async def test_initialize_no_password_no_token(self) -> None: + async def test_initialize_no_password_no_token(self: TestProxyConfigurationInitialize) -> None: proxy_configuration = ProxyConfiguration() with pytest.raises(ValueError, match='Apify Proxy password must be provided'): await proxy_configuration.initialize() async def test_initialize_manual_password( - self, + self: TestProxyConfigurationInitialize, monkeypatch: pytest.MonkeyPatch, respx_mock: MockRouter, ) -> None: dummy_proxy_status_url = 'http://dummy-proxy-status-url.com' monkeypatch.setenv(ApifyEnvVars.PROXY_STATUS_URL.value, dummy_proxy_status_url) - respx_mock.get(dummy_proxy_status_url).mock(httpx.Response(200, json={ - 'connected': True, - 'connectionError': None, - 'isManInTheMiddle': False, - })) + respx_mock.get(dummy_proxy_status_url).mock( + httpx.Response( + 200, + json={ + 'connected': True, + 'connectionError': None, + 'isManInTheMiddle': False, + }, + ) + ) proxy_configuration = ProxyConfiguration(password=DUMMY_PASSWORD) @@ -342,7 +376,7 @@ async def test_initialize_manual_password( assert proxy_configuration.is_man_in_the_middle is False async def test_initialize_manual_password_different_than_user_one( - self, + self: TestProxyConfigurationInitialize, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture, respx_mock: MockRouter, @@ -354,11 +388,16 @@ async def test_initialize_manual_password_different_than_user_one( monkeypatch.setenv(ApifyEnvVars.PROXY_STATUS_URL.value, dummy_proxy_status_url) monkeypatch.setenv(ApifyEnvVars.PROXY_PASSWORD.value, different_dummy_password) - respx_mock.get(dummy_proxy_status_url).mock(httpx.Response(200, json={ - 'connected': True, - 'connectionError': None, - 'isManInTheMiddle': True, - })) + respx_mock.get(dummy_proxy_status_url).mock( + httpx.Response( + 200, + json={ + 'connected': True, + 'connectionError': None, + 'isManInTheMiddle': True, + }, + ) + ) proxy_configuration = ProxyConfiguration(_apify_client=patched_apify_client) @@ -372,7 +411,7 @@ async def test_initialize_manual_password_different_than_user_one( assert 'The Apify Proxy password you provided belongs to a different user' in caplog.records[0].message async def test_initialize_not_connected( - self, + self: TestProxyConfigurationInitialize, monkeypatch: pytest.MonkeyPatch, respx_mock: MockRouter, ) -> None: @@ -380,10 +419,15 @@ async def test_initialize_not_connected( dummy_proxy_status_url = 'http://dummy-proxy-status-url.com' monkeypatch.setenv(ApifyEnvVars.PROXY_STATUS_URL.value, dummy_proxy_status_url) - respx_mock.get(dummy_proxy_status_url).mock(httpx.Response(200, json={ - 'connected': False, - 'connectionError': dummy_connection_error, - })) + respx_mock.get(dummy_proxy_status_url).mock( + httpx.Response( + 200, + json={ + 'connected': False, + 'connectionError': dummy_connection_error, + }, + ) + ) proxy_configuration = ProxyConfiguration(password=DUMMY_PASSWORD) @@ -391,7 +435,7 @@ async def test_initialize_not_connected( await proxy_configuration.initialize() async def test_initialize_status_page_unavailable( - self, + self: TestProxyConfigurationInitialize, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture, respx_mock: MockRouter, @@ -410,7 +454,7 @@ async def test_initialize_status_page_unavailable( assert 'Apify Proxy access check timed out' in caplog.records[0].message async def test_initialize_not_called_non_apify_proxy( - self, + self: TestProxyConfigurationInitialize, monkeypatch: pytest.MonkeyPatch, respx_mock: MockRouter, patched_apify_client: ApifyClientAsync, @@ -430,20 +474,20 @@ async def test_initialize_not_called_non_apify_proxy( class TestIsUrl: - def test__is_url(self) -> None: - assert _is_url('http://dummy-proxy.com:8000') is True - assert _is_url('https://example.com') is True - assert _is_url('http://localhost') is True - assert _is_url('https://12.34.56.78') is True - assert _is_url('http://12.34.56.78:9012') is True - assert _is_url('http://::1') is True - assert _is_url('https://2f45:4da6:8f56:af8c:5dce:c1de:14d2:8661') is True - - assert _is_url('dummy-proxy.com:8000') is False - assert _is_url('gyfwgfhkjhljkfhdsf') is False - assert _is_url('http://') is False - assert _is_url('http://example') is False - assert _is_url('http:/example.com') is False - assert _is_url('12.34.56.78') is False - assert _is_url('::1') is False - assert _is_url('https://4da6:8f56:af8c:5dce:c1de:14d2:8661') is False + def test__is_url(self: TestIsUrl) -> None: + assert is_url('http://dummy-proxy.com:8000') is True + assert is_url('https://example.com') is True + assert is_url('http://localhost') is True + assert is_url('https://12.34.56.78') is True + assert is_url('http://12.34.56.78:9012') is True + assert is_url('http://::1') is True + assert is_url('https://2f45:4da6:8f56:af8c:5dce:c1de:14d2:8661') is True + + assert is_url('dummy-proxy.com:8000') is False + assert is_url('gyfwgfhkjhljkfhdsf') is False + assert is_url('http://') is False + assert is_url('http://example') is False + assert is_url('http:/example.com') is False + assert is_url('12.34.56.78') is False + assert is_url('::1') is False + assert is_url('https://4da6:8f56:af8c:5dce:c1de:14d2:8661') is False diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index d1914102..ae09df13 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,33 +1,38 @@ +from __future__ import annotations + import asyncio import contextlib import os import time from collections import OrderedDict from datetime import datetime, timezone -from pathlib import Path +from typing import TYPE_CHECKING import pytest from aiofiles.os import mkdir from apify._utils import ( - _budget_ow, - _fetch_and_parse_env_var, - _force_remove, - _force_rename, - _get_cpu_usage_percent, - _get_memory_usage_bytes, - _guess_file_extension, - _maybe_parse_bool, - _maybe_parse_datetime, - _maybe_parse_int, - _raise_on_duplicate_storage, - _raise_on_non_existing_storage, - _run_func_at_interval_async, - _unique_key_to_request_id, + budget_ow, + fetch_and_parse_env_var, + force_remove, + force_rename, + get_cpu_usage_percent, + get_memory_usage_bytes, + guess_file_extension, + maybe_parse_bool, + maybe_parse_datetime, + maybe_parse_int, + raise_on_duplicate_storage, + raise_on_non_existing_storage, + run_func_at_interval_async, + unique_key_to_request_id, ) -from apify.consts import _StorageTypes +from apify.consts import StorageTypes from apify_shared.consts import ActorEnvVars, ApifyEnvVars +if TYPE_CHECKING: + from pathlib import Path + def test__fetch_and_parse_env_var(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv(ApifyEnvVars.IS_AT_HOME, 'True') @@ -39,55 +44,53 @@ def test__fetch_and_parse_env_var(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv('DUMMY_INT', '1') monkeypatch.setenv('DUMMY_STRING', 'DUMMY') - assert _fetch_and_parse_env_var(ApifyEnvVars.IS_AT_HOME) is True - assert _fetch_and_parse_env_var(ActorEnvVars.MEMORY_MBYTES) == 1024 - assert _fetch_and_parse_env_var(ApifyEnvVars.META_ORIGIN) == 'API' - assert _fetch_and_parse_env_var(ActorEnvVars.STARTED_AT) == \ - datetime(2022, 12, 2, 15, 19, 34, 907000, tzinfo=timezone.utc) + assert fetch_and_parse_env_var(ApifyEnvVars.IS_AT_HOME) is True + assert fetch_and_parse_env_var(ActorEnvVars.MEMORY_MBYTES) == 1024 + assert fetch_and_parse_env_var(ApifyEnvVars.META_ORIGIN) == 'API' + assert fetch_and_parse_env_var(ActorEnvVars.STARTED_AT) == datetime(2022, 12, 2, 15, 19, 34, 907000, tzinfo=timezone.utc) - assert _fetch_and_parse_env_var('DUMMY_BOOL') == '1' # type: ignore - assert _fetch_and_parse_env_var('DUMMY_DATETIME') == '2022-12-02T15:19:34.907Z' # type: ignore - assert _fetch_and_parse_env_var('DUMMY_INT') == '1' # type: ignore - assert _fetch_and_parse_env_var('DUMMY_STRING') == 'DUMMY' # type: ignore - assert _fetch_and_parse_env_var('NONEXISTENT_ENV_VAR') is None # type: ignore - assert _fetch_and_parse_env_var('NONEXISTENT_ENV_VAR', 'default') == 'default' # type: ignore + assert fetch_and_parse_env_var('DUMMY_BOOL') == '1' # type: ignore + assert fetch_and_parse_env_var('DUMMY_DATETIME') == '2022-12-02T15:19:34.907Z' # type: ignore + assert fetch_and_parse_env_var('DUMMY_INT') == '1' # type: ignore + assert fetch_and_parse_env_var('DUMMY_STRING') == 'DUMMY' # type: ignore + assert fetch_and_parse_env_var('NONEXISTENT_ENV_VAR') is None # type: ignore + assert fetch_and_parse_env_var('NONEXISTENT_ENV_VAR', 'default') == 'default' # type: ignore def test__get_cpu_usage_percent() -> None: - assert _get_cpu_usage_percent() >= 0 - assert _get_cpu_usage_percent() <= 100 + assert get_cpu_usage_percent() >= 0 + assert get_cpu_usage_percent() <= 100 def test__get_memory_usage_bytes() -> None: - assert _get_memory_usage_bytes() >= 0 - assert _get_memory_usage_bytes() <= 1024 * 1024 * 1024 * 1024 + assert get_memory_usage_bytes() >= 0 + assert get_memory_usage_bytes() <= 1024 * 1024 * 1024 * 1024 def test__maybe_parse_bool() -> None: - assert _maybe_parse_bool('True') is True - assert _maybe_parse_bool('true') is True - assert _maybe_parse_bool('1') is True - assert _maybe_parse_bool('False') is False - assert _maybe_parse_bool('false') is False - assert _maybe_parse_bool('0') is False - assert _maybe_parse_bool(None) is False - assert _maybe_parse_bool('bflmpsvz') is False + assert maybe_parse_bool('True') is True + assert maybe_parse_bool('true') is True + assert maybe_parse_bool('1') is True + assert maybe_parse_bool('False') is False + assert maybe_parse_bool('false') is False + assert maybe_parse_bool('0') is False + assert maybe_parse_bool(None) is False + assert maybe_parse_bool('bflmpsvz') is False def test__maybe_parse_datetime() -> None: - assert _maybe_parse_datetime('2022-12-02T15:19:34.907Z') == \ - datetime(2022, 12, 2, 15, 19, 34, 907000, tzinfo=timezone.utc) - assert _maybe_parse_datetime('2022-12-02T15:19:34.907') == '2022-12-02T15:19:34.907' - assert _maybe_parse_datetime('anything') == 'anything' + assert maybe_parse_datetime('2022-12-02T15:19:34.907Z') == datetime(2022, 12, 2, 15, 19, 34, 907000, tzinfo=timezone.utc) + assert maybe_parse_datetime('2022-12-02T15:19:34.907') == '2022-12-02T15:19:34.907' + assert maybe_parse_datetime('anything') == 'anything' def test__maybe_parse_int() -> None: - assert _maybe_parse_int('0') == 0 - assert _maybe_parse_int('1') == 1 - assert _maybe_parse_int('-1') == -1 - assert _maybe_parse_int('136749825') == 136749825 - assert _maybe_parse_int('') is None - assert _maybe_parse_int('abcd') is None + assert maybe_parse_int('0') == 0 + assert maybe_parse_int('1') == 1 + assert maybe_parse_int('-1') == -1 + assert maybe_parse_int('136749825') == 136749825 + assert maybe_parse_int('') is None + assert maybe_parse_int('abcd') is None async def test__run_func_at_interval_async__sync_function() -> None: @@ -103,7 +106,7 @@ def sync_increment() -> None: test_var += 1 started_at = time.perf_counter() - sync_increment_task = asyncio.create_task(_run_func_at_interval_async(sync_increment, interval)) + sync_increment_task = asyncio.create_task(run_func_at_interval_async(sync_increment, interval)) try: await asyncio.sleep(initial_delay) @@ -140,7 +143,7 @@ async def async_increment() -> None: test_var += 1 started_at = time.perf_counter() - async_increment_task = asyncio.create_task(_run_func_at_interval_async(async_increment, interval)) + async_increment_task = asyncio.create_task(run_func_at_interval_async(async_increment, interval)) try: await asyncio.sleep(initial_delay) @@ -167,47 +170,47 @@ async def test__force_remove(tmp_path: Path) -> None: test_file_path = os.path.join(tmp_path, 'test.txt') # Does not crash/raise when the file does not exist assert os.path.exists(test_file_path) is False - await _force_remove(test_file_path) + await force_remove(test_file_path) assert os.path.exists(test_file_path) is False # Removes the file if it exists - with open(test_file_path, 'a', encoding='utf-8'): + with open(test_file_path, 'a', encoding='utf-8'): # noqa: ASYNC101 pass assert os.path.exists(test_file_path) is True - await _force_remove(test_file_path) + await force_remove(test_file_path) assert os.path.exists(test_file_path) is False def test__raise_on_non_existing_storage() -> None: with pytest.raises(ValueError, match='Dataset with id "kckxQw6j6AtrgyA09" does not exist.'): - _raise_on_non_existing_storage(_StorageTypes.DATASET, 'kckxQw6j6AtrgyA09') + raise_on_non_existing_storage(StorageTypes.DATASET, 'kckxQw6j6AtrgyA09') def test__raise_on_duplicate_storage() -> None: with pytest.raises(ValueError, match='Dataset with name "test" already exists.'): - _raise_on_duplicate_storage(_StorageTypes.DATASET, 'name', 'test') + raise_on_duplicate_storage(StorageTypes.DATASET, 'name', 'test') def test__guess_file_extension() -> None: # Can guess common types properly - assert _guess_file_extension('application/json') == 'json' - assert _guess_file_extension('application/xml') == 'xml' - assert _guess_file_extension('text/plain') == 'txt' + assert guess_file_extension('application/json') == 'json' + assert guess_file_extension('application/xml') == 'xml' + assert guess_file_extension('text/plain') == 'txt' # Can handle unusual formats - assert _guess_file_extension(' application/json ') == 'json' - assert _guess_file_extension('APPLICATION/JSON') == 'json' - assert _guess_file_extension('application/json;charset=utf-8') == 'json' + assert guess_file_extension(' application/json ') == 'json' + assert guess_file_extension('APPLICATION/JSON') == 'json' + assert guess_file_extension('application/json;charset=utf-8') == 'json' # Returns None for non-existent content types - assert _guess_file_extension('clearly not a content type') is None - assert _guess_file_extension('') is None + assert guess_file_extension('clearly not a content type') is None + assert guess_file_extension('') is None def test__unique_key_to_request_id() -> None: # Right side from `uniqueKeyToRequestId` in Crawlee - assert _unique_key_to_request_id('abc') == 'ungWv48BzpBQUDe' - assert _unique_key_to_request_id('test') == 'n4bQgYhMfWWaLqg' + assert unique_key_to_request_id('abc') == 'ungWv48BzpBQUDe' + assert unique_key_to_request_id('test') == 'n4bQgYhMfWWaLqg' async def test__force_rename(tmp_path: Path) -> None: @@ -217,20 +220,20 @@ async def test__force_rename(tmp_path: Path) -> None: dst_file = os.path.join(dst_dir, 'dst_dir.txt') # Won't crash if source directory does not exist assert os.path.exists(src_dir) is False - await _force_rename(src_dir, dst_dir) + await force_rename(src_dir, dst_dir) # Will remove dst_dir if it exists (also covers normal case) # Create the src_dir with a file in it await mkdir(src_dir) - with open(src_file, 'a', encoding='utf-8'): + with open(src_file, 'a', encoding='utf-8'): # noqa: ASYNC101 pass # Create the dst_dir with a file in it await mkdir(dst_dir) - with open(dst_file, 'a', encoding='utf-8'): + with open(dst_file, 'a', encoding='utf-8'): # noqa: ASYNC101 pass assert os.path.exists(src_file) is True assert os.path.exists(dst_file) is True - await _force_rename(src_dir, dst_dir) + await force_rename(src_dir, dst_dir) assert os.path.exists(src_dir) is False assert os.path.exists(dst_file) is False # src_dir.txt should exist in dst_dir @@ -238,22 +241,28 @@ async def test__force_rename(tmp_path: Path) -> None: def test__budget_ow() -> None: - _budget_ow({ - 'a': 123, - 'b': 'string', - 'c': datetime.now(timezone.utc), - }, { - 'a': (int, True), - 'b': (str, False), - 'c': (datetime, True), - }) + budget_ow( + { + 'a': 123, + 'b': 'string', + 'c': datetime.now(timezone.utc), + }, + { + 'a': (int, True), + 'b': (str, False), + 'c': (datetime, True), + }, + ) with pytest.raises(ValueError, match='required'): - _budget_ow({}, {'id': (str, True)}) + budget_ow({}, {'id': (str, True)}) with pytest.raises(ValueError, match='must be of type'): - _budget_ow({'id': 123}, {'id': (str, True)}) + budget_ow({'id': 123}, {'id': (str, True)}) # Check if subclasses pass the check - _budget_ow({ - 'ordered_dict': OrderedDict(), - }, { - 'ordered_dict': (dict, False), - }) + budget_ow( + { + 'ordered_dict': OrderedDict(), + }, + { + 'ordered_dict': (dict, False), + }, + )