Skip to content

Commit

Permalink
Fix query param for session
Browse files Browse the repository at this point in the history
  • Loading branch information
TheSuperiorStanislav committed Mar 21, 2024
1 parent 1589e32 commit b6f66f2
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions saritasa_sqlalchemy_tools/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def get_async_engine(
host: str,
port: int,
database: str,
query: dict[str, tuple[str, ...] | str],
echo: bool = False,
on_connect: collections.abc.Sequence[SessionOnConnect] = (),
**query,
) -> sqlalchemy.ext.asyncio.AsyncEngine:
"""Set up engine for working with database."""
db_engine = sqlalchemy.ext.asyncio.create_async_engine(
Expand Down Expand Up @@ -56,6 +56,7 @@ def get_async_session_factory(
host: str,
port: int,
database: str,
query: dict[str, tuple[str, ...] | str],
echo: bool = False,
on_connect: collections.abc.Sequence[SessionOnConnect] = (),
autocommit: bool = False,
Expand All @@ -66,7 +67,6 @@ def get_async_session_factory(
# the Session.autoflush parameter.
autoflush: bool = False,
expire_on_commit: bool = False,
**query,
) -> sqlalchemy.ext.asyncio.async_sessionmaker[
sqlalchemy.ext.asyncio.AsyncSession
]:
Expand All @@ -81,7 +81,7 @@ def get_async_session_factory(
database=database,
on_connect=on_connect,
echo=echo,
**query,
query=query,
),
autocommit=autocommit,
autoflush=autoflush,
Expand All @@ -96,12 +96,12 @@ async def get_async_db_session(
host: str,
port: int,
database: str,
query: dict[str, tuple[str, ...] | str],
echo: bool = False,
on_connect: collections.abc.Sequence[SessionOnConnect] = (),
autocommit: bool = False,
autoflush: bool = False,
expire_on_commit: bool = False,
**query,
) -> collections.abc.AsyncIterator[Session]:
"""Set up and get db session."""
async with get_async_session_factory(
Expand All @@ -116,7 +116,7 @@ async def get_async_db_session(
autocommit=autocommit,
autoflush=autoflush,
expire_on_commit=expire_on_commit,
**query,
query=query,
)() as session:
try:
yield session
Expand All @@ -135,12 +135,12 @@ async def get_async_db_session_context(
host: str,
port: int,
database: str,
query: dict[str, tuple[str, ...] | str],
echo: bool = False,
on_connect: collections.abc.Sequence[SessionOnConnect] = (),
autocommit: bool = False,
autoflush: bool = False,
expire_on_commit: bool = False,
**query,
) -> collections.abc.AsyncIterator[Session]:
"""Init db session."""
db_iterator = get_async_db_session(
Expand All @@ -155,7 +155,7 @@ async def get_async_db_session_context(
autocommit=autocommit,
autoflush=autoflush,
expire_on_commit=expire_on_commit,
**query,
query=query,
)
try:
yield await anext(db_iterator) # type: ignore
Expand Down

0 comments on commit b6f66f2

Please sign in to comment.