Skip to content

Commit

Permalink
Merge pull request #81 from TheJacksonLaboratory/G3-75-fix-user-depen…
Browse files Browse the repository at this point in the history
…dency-function

G3 75 fix user dependency function
  • Loading branch information
bergsalex authored Jul 15, 2024
2 parents 4509faf + 07f7e69 commit be0f083
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "geneweaver-api"
version = "0.7.0a9"
version = "0.7.0a10"
description = "The Geneweaver API"
authors = [
"Alexander Berger <alexander.berger@jax.org>",
Expand Down
4 changes: 2 additions & 2 deletions src/geneweaver/api/controller/genesets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

@router.get("")
def get_visible_genesets(
user: UserInternal = Security(deps.optional_full_user),
cursor: Optional[deps.Cursor] = Depends(deps.cursor),
cursor: deps.CursorDep,
user: deps.OptionalFullUserDep,
gs_id: Annotated[
Optional[int],
Query(
Expand Down
30 changes: 22 additions & 8 deletions src/geneweaver/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import logging
from contextlib import asynccontextmanager
from tempfile import TemporaryDirectory
from typing import Generator, Optional
from typing import Annotated, Optional

import psycopg
from fastapi import Depends, FastAPI, Request
from geneweaver.api.core.config import settings
from geneweaver.api.core.exceptions import AuthenticationMismatch
from geneweaver.api.core.security import Auth0, UserInternal
from geneweaver.db import user as db_user
from psycopg.rows import dict_row
from psycopg.rows import DictRow, dict_row
from psycopg_pool import ConnectionPool

auth = Auth0(
Expand All @@ -34,7 +34,11 @@ async def lifespan(app: FastAPI) -> None:
:param app: The FastAPI application (dependency injection).
"""
logger.info("Opening DB Connection Pool.")
app.pool = ConnectionPool(settings.DB.URI)
app.pool = ConnectionPool(
settings.DB.URI,
connection_class=psycopg.Connection[DictRow],
kwargs={"row_factory": dict_row},
)
app.pool.open()
app.pool.wait()
with app.pool.connection() as conn:
Expand All @@ -50,13 +54,17 @@ async def lifespan(app: FastAPI) -> None:
app.pool.close()


def cursor(request: Request) -> Generator:
async def cursor(request: Request) -> Cursor:
"""Get a cursor from the connection pool."""
logger.debug("Getting cursor from pool.")
with request.app.pool.connection() as conn:
with conn.cursor(row_factory=dict_row) as cur:
with conn.cursor() as cur:
yield cur


CursorDep = Annotated[Cursor, Depends(cursor)]


def _get_user_details(cursor: Cursor, user: UserInternal) -> UserInternal:
"""Get the user details.
Expand Down Expand Up @@ -100,8 +108,11 @@ async def full_user(
yield _get_user_details(cursor, user)


FullUserDep = Annotated[UserInternal, Depends(full_user)]


async def optional_full_user(
cursor: Cursor = Depends(cursor),
cursor: CursorDep,
user: Optional[UserInternal] = Depends(auth.get_user),
) -> Optional[UserInternal]:
"""Get the full user object, if request is logged in.
Expand All @@ -115,8 +126,11 @@ async def optional_full_user(
@param user: GW user.
"""
if user is not None:
yield _get_user_details(cursor, user)
yield None
return _get_user_details(cursor, user)
return None


OptionalFullUserDep = Annotated[Optional[UserInternal], Depends(optional_full_user)]


async def get_temp_dir() -> TemporaryDirectory:
Expand Down

0 comments on commit be0f083

Please sign in to comment.