Skip to content

Commit

Permalink
fix: sync
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-friday committed Nov 9, 2024
1 parent eec91fb commit 472ad21
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 57 deletions.
1 change: 0 additions & 1 deletion src/tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ async def worker() -> None:
assert state["num"] == 2


@pytest.mark.only
async def test_async_rlock(cache):
state = {"num": 0}
rlock = typed_diskcache.AsyncRLock(cache, "demo")
Expand Down
15 changes: 8 additions & 7 deletions src/typed_diskcache/database/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
)
from typing import TYPE_CHECKING, Any, Callable, Protocol, overload, runtime_checkable

import anyio
import sqlalchemy as sa
from sqlalchemy.dialects.sqlite import dialect as sqlite_dialect
from sqlalchemy.engine import Connection, Engine, create_engine
Expand Down Expand Up @@ -53,7 +52,6 @@

_TIMEOUT = 10
_TIMEOUT_MS = _TIMEOUT * 1000
_LOCK = anyio.Lock()

logger = get_logger()

Expand Down Expand Up @@ -243,6 +241,7 @@ def ensure_sqlite_async_engine(
def sync_transact(conn: SyncConnT) -> Generator[SyncConnT, None, None]:
is_begin = conn.info.get(CONNECTION_BEGIN_INFO_KEY, False)
if is_begin is False:
logger.debug("enter transaction, session: `%d`", id(conn))
conn.execute(sa.text("BEGIN IMMEDIATE;"))
conn.info[CONNECTION_BEGIN_INFO_KEY] = True

Expand All @@ -252,24 +251,26 @@ def sync_transact(conn: SyncConnT) -> Generator[SyncConnT, None, None]:
conn.rollback()
raise
finally:
logger.debug("exit transaction, session: `%d`", id(conn))
with suppress(ResourceClosedError):
conn.info[CONNECTION_BEGIN_INFO_KEY] = False


@asynccontextmanager
async def async_transact(conn: AsyncConnT) -> AsyncGenerator[AsyncConnT, None]:
async with _LOCK:
is_begin = conn.info.get(CONNECTION_BEGIN_INFO_KEY, False)
if is_begin is False:
await conn.execute(sa.text("BEGIN IMMEDIATE;"))
conn.info[CONNECTION_BEGIN_INFO_KEY] = True
is_begin = conn.info.get(CONNECTION_BEGIN_INFO_KEY, False)
if is_begin is False:
logger.debug("enter transaction, session: `%d`", id(conn))
await conn.execute(sa.text("BEGIN IMMEDIATE;"))
conn.info[CONNECTION_BEGIN_INFO_KEY] = True

try:
yield conn
except Exception:
await conn.rollback()
raise
finally:
logger.debug("exit transaction, session: `%d`", id(conn))
with suppress(ResourceClosedError):
conn.info[CONNECTION_BEGIN_INFO_KEY] = False

Expand Down
52 changes: 46 additions & 6 deletions src/typed_diskcache/database/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typed_diskcache.core.types import EvictionPolicy
from typed_diskcache.database import connect as db_connect
from typed_diskcache.database.model import Cache
from typed_diskcache.log import get_logger

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Generator, Mapping
Expand All @@ -28,6 +29,8 @@

__all__ = ["Connection"]

logger = get_logger()


class Connection:
"""Database connection."""
Expand Down Expand Up @@ -112,40 +115,72 @@ def _async_engine(self) -> AsyncEngine:
return db_connect.set_listeners(engine, self._settings.sqlite_settings)

@contextmanager
def _connect(self) -> Generator[SAConnection, None, None]:
def _connect(self, *, stacklevel: int = 1) -> Generator[SAConnection, None, None]:
with self._sync_engine.connect() as connection:
logger.debug(
"Creating connection: `%d`", id(connection), stacklevel=stacklevel
)
yield connection
logger.debug(
"Closing connection: `%d`", id(connection), stacklevel=stacklevel
)

@contextmanager
def session(self) -> Generator[Session, None, None]:
def session(self, *, stacklevel: int = 1) -> Generator[Session, None, None]:
"""Connect to the database."""
session = self._context.get()
if session is not None:
logger.debug("Reusing session: `%d`", id(session), stacklevel=stacklevel)
yield session
return

with self._connect() as connection:
with self._connect(stacklevel=stacklevel + 2) as connection:
with Session(connection, autoflush=False) as session:
logger.debug(
"Creating session: `%d`", id(session), stacklevel=stacklevel
)
yield session
logger.debug(
"Closing session: `%d`", id(session), stacklevel=stacklevel
)

@asynccontextmanager
async def _aconnect(self) -> AsyncGenerator[AsyncConnection, None]:
async def _aconnect(
self, *, stacklevel: int = 1
) -> AsyncGenerator[AsyncConnection, None]:
"""Connect to the database."""
async with self._async_engine.connect() as connection:
logger.debug(
"Creating async connection: `%d`", id(connection), stacklevel=stacklevel
)
yield connection
logger.debug(
"Closing async connection: `%d`", id(connection), stacklevel=stacklevel
)

@asynccontextmanager
async def asession(self) -> AsyncGenerator[AsyncSession, None]:
async def asession(
self, *, stacklevel: int = 1
) -> AsyncGenerator[AsyncSession, None]:
"""Connect to the database."""
session = self._acontext.get()
if session is not None:
logger.debug(
"Reusing async session: `%d`", id(session), stacklevel=stacklevel
)
await anyio.lowlevel.checkpoint()
yield session
return

async with self._aconnect() as connection:
async with self._aconnect(stacklevel=stacklevel + 2) as connection:
async with AsyncSession(connection, autoflush=False) as session:
logger.debug(
"Creating async session: `%d`", id(session), stacklevel=stacklevel
)
yield session
logger.debug(
"Closing async session: `%d`", id(session), stacklevel=stacklevel
)

def close(self) -> None:
"""Close the connection."""
Expand Down Expand Up @@ -195,6 +230,11 @@ def enter_session(
context_var = (
self._acontext if isinstance(session, AsyncSession) else self._context
)
logger.debug(
"Entering session context: `%s`, session: `%d`",
context_var.name,
id(session),
)
with enter_session(session, context_var) as context: # pyright: ignore[reportArgumentType]
yield context

Expand Down
20 changes: 10 additions & 10 deletions src/typed_diskcache/implement/cache/default/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
@context("Cache.length")
@override
def __len__(self) -> int:
with self.conn.session() as session:
with self.conn.session(stacklevel=4) as session:
return session.scalars(
sa.select(Metadata.value).where(Metadata.key == MetadataKey.COUNT)
).one()
Expand All @@ -128,7 +128,7 @@ def __getitem__(self, key: Any) -> Container[Any]:
@override
def __contains__(self, key: Any) -> bool:
db_key, raw = self.disk.put(key)
with self.conn.session() as session:
with self.conn.session(stacklevel=4) as session:
row = session.scalars(
sa.select(CacheTable.id).where(
CacheTable.key == db_key,
Expand Down Expand Up @@ -241,7 +241,7 @@ def get(
and self.settings.eviction_policy == EvictionPolicy.NONE
):
logger.debug("Cache statistics disabled or eviction policy is NONE")
with self.conn.session() as session:
with self.conn.session(stacklevel=4) as session:
row = session.scalars(
select_stmt, {"expire_time": time.time()}
).one_or_none()
Expand Down Expand Up @@ -328,7 +328,7 @@ async def aget(
and self.settings.eviction_policy == EvictionPolicy.NONE
):
logger.debug("Cache statistics disabled or eviction policy is NONE")
async with self.conn.asession() as session:
async with self.conn.asession(stacklevel=4) as session:
row_fetch = await session.scalars(
select_stmt, {"expire_time": time.time()}
)
Expand Down Expand Up @@ -685,7 +685,7 @@ async def _async_cull(
@context
@override
def volume(self) -> int:
with self.conn.session() as session:
with self.conn.session(stacklevel=4) as session:
page_count: int = session.execute(
sa.text("PRAGMA page_count;")
).scalar_one()
Expand All @@ -698,7 +698,7 @@ def volume(self) -> int:
@context
@override
async def avolume(self) -> int:
async with self.conn.asession() as session:
async with self.conn.asession(stacklevel=4) as session:
page_count_fetch = await session.execute(sa.text("PRAGMA page_count;"))
page_count: int = page_count_fetch.scalar_one()
size_fetch = await session.scalars(
Expand Down Expand Up @@ -1068,7 +1068,7 @@ def filter(
lower_bound = 0
tags_count = len(tags)
while True:
with self.conn.session() as session:
with self.conn.session(stacklevel=4) as session:
rows = session.execute(
stmt,
{
Expand Down Expand Up @@ -1105,7 +1105,7 @@ async def afilter(
lower_bound = 0
tags_count = len(tags)
while True:
async with self.conn.asession() as session:
async with self.conn.asession(stacklevel=4) as session:
rows_fetch = await session.execute(
stmt,
{
Expand Down Expand Up @@ -1941,7 +1941,7 @@ async def acheck(
@override
def iterkeys(self, *, reverse: bool = False) -> Generator[Any, None, None]:
select_stmt, iter_stmt = default_utils.prepare_iterkeys_stmt(reverse=reverse)
with self.conn.session() as session:
with self.conn.session(stacklevel=4) as session:
row = session.execute(select_stmt).one_or_none()

if not row:
Expand All @@ -1965,7 +1965,7 @@ def iterkeys(self, *, reverse: bool = False) -> Generator[Any, None, None]:
@override
async def aiterkeys(self, *, reverse: bool = False) -> AsyncGenerator[Any, None]:
select_stmt, iter_stmt = default_utils.prepare_iterkeys_stmt(reverse=reverse)
async with self.conn.asession() as session:
async with self.conn.asession(stacklevel=4) as session:
row_fetch = await session.execute(select_stmt)
row = row_fetch.one_or_none()

Expand Down
41 changes: 28 additions & 13 deletions src/typed_diskcache/implement/cache/default/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,17 @@ def prepare_cull_stmt(
return filenames_select_stmt, filenames_delete_stmt, select_stmt


def transact_process(
def transact_process( # noqa: PLR0913
stack: ExitStack,
conn: Connection,
disk: DiskProtocol,
*,
retry: bool = False,
filename: str | PathLike[str] | None = None,
stacklevel: int = 3,
) -> Session | None:
try:
session = stack.enter_context(conn.session())
session = stack.enter_context(conn.session(stacklevel=stacklevel))
session = stack.enter_context(database_transact(session))
except OperationalError as exc:
stack.close()
Expand All @@ -174,16 +175,17 @@ def transact_process(
return session


async def async_transact_process(
async def async_transact_process( # noqa: PLR0913
stack: AsyncExitStack,
conn: Connection,
disk: DiskProtocol,
*,
retry: bool = False,
filename: str | PathLike[str] | None = None,
stacklevel: int = 3,
) -> AsyncSession | None:
try:
session = await stack.enter_async_context(conn.asession())
session = await stack.enter_async_context(conn.asession(stacklevel=stacklevel))
session = await stack.enter_async_context(database_transact(session))
except OperationalError as exc:
await stack.aclose()
Expand Down Expand Up @@ -218,7 +220,7 @@ def iter_disk(
)

while True:
with conn.session() as session:
with conn.session(stacklevel=4) as session:
rows = session.execute(
stmt,
{"left_bound": rowid, "right_bound": bound}
Expand Down Expand Up @@ -254,7 +256,7 @@ async def aiter_disk(
)

while True:
async with conn.asession() as session:
async with conn.asession(stacklevel=4) as session:
rows_fetch = await session.execute(
stmt, {"left_bound": rowid, "right_bound": bound}
)
Expand Down Expand Up @@ -422,7 +424,12 @@ def transact(
while session is None:
stack.close()
session = transact_process(
stack, conn, disk, retry=retry, filename=filename
stack,
conn,
disk,
retry=retry,
filename=filename,
stacklevel=stacklevel + 4,
)

logger.debug("Enter transaction `%s`", filename, stacklevel=stacklevel)
Expand Down Expand Up @@ -461,12 +468,20 @@ async def async_transact(
while session is None:
await stack.aclose()
session = await async_transact_process(
stack, conn, disk, retry=retry, filename=filename
stack,
conn,
disk,
retry=retry,
filename=filename,
stacklevel=stacklevel + 4,
)

logger.debug("Enter async transaction `%s`", filename, stacklevel=stacklevel)
stack.callback(
logger.debug, "Exit async transaction `%s`", filename, stacklevel=stacklevel
logger.debug,
"Exit async transaction `%s`",
filename,
stacklevel=stacklevel + 2,
)
try:
stack.enter_context(receive)
Expand Down Expand Up @@ -599,12 +614,12 @@ def prepare_filter_stmt(


def find_max_id(conn: Connection) -> int | None:
with conn.session() as session:
with conn.session(stacklevel=4) as session:
return session.scalar(sa.select(sa.func.max(CacheTable.id)))


async def async_find_max_id(conn: Connection) -> int | None:
async with conn.asession() as session:
async with conn.asession(stacklevel=4) as session:
return await session.scalar(sa.select(sa.func.max(CacheTable.id)))


Expand Down Expand Up @@ -976,7 +991,7 @@ def prepare_iterkeys_stmt(


async def acheck_integrity(*, conn: Connection, fix: bool, stacklevel: int = 2) -> None:
async with conn.asession() as session:
async with conn.asession(stacklevel=4) as session:
integrity_fetch = await session.execute(sa.text("PRAGMA integrity_check;"))
integrity = integrity_fetch.scalars().all()

Expand Down Expand Up @@ -1168,7 +1183,7 @@ async def acheck_metadata_size(


def check_integrity(*, conn: Connection, fix: bool, stacklevel: int = 2) -> None:
with conn.session() as session:
with conn.session(stacklevel=4) as session:
integrity = session.execute(sa.text("PRAGMA integrity_check;")).scalars().all()

if len(integrity) != 1 or integrity[0] != "ok":
Expand Down
Loading

0 comments on commit 472ad21

Please sign in to comment.