Skip to content

Commit

Permalink
Merge pull request #14 from simonsobs/db
Browse files Browse the repository at this point in the history
Switch to asynchronous DB access
  • Loading branch information
TaiSakuma authored Dec 14, 2023
2 parents bab4a80 + 64f0067 commit 960a8bd
Show file tree
Hide file tree
Showing 17 changed files with 141 additions and 181 deletions.
2 changes: 1 addition & 1 deletion src/nextline_rdb/alembic.ini
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
# output_encoding = utf-8

# sqlalchemy.url = driver://user:pass@localhost/dbname
sqlalchemy.url = sqlite:///migration.sqlite3
sqlalchemy.url = sqlite+aiosqlite:///migration.sqlite3


[post_write_hooks]
Expand Down
38 changes: 25 additions & 13 deletions src/nextline_rdb/alembic/env.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import logging
import logging.config

from alembic import context
from sqlalchemy import create_engine
from sqlalchemy import Connection
from sqlalchemy.ext.asyncio import create_async_engine

from nextline_rdb import models

Expand Down Expand Up @@ -52,25 +54,35 @@ def run_migrations_offline() -> None:
context.run_migrations()


def do_run_migrations(connection: Connection) -> None:
context.configure(
connection=connection,
target_metadata=target_metadata,
render_as_batch=True,
)

with context.begin_transaction():
context.run_migrations()


async def run_async_migrations() -> None:
assert url is not None
connectable = create_async_engine(url)

async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)

await connectable.dispose()


def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
assert url is not None
connectable = create_engine(url)

with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
render_as_batch=True,
)

with context.begin_transaction():
context.run_migrations()
asyncio.run(run_async_migrations())


if context.is_offline_mode():
Expand Down
2 changes: 1 addition & 1 deletion src/nextline_rdb/default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# https://github.com/rochacbruno/learndynaconf/tree/main/configs

[db]
url = "sqlite:///:memory:?check_same_thread=false"
url = "sqlite+aiosqlite://"

[logging.loggers.nextline_rdb]
handlers = ["default"]
Expand Down
14 changes: 7 additions & 7 deletions src/nextline_rdb/pagination.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import NamedTuple, Optional, Type, TypeVar, cast

from sqlalchemy import func, select
from sqlalchemy.orm import aliased
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import aliased, DeclarativeBase
from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.selectable import Select

from . import models as db_models

# import sqlparse

Expand All @@ -24,9 +24,9 @@ class SortField(NamedTuple):
_Id = TypeVar("_Id")


