From 472ad2138734c88dc2a152e6e53b202155b42d1e Mon Sep 17 00:00:00 2001 From: phi Date: Sat, 9 Nov 2024 22:07:54 +0900 Subject: [PATCH] fix: sync --- src/tests/test_sync.py | 1 - src/typed_diskcache/database/connect.py | 15 +++--- src/typed_diskcache/database/connection.py | 52 ++++++++++++++++--- .../implement/cache/default/main.py | 20 +++---- .../implement/cache/default/utils.py | 41 ++++++++++----- src/typed_diskcache/implement/sync/lock.py | 43 ++++++++++----- .../implement/sync/semaphore.py | 15 +++--- 7 files changed, 130 insertions(+), 57 deletions(-) diff --git a/src/tests/test_sync.py b/src/tests/test_sync.py index 1dcb61b..4361a6b 100644 --- a/src/tests/test_sync.py +++ b/src/tests/test_sync.py @@ -94,7 +94,6 @@ async def worker() -> None: assert state["num"] == 2 -@pytest.mark.only async def test_async_rlock(cache): state = {"num": 0} rlock = typed_diskcache.AsyncRLock(cache, "demo") diff --git a/src/typed_diskcache/database/connect.py b/src/typed_diskcache/database/connect.py index 30e4f4f..7e43b3e 100644 --- a/src/typed_diskcache/database/connect.py +++ b/src/typed_diskcache/database/connect.py @@ -9,7 +9,6 @@ ) from typing import TYPE_CHECKING, Any, Callable, Protocol, overload, runtime_checkable -import anyio import sqlalchemy as sa from sqlalchemy.dialects.sqlite import dialect as sqlite_dialect from sqlalchemy.engine import Connection, Engine, create_engine @@ -53,7 +52,6 @@ _TIMEOUT = 10 _TIMEOUT_MS = _TIMEOUT * 1000 -_LOCK = anyio.Lock() logger = get_logger() @@ -243,6 +241,7 @@ def ensure_sqlite_async_engine( def sync_transact(conn: SyncConnT) -> Generator[SyncConnT, None, None]: is_begin = conn.info.get(CONNECTION_BEGIN_INFO_KEY, False) if is_begin is False: + logger.debug("enter transaction, session: `%d`", id(conn)) conn.execute(sa.text("BEGIN IMMEDIATE;")) conn.info[CONNECTION_BEGIN_INFO_KEY] = True @@ -252,17 +251,18 @@ def sync_transact(conn: SyncConnT) -> Generator[SyncConnT, None, None]: conn.rollback() raise finally: + logger.debug("exit transaction, session: `%d`", id(conn)) with suppress(ResourceClosedError): conn.info[CONNECTION_BEGIN_INFO_KEY] = False @asynccontextmanager async def async_transact(conn: AsyncConnT) -> AsyncGenerator[AsyncConnT, None]: - async with _LOCK: - is_begin = conn.info.get(CONNECTION_BEGIN_INFO_KEY, False) - if is_begin is False: - await conn.execute(sa.text("BEGIN IMMEDIATE;")) - conn.info[CONNECTION_BEGIN_INFO_KEY] = True + is_begin = conn.info.get(CONNECTION_BEGIN_INFO_KEY, False) + if is_begin is False: + logger.debug("enter transaction, session: `%d`", id(conn)) + await conn.execute(sa.text("BEGIN IMMEDIATE;")) + conn.info[CONNECTION_BEGIN_INFO_KEY] = True try: yield conn @@ -270,6 +270,7 @@ async def async_transact(conn: AsyncConnT) -> AsyncGenerator[AsyncConnT, None]: await conn.rollback() raise finally: + logger.debug("exit transaction, session: `%d`", id(conn)) with suppress(ResourceClosedError): conn.info[CONNECTION_BEGIN_INFO_KEY] = False diff --git a/src/typed_diskcache/database/connection.py b/src/typed_diskcache/database/connection.py index a6138f3..8a9ba59 100644 --- a/src/typed_diskcache/database/connection.py +++ b/src/typed_diskcache/database/connection.py @@ -16,6 +16,7 @@ from typed_diskcache.core.types import EvictionPolicy from typed_diskcache.database import connect as db_connect from typed_diskcache.database.model import Cache +from typed_diskcache.log import get_logger if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator, Mapping @@ -28,6 +29,8 @@ __all__ = ["Connection"] +logger = get_logger() + class Connection: """Database connection.""" @@ -112,40 +115,72 @@ def _async_engine(self) -> AsyncEngine: return db_connect.set_listeners(engine, self._settings.sqlite_settings) @contextmanager - def _connect(self) -> Generator[SAConnection, None, None]: + def _connect(self, *, stacklevel: int = 1) -> Generator[SAConnection, None, None]: with self._sync_engine.connect() as connection: + logger.debug( + "Creating connection: `%d`", id(connection), stacklevel=stacklevel + ) yield connection + logger.debug( + "Closing connection: `%d`", id(connection), stacklevel=stacklevel + ) @contextmanager - def session(self) -> Generator[Session, None, None]: + def session(self, *, stacklevel: int = 1) -> Generator[Session, None, None]: """Connect to the database.""" session = self._context.get() if session is not None: + logger.debug("Reusing session: `%d`", id(session), stacklevel=stacklevel) yield session return - with self._connect() as connection: + with self._connect(stacklevel=stacklevel + 2) as connection: with Session(connection, autoflush=False) as session: + logger.debug( + "Creating session: `%d`", id(session), stacklevel=stacklevel + ) yield session + logger.debug( + "Closing session: `%d`", id(session), stacklevel=stacklevel + ) @asynccontextmanager - async def _aconnect(self) -> AsyncGenerator[AsyncConnection, None]: + async def _aconnect( + self, *, stacklevel: int = 1 + ) -> AsyncGenerator[AsyncConnection, None]: """Connect to the database.""" async with self._async_engine.connect() as connection: + logger.debug( + "Creating async connection: `%d`", id(connection), stacklevel=stacklevel + ) yield connection + logger.debug( + "Closing async connection: `%d`", id(connection), stacklevel=stacklevel + ) @asynccontextmanager - async def asession(self) -> AsyncGenerator[AsyncSession, None]: + async def asession( + self, *, stacklevel: int = 1 + ) -> AsyncGenerator[AsyncSession, None]: """Connect to the database.""" session = self._acontext.get() if session is not None: + logger.debug( + "Reusing async session: `%d`", id(session), stacklevel=stacklevel + ) await anyio.lowlevel.checkpoint() yield session return - async with self._aconnect() as connection: + async with self._aconnect(stacklevel=stacklevel + 2) as connection: async with AsyncSession(connection, autoflush=False) as session: + logger.debug( + "Creating async session: `%d`", id(session), stacklevel=stacklevel + ) yield session + logger.debug( + "Closing async session: `%d`", id(session), stacklevel=stacklevel + ) def close(self) -> None: """Close the connection.""" @@ -195,6 +230,11 @@ def enter_session( context_var = ( self._acontext if isinstance(session, AsyncSession) else self._context ) + logger.debug( + "Entering session context: `%s`, session: `%d`", + context_var.name, + id(session), + ) with enter_session(session, context_var) as context: # pyright: ignore[reportArgumentType] yield context diff --git a/src/typed_diskcache/implement/cache/default/main.py b/src/typed_diskcache/implement/cache/default/main.py index d56508d..49b5c04 100644 --- a/src/typed_diskcache/implement/cache/default/main.py +++ b/src/typed_diskcache/implement/cache/default/main.py @@ -108,7 +108,7 @@ def __init__( @context("Cache.length") @override def __len__(self) -> int: - with self.conn.session() as session: + with self.conn.session(stacklevel=4) as session: return session.scalars( sa.select(Metadata.value).where(Metadata.key == MetadataKey.COUNT) ).one() @@ -128,7 +128,7 @@ def __getitem__(self, key: Any) -> Container[Any]: @override def __contains__(self, key: Any) -> bool: db_key, raw = self.disk.put(key) - with self.conn.session() as session: + with self.conn.session(stacklevel=4) as session: row = session.scalars( sa.select(CacheTable.id).where( CacheTable.key == db_key, @@ -241,7 +241,7 @@ def get( and self.settings.eviction_policy == EvictionPolicy.NONE ): logger.debug("Cache statistics disabled or eviction policy is NONE") - with self.conn.session() as session: + with self.conn.session(stacklevel=4) as session: row = session.scalars( select_stmt, {"expire_time": time.time()} ).one_or_none() @@ -328,7 +328,7 @@ async def aget( and self.settings.eviction_policy == EvictionPolicy.NONE ): logger.debug("Cache statistics disabled or eviction policy is NONE") - async with self.conn.asession() as session: + async with self.conn.asession(stacklevel=4) as session: row_fetch = await session.scalars( select_stmt, {"expire_time": time.time()} ) @@ -685,7 +685,7 @@ async def _async_cull( @context @override def volume(self) -> int: - with self.conn.session() as session: + with self.conn.session(stacklevel=4) as session: page_count: int = session.execute( sa.text("PRAGMA page_count;") ).scalar_one() @@ -698,7 +698,7 @@ def volume(self) -> int: @context @override async def avolume(self) -> int: - async with self.conn.asession() as session: + async with self.conn.asession(stacklevel=4) as session: page_count_fetch = await session.execute(sa.text("PRAGMA page_count;")) page_count: int = page_count_fetch.scalar_one() size_fetch = await session.scalars( @@ -1068,7 +1068,7 @@ def filter( lower_bound = 0 tags_count = len(tags) while True: - with self.conn.session() as session: + with self.conn.session(stacklevel=4) as session: rows = session.execute( stmt, { @@ -1105,7 +1105,7 @@ async def afilter( lower_bound = 0 tags_count = len(tags) while True: - async with self.conn.asession() as session: + async with self.conn.asession(stacklevel=4) as session: rows_fetch = await session.execute( stmt, { @@ -1941,7 +1941,7 @@ async def acheck( @override def iterkeys(self, *, reverse: bool = False) -> Generator[Any, None, None]: select_stmt, iter_stmt = default_utils.prepare_iterkeys_stmt(reverse=reverse) - with self.conn.session() as session: + with self.conn.session(stacklevel=4) as session: row = session.execute(select_stmt).one_or_none() if not row: @@ -1965,7 +1965,7 @@ def iterkeys(self, *, reverse: bool = False) -> Generator[Any, None, None]: @override async def aiterkeys(self, *, reverse: bool = False) -> AsyncGenerator[Any, None]: select_stmt, iter_stmt = default_utils.prepare_iterkeys_stmt(reverse=reverse) - async with self.conn.asession() as session: + async with self.conn.asession(stacklevel=4) as session: row_fetch = await session.execute(select_stmt) row = row_fetch.one_or_none() diff --git a/src/typed_diskcache/implement/cache/default/utils.py b/src/typed_diskcache/implement/cache/default/utils.py index 60b7ded..900e5fc 100644 --- a/src/typed_diskcache/implement/cache/default/utils.py +++ b/src/typed_diskcache/implement/cache/default/utils.py @@ -152,16 +152,17 @@ def prepare_cull_stmt( return filenames_select_stmt, filenames_delete_stmt, select_stmt -def transact_process( +def transact_process( # noqa: PLR0913 stack: ExitStack, conn: Connection, disk: DiskProtocol, *, retry: bool = False, filename: str | PathLike[str] | None = None, + stacklevel: int = 3, ) -> Session | None: try: - session = stack.enter_context(conn.session()) + session = stack.enter_context(conn.session(stacklevel=stacklevel)) session = stack.enter_context(database_transact(session)) except OperationalError as exc: stack.close() @@ -174,16 +175,17 @@ def transact_process( return session -async def async_transact_process( +async def async_transact_process( # noqa: PLR0913 stack: AsyncExitStack, conn: Connection, disk: DiskProtocol, *, retry: bool = False, filename: str | PathLike[str] | None = None, + stacklevel: int = 3, ) -> AsyncSession | None: try: - session = await stack.enter_async_context(conn.asession()) + session = await stack.enter_async_context(conn.asession(stacklevel=stacklevel)) session = await stack.enter_async_context(database_transact(session)) except OperationalError as exc: await stack.aclose() @@ -218,7 +220,7 @@ def iter_disk( ) while True: - with conn.session() as session: + with conn.session(stacklevel=4) as session: rows = session.execute( stmt, {"left_bound": rowid, "right_bound": bound} @@ -254,7 +256,7 @@ async def aiter_disk( ) while True: - async with conn.asession() as session: + async with conn.asession(stacklevel=4) as session: rows_fetch = await session.execute( stmt, {"left_bound": rowid, "right_bound": bound} ) @@ -422,7 +424,12 @@ def transact( while session is None: stack.close() session = transact_process( - stack, conn, disk, retry=retry, filename=filename + stack, + conn, + disk, + retry=retry, + filename=filename, + stacklevel=stacklevel + 4, ) logger.debug("Enter transaction `%s`", filename, stacklevel=stacklevel) @@ -461,12 +468,20 @@ async def async_transact( while session is None: await stack.aclose() session = await async_transact_process( - stack, conn, disk, retry=retry, filename=filename + stack, + conn, + disk, + retry=retry, + filename=filename, + stacklevel=stacklevel + 4, ) logger.debug("Enter async transaction `%s`", filename, stacklevel=stacklevel) stack.callback( - logger.debug, "Exit async transaction `%s`", filename, stacklevel=stacklevel + logger.debug, + "Exit async transaction `%s`", + filename, + stacklevel=stacklevel + 2, ) try: stack.enter_context(receive) @@ -599,12 +614,12 @@ def prepare_filter_stmt( def find_max_id(conn: Connection) -> int | None: - with conn.session() as session: + with conn.session(stacklevel=4) as session: return session.scalar(sa.select(sa.func.max(CacheTable.id))) async def async_find_max_id(conn: Connection) -> int | None: - async with conn.asession() as session: + async with conn.asession(stacklevel=4) as session: return await session.scalar(sa.select(sa.func.max(CacheTable.id))) @@ -976,7 +991,7 @@ def prepare_iterkeys_stmt( async def acheck_integrity(*, conn: Connection, fix: bool, stacklevel: int = 2) -> None: - async with conn.asession() as session: + async with conn.asession(stacklevel=4) as session: integrity_fetch = await session.execute(sa.text("PRAGMA integrity_check;")) integrity = integrity_fetch.scalars().all() @@ -1168,7 +1183,7 @@ async def acheck_metadata_size( def check_integrity(*, conn: Connection, fix: bool, stacklevel: int = 2) -> None: - with conn.session() as session: + with conn.session(stacklevel=4) as session: integrity = session.execute(sa.text("PRAGMA integrity_check;")).scalars().all() if len(integrity) != 1 or integrity[0] != "ok": diff --git a/src/typed_diskcache/implement/sync/lock.py b/src/typed_diskcache/implement/sync/lock.py index 6f67331..2bb30ec 100644 --- a/src/typed_diskcache/implement/sync/lock.py +++ b/src/typed_diskcache/implement/sync/lock.py @@ -165,16 +165,19 @@ def main() -> None: @override def acquire(self) -> None: pid = os.getpid() - tid = threading.get_ident() + tid = threading.get_native_id() pid_tid = f"{pid}-{tid}" start = time.monotonic() timeout = 0 with ExitStack() as stack: + session = stack.enter_context(self._cache.conn.session(stacklevel=4)) + sub_stack = stack.enter_context(ExitStack()) while timeout < self.timeout: - session = stack.enter_context(self._cache.conn.session()) - stack.enter_context(transact(session)) - context = stack.enter_context(self._cache.conn.enter_session(session)) + sub_stack.enter_context(transact(session)) + context = sub_stack.enter_context( + self._cache.conn.enter_session(session) + ) container = context.run( self._cache.get, self.key, default=("default", 0) ) @@ -194,7 +197,12 @@ def acquire(self) -> None: tags=self.tags, ) return - stack.close() + logger.debug( + "Invalid lock: expected: `%s`, value: `%s`", + pid_tid, + container_value, + ) + sub_stack.close() time.sleep(SPIN_LOCK_SLEEP) timeout = time.monotonic() - start @@ -204,11 +212,12 @@ def acquire(self) -> None: @override def release(self) -> None: pid = os.getpid() - tid = threading.get_ident() + tid = threading.get_native_id() pid_tid = f"{pid}-{tid}" with ExitStack() as stack: - session = stack.enter_context(self._cache.conn.session()) + logger.debug("releasing lock: %s", pid_tid) + session = stack.enter_context(self._cache.conn.session(stacklevel=4)) stack.enter_context(transact(session)) context = stack.enter_context(self._cache.conn.enter_session(session)) container = context.run(self._cache.get, self.key, default=("default", 0)) @@ -378,17 +387,17 @@ async def acquire(self) -> None: import anyio pid = os.getpid() - tid = threading.get_ident() + tid = threading.get_native_id() pid_tid = f"{pid}-{tid}" try: async with AsyncExitStack() as stack: stack.enter_context(anyio.fail_after(self.timeout)) + session = await stack.enter_async_context( + self._cache.conn.asession(stacklevel=4) + ) sub_stack = await stack.enter_async_context(AsyncExitStack()) while True: - session = await sub_stack.enter_async_context( - self._cache.conn.asession() - ) await sub_stack.enter_async_context(transact(session)) context = stack.enter_context( self._cache.conn.enter_session(session) @@ -412,6 +421,11 @@ async def acquire(self) -> None: tags=self.tags, ) return + logger.debug( + "Invalid lock: expected: `%s`, value: `%s`", + pid_tid, + container_value, + ) await sub_stack.aclose() await anyio.sleep(SPIN_LOCK_SLEEP) except TimeoutError as exc: @@ -422,11 +436,14 @@ async def acquire(self) -> None: async def release(self) -> None: """Release lock by decrementing count.""" pid = os.getpid() - tid = threading.get_ident() + tid = threading.get_native_id() pid_tid = f"{pid}-{tid}" async with AsyncExitStack() as stack: - session = await stack.enter_async_context(self._cache.conn.asession()) + logger.debug("releasing lock: %s", pid_tid) + session = await stack.enter_async_context( + self._cache.conn.asession(stacklevel=4) + ) await stack.enter_async_context(transact(session)) context = stack.enter_context(self._cache.conn.enter_session(session)) container = await self._cache.aget(self.key, default=("default", 0)) diff --git a/src/typed_diskcache/implement/sync/semaphore.py b/src/typed_diskcache/implement/sync/semaphore.py index cf72219..c142a7d 100644 --- a/src/typed_diskcache/implement/sync/semaphore.py +++ b/src/typed_diskcache/implement/sync/semaphore.py @@ -106,10 +106,13 @@ def acquire(self) -> None: start = time.monotonic() timeout = 0 with ExitStack() as stack: + session = stack.enter_context(self._cache.conn.session()) + sub_stack = stack.enter_context(ExitStack()) while timeout < self.timeout: - session = stack.enter_context(self._cache.conn.session()) - stack.enter_context(transact(session)) - context = stack.enter_context(self._cache.conn.enter_session(session)) + sub_stack.enter_context(transact(session)) + context = sub_stack.enter_context( + self._cache.conn.enter_session(session) + ) container = context.run(self._cache.get, self.key, default=self._value) container_value = validate_semaphore_value(container.value) if container_value > 0: @@ -121,7 +124,7 @@ def acquire(self) -> None: tags=self.tags, ) return - stack.close() + sub_stack.close() time.sleep(SPIN_LOCK_SLEEP) timeout = time.monotonic() - start @@ -251,11 +254,9 @@ async def acquire(self) -> None: try: async with AsyncExitStack() as stack: stack.enter_context(anyio.fail_after(self.timeout)) + session = await stack.enter_async_context(self._cache.conn.asession()) sub_stack = await stack.enter_async_context(AsyncExitStack()) while True: - session = await sub_stack.enter_async_context( - self._cache.conn.asession() - ) await sub_stack.enter_async_context(transact(session)) context = stack.enter_context( self._cache.conn.enter_session(session)