def load_models(
session,
Model: Type[db_models.Model],
async def load_models(
session: AsyncSession,
Model: Type[DeclarativeBase],
id_field: str,
*,
sort: Optional[Sort] = None,
Expand All @@ -45,12 +45,12 @@ def load_models(
last=last,
)

models = session.scalars(stmt)
models = await session.scalars(stmt)
return models


def compose_statement(
Model: Type[db_models.Model],
Model: Type[DeclarativeBase],
id_field: str,
*,
sort: Optional[Sort] = None,
Expand Down
24 changes: 12 additions & 12 deletions src/nextline_rdb/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from nextline import Nextline
from nextlinegraphql.hook import spec
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession

from . import models
from .db import DB
from .db import AsyncDB
from .schema import Mutation, Query, Subscription
from .write import write_db
from .write import async_write_db

HERE = Path(__file__).resolve().parent
DEFAULT_CONFIG_PATH = HERE / 'default.toml'
Expand Down Expand Up @@ -50,32 +50,32 @@ def schema(self) -> tuple[type, type | None, type | None]:
@asynccontextmanager
async def lifespan(self, context: Mapping) -> AsyncIterator[None]:
nextline = context['nextline']
self._db = DB(self._url)
with self._db:
self._db = AsyncDB(self._url)
async with self._db:
await self._initialize_nextline(nextline)
async with write_db(nextline, self._db):
async with async_write_db(nextline, self._db):
yield

async def _initialize_nextline(self, nextline: Nextline) -> None:
run_no, script = self._last_run_no_and_script()
run_no, script = await self._last_run_no_and_script()
if run_no is not None:
run_no += 1
if run_no >= nextline._init_options.run_no_start_from:
nextline._init_options.run_no_start_from = run_no
if script is not None:
nextline._init_options.statement = script

def _last_run_no_and_script(self) -> tuple[Optional[int], Optional[str]]:
with self._db.session() as session:
last_run = self._last_run(session)
async def _last_run_no_and_script(self) -> tuple[Optional[int], Optional[str]]:
async with self._db.session() as session:
last_run = await self._last_run(session)
if last_run is None:
return None, None
else:
return last_run.run_no, last_run.script

def _last_run(self, session: Session) -> Optional[models.Run]:
async def _last_run(self, session: AsyncSession) -> Optional[models.Run]:
stmt = select(models.Run, func.max(models.Run.run_no))
if model := session.execute(stmt).scalar_one_or_none():
if model := (await session.execute(stmt)).scalar_one_or_none():
return model
else:
logger = getLogger(__name__)
Expand Down
13 changes: 7 additions & 6 deletions src/nextline_rdb/schema/pagination/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
Relay doc: https://relay.dev/graphql/connections.htm
"""

from typing import Callable, Generic, Optional, TypeVar
from collections.abc import Callable, Coroutine
from typing import Any, Generic, Optional, TypeVar

import strawberry
from strawberry.types import Info
Expand Down Expand Up @@ -32,9 +33,9 @@ class Connection(Generic[_T]):
edges: list[Edge[_T]]


def query_connection(
async def query_connection(
info: Info,
query_edges: Callable[..., list[Edge[_T]]],
query_edges: Callable[..., Coroutine[Any, Any, list[Edge[_T]]]],
before: Optional[str] = None,
after: Optional[str] = None,
first: Optional[int] = None,
Expand All @@ -49,19 +50,19 @@ def query_connection(
if forward:
if first is not None:
first += 1 # add one for has_next_page
edges = query_edges(info=info, after=after, first=first)
edges = await query_edges(info=info, after=after, first=first)
has_previous_page = not not after
if has_next_page := len(edges) == first:
edges = edges[:-1]
elif backward:
if last is not None:
last += 1 # add one for has_previous_page
edges = query_edges(info=info, before=before, last=last)
edges = await query_edges(info=info, before=before, last=last)
if has_previous_page := len(edges) == last:
edges = edges[1:]
has_next_page = not not before
else:
edges = query_edges(info)
edges = await query_edges(info)
has_previous_page = False
has_next_page = False

Expand Down
8 changes: 4 additions & 4 deletions src/nextline_rdb/schema/pagination/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def decode_id(cursor: str) -> int:
_T = TypeVar("_T")


def load_connection(
async def load_connection(
info: Info,
Model: Type[db_models.Model],
id_field: str,
Expand All @@ -43,7 +43,7 @@ def load_connection(
create_node_from_model=create_node_from_model,
)

return query_connection(
return await query_connection(
info,
query_edges,
before,
Expand All @@ -53,7 +53,7 @@ def load_connection(
)


def load_edges(
async def load_edges(
info: Info,
Model: Type[db_models.Model],
id_field: str,
Expand All @@ -66,7 +66,7 @@ def load_edges(
) -> list[Edge[_T]]:
session = info.context["session"]

models = load_models(
models = await load_models(
session,
Model,
id_field,
Expand Down
4 changes: 2 additions & 2 deletions src/nextline_rdb/schema/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class History:
@strawberry.type
class Query:
@strawberry.field
def history(self, info: Info) -> History:
async def history(self, info: Info) -> History:
db = info.context["db"]
with db.session() as session:
async with db.session() as session:
info.context["session"] = session
return History()
20 changes: 10 additions & 10 deletions src/nextline_rdb/schema/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .pagination import Connection, load_connection


def query_connection_run(
async def query_connection_run(
info: Info,
before: Optional[str] = None,
after: Optional[str] = None,
Expand All @@ -21,10 +21,10 @@ def query_connection_run(
) -> Connection[RunHistory]:
Model = db_models.Run
NodeType = RunHistory
return query_connection(info, before, after, first, last, Model, NodeType)
return await query_connection(info, before, after, first, last, Model, NodeType)


def query_connection_trace(
async def query_connection_trace(
info: Info,
before: Optional[str] = None,
after: Optional[str] = None,
Expand All @@ -33,10 +33,10 @@ def query_connection_trace(
) -> Connection[TraceHistory]:
Model = db_models.Trace
NodeType = TraceHistory
return query_connection(info, before, after, first, last, Model, NodeType)
return await query_connection(info, before, after, first, last, Model, NodeType)


def query_connection_prompt(
async def query_connection_prompt(
info: Info,
before: Optional[str] = None,
after: Optional[str] = None,
Expand All @@ -45,10 +45,10 @@ def query_connection_prompt(
) -> Connection[PromptHistory]:
Model = db_models.Prompt
NodeType = PromptHistory
return query_connection(info, before, after, first, last, Model, NodeType)
return await query_connection(info, before, after, first, last, Model, NodeType)


def query_connection_stdout(
async def query_connection_stdout(
info: Info,
before: Optional[str] = None,
after: Optional[str] = None,
Expand All @@ -57,13 +57,13 @@ def query_connection_stdout(
) -> Connection[StdoutHistory]:
Model = db_models.Stdout
NodeType = StdoutHistory
return query_connection(info, before, after, first, last, Model, NodeType)
return await query_connection(info, before, after, first, last, Model, NodeType)


_T = TypeVar("_T")


def query_connection(
async def query_connection(
info: Info,
before: Optional[str],
after: Optional[str],
Expand All @@ -77,7 +77,7 @@ def query_connection(

create_node_from_model = NodeType.from_model # type: ignore

return load_connection(
return await load_connection(
info,
Model,
id_field,
Expand Down
4 changes: 2 additions & 2 deletions tests/alembic/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def alembic_config() -> Config:
@pytest.fixture
def alembic_config_in_memory(alembic_config: Config) -> Config:
config = alembic_config
url = 'sqlite://'
url = 'sqlite+aiosqlite://'
config.set_main_option('sqlalchemy.url', url)
return config

Expand All @@ -27,6 +27,6 @@ def alembic_config_temp_sqlite(
) -> Config:
config = alembic_config
dir = tmp_path_factory.mktemp('db')
url = f'sqlite:///{dir}/db.sqlite'
url = f'sqlite+aiosqlite:///{dir}/db.sqlite'
config.set_main_option('sqlalchemy.url', url)
return config
Loading

0 comments on commit 960a8bd

Please sign in to comment.