diff --git a/xas-standards-api/src/xas_standards_api/app.py b/xas-standards-api/src/xas_standards_api/app.py index d79c464..47f6a18 100644 --- a/xas-standards-api/src/xas_standards_api/app.py +++ b/xas-standards-api/src/xas_standards_api/app.py @@ -1,53 +1,13 @@ -import datetime import os -from typing import Annotated, List, Optional -import requests from fastapi import ( - Depends, FastAPI, - File, - Form, - HTTPException, - Query, - UploadFile, - status, ) -from fastapi.responses import HTMLResponse -from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer from fastapi.staticfiles import StaticFiles from fastapi_pagination import add_pagination -from fastapi_pagination.cursor import CursorPage -from fastapi_pagination.ext.sqlalchemy import paginate -from sqlmodel import Session, create_engine, select from starlette.responses import RedirectResponse -from .crud import ( - add_new_standard, - get_data, - get_file, - get_file_as_text, - get_metadata, - get_standard, - select_all, - select_or_create_person, - update_review, -) -from .schemas import ( - AdminXASStandardResponse, - Beamline, - BeamlineResponse, - Edge, - Element, - LicenceType, - MetadataResponse, - Person, - ReviewStatus, - XASStandard, - XASStandardAdminReviewInput, - XASStandardInput, - XASStandardResponse, -) +from .routers import admin, open, protected dev = False @@ -57,29 +17,14 @@ print("RUNNING IN DEV MODE") dev = True -get_bearer_token = HTTPBearer(auto_error=True) -url = os.environ.get("POSTGRESURL") build_dir = os.environ.get("FRONTEND_BUILD_DIR") -oidc_user_info_endpoint = os.environ.get("OIDC_USER_INFO_ENDPOINT") - - -if url: - engine = create_engine(url) -else: - print("URL not set - unit tests only") - - -def get_session(): - with Session(engine) as session: - yield session - app = FastAPI() -CursorPage = CursorPage.with_custom_options( - size=Query(10, ge=1, le=100), -) +app.include_router(open.router) +app.include_router(protected.router) +app.include_router(admin.router) add_pagination(app) @@ -90,229 +35,5 @@ async def redirect_home(): return "/" -async def get_current_user( - auth: HTTPAuthorizationCredentials = Depends(get_bearer_token), -): - - if auth is None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid user token", - ) - - if dev: - return auth.credentials - - if oidc_user_info_endpoint is None: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="User info endpoint error", - ) - - response = requests.get( - url=oidc_user_info_endpoint, - headers={"Authorization": f"Bearer {auth.credentials}"}, - ) - - if response.status_code == 401: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid user token", - ) - - return response.json()["id"] - - -@app.get("/api/user") -async def check( - session: Session = Depends(get_session), user_id: str = Depends(get_current_user) -): - - statement = select(Person).where(Person.identifier == user_id) - person = session.exec(statement).first() - - admin = person is not None and person.admin - - return {"user": user_id, "admin": admin} - - -@app.get("/api/metadata") -def read_metadata(session: Session = Depends(get_session)) -> MetadataResponse: - return get_metadata(session) - - -@app.get("/api/licences") -def read_licences(session: Session = Depends(get_session)) -> List[LicenceType]: - return list(LicenceType) - - -@app.get("/api/beamlines") -def read_beamlines(session: Session = Depends(get_session)) -> List[BeamlineResponse]: - bl = select_all(session, Beamline) - return bl - - -@app.get("/api/elements") -def read_elements(session: Session = Depends(get_session)) -> List[Element]: - e = select_all(session, Element) - return e - - -@app.get("/api/edges") -def read_edges(session: Session = Depends(get_session)) -> List[Edge]: - e = select_all(session, Edge) - return e - - -@app.get("/api/standards") -def read_standards( - session: Session = Depends(get_session), - element: str | None = None, -) -> CursorPage[XASStandardResponse]: - - statement = select(XASStandard).where( - XASStandard.review_status == ReviewStatus.approved - ) - - if element: - statement = statement.join(Element, XASStandard.element_z == Element.z).where( - Element.symbol == element - ) - - return paginate( - session, - statement.order_by(XASStandard.id), - ) - - -@app.get("/api/admin/standards") -def read_standards_admin( - session: Session = Depends(get_session), - user_id: str = Depends(get_current_user), -) -> CursorPage[AdminXASStandardResponse]: - - statement = select(Person).where(Person.identifier == user_id) - person = session.exec(statement).first() - - if person is None or not person.admin: - raise HTTPException(status_code=401, detail=f"No standard with id={user_id}") - - if not person.admin: - raise HTTPException(status_code=401, detail=f"User {user_id} not admin") - - statement = select(XASStandard).where( - XASStandard.review_status == ReviewStatus.pending - ) - - return paginate(session, statement.order_by(XASStandard.id)) - - -@app.get("/api/standards/{id}") -async def read_standard( - id: int, session: Session = Depends(get_session) -) -> XASStandardResponse: - return get_standard(session, id) - - -@app.post("/api/standards") -def add_standard_file( - xdi_file: UploadFile, - element_id: Annotated[str, Form()], - edge_id: Annotated[str, Form()], - beamline_id: Annotated[int, Form()], - sample_name: Annotated[str, Form()], - sample_prep: Annotated[str, Form()], - doi: Annotated[str, Form()], - citation: Annotated[str, Form()], - comments: Annotated[str, Form()], - date: Annotated[str, Form()], - licence: Annotated[str, Form()], - additional_files: Optional[list[UploadFile]] = Form(None), - sample_comp: Optional[str] = Form(None), - user_id: str = Depends(get_current_user), - session: Session = Depends(get_session), -) -> XASStandard: - - if additional_files: - print(f"Additional files {len(additional_files)}") - - person = select_or_create_person(session, user_id) - - form_input = XASStandardInput( - submitter_id=person.id, - beamline_id=beamline_id, - doi=doi, - element_z=element_id, - edge_id=edge_id, - sample_name=sample_name, - sample_prep=sample_prep, - submitter_comments=comments, - citation=citation, - licence=licence, - collection_date=date, - submission_date=datetime.datetime.now(), - sample_comp=sample_comp, - ) - - return add_new_standard(session, xdi_file, form_input, additional_files) - - -@app.patch("/api/standards") -def submit_review( - review: XASStandardAdminReviewInput, - session: Session = Depends(get_session), - user_id: str = Depends(get_current_user), -): - - statement = select(Person).where(Person.identifier == user_id) - person = session.exec(statement).first() - - if person is None or not person.admin: - raise HTTPException(status_code=401, detail=f"No standard with id={user_id}") - - if not person.admin: - raise HTTPException(status_code=401, detail=f"User {user_id} not admin") - return update_review(session, review, person.id) - - -@app.get("/api/data/{id}") -async def read_data( - id: int, format: Optional[str] = "json", session: Session = Depends(get_session) -): - - if format == "xdi": - return get_file(session, id) - - return get_data(session, id) - - -@app.get("/api/admin/data/{id}") -async def read_admin_data(id: int, session: Session = Depends(get_session)): - - return get_file_as_text(session, id) - - -@app.post("/uploadfiles/") -async def create_upload_files( - files: Annotated[ - list[UploadFile], File(description="Multiple files as UploadFile") - ], -): - return {"filenames": [file.filename for file in files]} - - -@app.get("/test") -async def main(): - content = """ - -
- - -
- - """ - return HTMLResponse(content=content) - - if build_dir: app.mount("/", StaticFiles(directory="/client/dist", html=True), name="site") diff --git a/xas-standards-api/src/xas_standards_api/auth.py b/xas-standards-api/src/xas_standards_api/auth.py new file mode 100644 index 0000000..112dc86 --- /dev/null +++ b/xas-standards-api/src/xas_standards_api/auth.py @@ -0,0 +1,52 @@ +import os + +import requests +from fastapi import ( + Depends, + HTTPException, + status, +) +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer + +get_bearer_token = HTTPBearer(auto_error=True) + +oidc_user_info_endpoint = os.environ.get("OIDC_USER_INFO_ENDPOINT") +dev = False + +env_value = os.environ.get("FASTAPI_APP_ENV") + +if env_value and env_value == "development": + print("RUNNING IN DEV MODE") + dev = True + +async def get_current_user( + auth: HTTPAuthorizationCredentials = Depends(get_bearer_token), +): + + if auth is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid user token", + ) + + if dev: + return auth.credentials + + if oidc_user_info_endpoint is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="User info endpoint error", + ) + + response = requests.get( + url=oidc_user_info_endpoint, + headers={"Authorization": f"Bearer {auth.credentials}"}, + ) + + if response.status_code == 401: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid user token", + ) + + return response.json()["id"] diff --git a/xas-standards-api/src/xas_standards_api/crud.py b/xas-standards-api/src/xas_standards_api/crud.py index d35bad2..27a6f41 100644 --- a/xas-standards-api/src/xas_standards_api/crud.py +++ b/xas-standards-api/src/xas_standards_api/crud.py @@ -1,27 +1,71 @@ import uuid -from fastapi import HTTPException +from fastapi import HTTPException, Query from fastapi.responses import FileResponse, PlainTextResponse +from fastapi_pagination.cursor import CursorPage +from fastapi_pagination.ext.sqlalchemy import paginate from larch.io import xdi from larch.xafs import pre_edge, set_xafsGroup -from sqlmodel import select +from sqlmodel import Session, select -from .schemas import ( +from .models.models import ( Beamline, Edge, Element, LicenceType, Person, PersonInput, + ReviewStatus, XASStandard, + XASStandardAdminReviewInput, XASStandardData, XASStandardDataInput, XASStandardInput, ) +from .models.response_models import XASStandardResponse + +CursorPage = CursorPage.with_custom_options( + size=Query(10, ge=1, le=100), +) + pvc_location = "/scratch/xas-standards-pretend-pvc/" +def patch_standard_review( + review: XASStandardAdminReviewInput, session: Session, user_id: str +): + statement = select(Person).where(Person.identifier == user_id) + person = session.exec(statement).first() + + if person is None or not person.admin: + raise HTTPException(status_code=401, detail=f"No standard with id={user_id}") + + if not person.admin: + raise HTTPException(status_code=401, detail=f"User {user_id} not admin") + return update_review(session, review, person.id) + + +def read_standards_page( + session: Session, + element: str | None = None, +) -> CursorPage[XASStandardResponse]: + + statement = select(XASStandard).where( + XASStandard.review_status == ReviewStatus.approved + ) + + if element: + statement = statement.join(Element, XASStandard.element_z == Element.z).where( + Element.symbol == element + ) + + return paginate( + session, + statement.order_by(XASStandard.id), + ) + + def get_beamline_names(session): results = session.exec(select(Beamline.name, Beamline.id)).all() return results @@ -62,6 +106,15 @@ def update_review(session, review, reviewer_id): return standard +def get_user(session, user_id): + statement = select(Person).where(Person.identifier == user_id) + person = session.exec(statement).first() + + admin = person is not None and person.admin + + return {"user": user_id, "admin": admin} + + def select_or_create_person(session, identifier): p = PersonInput(identifier=identifier) @@ -178,3 +231,23 @@ def get_data(session, id): "mufluor": fluor_out, "murefer": ref_out, } + + +def get_standards_admin( + session: Session, + user_id: str, +): + statement = select(Person).where(Person.identifier == user_id) + person = session.exec(statement).first() + + if person is None or not person.admin: + raise HTTPException(status_code=401, detail=f"No standard with id={user_id}") + + if not person.admin: + raise HTTPException(status_code=401, detail=f"User {user_id} not admin") + + statement = select(XASStandard).where( + XASStandard.review_status == ReviewStatus.pending + ) + + return paginate(session, statement.order_by(XASStandard.id)) diff --git a/xas-standards-api/src/xas_standards_api/database.py b/xas-standards-api/src/xas_standards_api/database.py new file mode 100644 index 0000000..131ac67 --- /dev/null +++ b/xas-standards-api/src/xas_standards_api/database.py @@ -0,0 +1,21 @@ +import os + +from sqlmodel import Session, create_engine + +url = os.environ.get("POSTGRESURL") + +engine = None + +if url: + engine = create_engine(url) +else: + print("URL not set - unit tests only") + + +def get_session(): + + if engine is None: + raise Exception("Database engine is None, has url been set?") + + with Session(engine) as session: + yield session diff --git a/xas-standards-api/src/xas_standards_api/models/__init__.py b/xas-standards-api/src/xas_standards_api/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/xas-standards-api/src/xas_standards_api/schemas.py b/xas-standards-api/src/xas_standards_api/models/models.py similarity index 79% rename from xas-standards-api/src/xas_standards_api/schemas.py rename to xas-standards-api/src/xas_standards_api/models/models.py index d86d625..d774291 100644 --- a/xas-standards-api/src/xas_standards_api/schemas.py +++ b/xas-standards-api/src/xas_standards_api/models/models.py @@ -2,19 +2,19 @@ import enum from typing import List, Optional -from pydantic import BaseModel from sqlmodel import Column, Enum, Field, Relationship, SQLModel -class Mono(BaseModel): - name: Optional[str] = None - d_spacing: Optional[str] = None - +class ReviewStatus(enum.Enum): + pending = "pending" + approved = "approved" + rejected = "rejected" -class Sample(BaseModel): - name: Optional[str] = None - prep: Optional[str] = None +class LicenceType(enum.Enum): + cc_by = "cc_by" + cc_0 = "cc_0" + logged_in_only = "logged_in_only" class PersonInput(SQLModel): identifier: str = Field(index=True, unique=True) @@ -25,24 +25,19 @@ class Person(PersonInput, table=True): admin: bool = False -class ElementInput(SQLModel): - symbol: str = Field(unique=True) - -class Element(ElementInput, table=True): +class Element(SQLModel, table=True): __tablename__: str = "element" z: int = Field(primary_key=True, unique=True) + symbol: str = Field(unique=True) name: str = Field(unique=True) -class EdgeInput(SQLModel): - name: str = Field(unique=True) - - -class Edge(EdgeInput, table=True): +class Edge(SQLModel, table=True): __tablename__: str = "edge" + name: str = Field(unique=True) id: int = Field(primary_key=True) level: str = Field(unique=True) @@ -78,27 +73,6 @@ class Beamline(SQLModel, table=True): ) -class FacilityResponse(SQLModel): - fullname: str - name: str - city: str - country: str - - -class BeamlineResponse(SQLModel): - id: int - name: str - notes: str - facility: FacilityResponse - - -class MetadataResponse(SQLModel): - beamlines: List[BeamlineResponse] - elements: List[Element] - edges: List[Edge] - licences: List[str] - - class XASStandardDataInput(SQLModel): original_filename: str transmission: bool @@ -116,18 +90,6 @@ class XASStandardData(XASStandardDataInput, table=True): xas_standard: "XASStandard" = Relationship(back_populates="xas_standard_data") -class ReviewStatus(enum.Enum): - pending = "pending" - approved = "approved" - rejected = "rejected" - - -class LicenceType(enum.Enum): - cc_by = "cc_by" - cc_0 = "cc_0" - logged_in_only = "logged_in_only" - - class XASStandardInput(SQLModel): submitter_id: int = Field(foreign_key="person.id") @@ -165,19 +127,6 @@ class XASStandard(XASStandardInput, table=True): } ) - -class XASStandardResponse(XASStandardInput): - id: int | None - element: ElementInput - edge: EdgeInput - beamline: BeamlineResponse - submitter_id: int - - -class AdminXASStandardResponse(XASStandardResponse): - submitter: Person - - class XASStandardAdminReviewInput(SQLModel): reviewer_comments: Optional[str] = None review_status: ReviewStatus diff --git a/xas-standards-api/src/xas_standards_api/models/response_models.py b/xas-standards-api/src/xas_standards_api/models/response_models.py new file mode 100644 index 0000000..47b071c --- /dev/null +++ b/xas-standards-api/src/xas_standards_api/models/response_models.py @@ -0,0 +1,39 @@ +from typing import List + +from sqlmodel import SQLModel + +from .models import Edge, Element, Person, XASStandardInput + + +class FacilityResponse(SQLModel): + fullname: str + name: str + city: str + country: str + +class BeamlineResponse(SQLModel): + id: int + name: str + notes: str + facility: FacilityResponse + +class XASStandardResponse(XASStandardInput): + id: int | None + element: Element + edge: Edge + beamline: BeamlineResponse + submitter_id: int + + +class AdminXASStandardResponse(XASStandardResponse): + submitter: Person + +class MetadataResponse(SQLModel): + beamlines: List[BeamlineResponse] + elements: List[Element] + edges: List[Edge] + licences: List[str] + + + + diff --git a/xas-standards-api/src/xas_standards_api/routers/__init__.py b/xas-standards-api/src/xas_standards_api/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/xas-standards-api/src/xas_standards_api/routers/admin.py b/xas-standards-api/src/xas_standards_api/routers/admin.py new file mode 100644 index 0000000..f27b6ba --- /dev/null +++ b/xas-standards-api/src/xas_standards_api/routers/admin.py @@ -0,0 +1,25 @@ +from fastapi import APIRouter, Depends +from fastapi_pagination.cursor import CursorPage +from sqlmodel import Session + +from ..auth import get_current_user +from ..crud import get_file_as_text, get_standards_admin +from ..database import get_session +from ..models.response_models import AdminXASStandardResponse + +router = APIRouter() + + +@router.get("/api/admin/data/{id}") +async def read_admin_data(id: int, session: Session = Depends(get_session)): + + return get_file_as_text(session, id) + + +@router.get("/api/admin/standards") +def read_standards_admin( + session: Session = Depends(get_session), + user_id: str = Depends(get_current_user), +) -> CursorPage[AdminXASStandardResponse]: + + return get_standards_admin(session,user_id) diff --git a/xas-standards-api/src/xas_standards_api/routers/open.py b/xas-standards-api/src/xas_standards_api/routers/open.py new file mode 100644 index 0000000..65abcbd --- /dev/null +++ b/xas-standards-api/src/xas_standards_api/routers/open.py @@ -0,0 +1,45 @@ +from typing import Optional + +from fastapi import APIRouter, Depends +from fastapi_pagination.cursor import CursorPage +from sqlmodel import Session + +from ..crud import get_data, get_file, get_metadata, get_standard, read_standards_page +from ..database import get_session +from ..models.response_models import ( + MetadataResponse, + XASStandardResponse, +) + +router = APIRouter() + +@router.get("/api/standards/{id}") +async def read_standard( + id: int, session: Session = Depends(get_session) +) -> XASStandardResponse: + return get_standard(session, id) + + +@router.get("/api/metadata") +def read_metadata(session: Session = Depends(get_session)) -> MetadataResponse: + return get_metadata(session) + + +@router.get("/api/standards") +def read_standards( + session: Session = Depends(get_session), + element: str | None = None, +) -> CursorPage[XASStandardResponse]: + + return read_standards_page(session,element) + +@router.get("/api/data/{id}") +async def read_data( + id: int, format: Optional[str] = "json", session: Session = Depends(get_session) +): + + if format == "xdi": + return get_file(session, id) + + return get_data(session, id) + diff --git a/xas-standards-api/src/xas_standards_api/routers/protected.py b/xas-standards-api/src/xas_standards_api/routers/protected.py new file mode 100644 index 0000000..935eb79 --- /dev/null +++ b/xas-standards-api/src/xas_standards_api/routers/protected.py @@ -0,0 +1,81 @@ +import datetime +from typing import Annotated, Optional + +from fastapi import ( + APIRouter, + Depends, + Form, + UploadFile, +) +from sqlmodel import Session + +from ..auth import get_current_user +from ..crud import ( + add_new_standard, + get_user, + patch_standard_review, + select_or_create_person, +) +from ..database import get_session +from ..models.models import XASStandard, XASStandardAdminReviewInput, XASStandardInput + +router = APIRouter() + +@router.get("/api/user") +async def check( + session: Session = Depends(get_session), user_id: str = Depends(get_current_user) +): + + return get_user(session,user_id) + + +@router.post("/api/standards") +def add_standard_file( + xdi_file: UploadFile, + element_id: Annotated[str, Form()], + edge_id: Annotated[str, Form()], + beamline_id: Annotated[int, Form()], + sample_name: Annotated[str, Form()], + sample_prep: Annotated[str, Form()], + doi: Annotated[str, Form()], + citation: Annotated[str, Form()], + comments: Annotated[str, Form()], + date: Annotated[str, Form()], + licence: Annotated[str, Form()], + additional_files: Optional[list[UploadFile]] = Form(None), + sample_comp: Optional[str] = Form(None), + user_id: str = Depends(get_current_user), + session: Session = Depends(get_session), +) -> XASStandard: + + if additional_files: + print(f"Additional files {len(additional_files)}") + + person = select_or_create_person(session, user_id) + + form_input = XASStandardInput( + submitter_id=person.id, + beamline_id=beamline_id, + doi=doi, + element_z=element_id, + edge_id=edge_id, + sample_name=sample_name, + sample_prep=sample_prep, + submitter_comments=comments, + citation=citation, + licence=licence, + collection_date=date, + submission_date=datetime.datetime.now(), + sample_comp=sample_comp, + ) + + return add_new_standard(session, xdi_file, form_input, additional_files) + +@router.patch("/api/standards") +def submit_review( + review: XASStandardAdminReviewInput, + session: Session = Depends(get_session), + user_id: str = Depends(get_current_user), +): + + return patch_standard_review(review, session, user_id) diff --git a/xas-standards-api/tests/test_app.py b/xas-standards-api/tests/test_app.py index e69de29..8bbab74 100644 --- a/xas-standards-api/tests/test_app.py +++ b/xas-standards-api/tests/test_app.py @@ -0,0 +1,102 @@ +from fastapi.testclient import TestClient +from sqlmodel import Session, SQLModel, create_engine +from sqlmodel.pool import StaticPool + +from xas_standards_api.app import app +from xas_standards_api.auth import get_current_user +from xas_standards_api.database import get_session +from xas_standards_api.models.models import Beamline, Edge, Element, Facility, Person +from xas_standards_api.models.response_models import MetadataResponse + +client = TestClient(app) + + +def test_read_item(): + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + + session.add(Element(name="Hydrogen", z=1, symbol="H")) + session.add(Edge(name="K", id=1, level="sp")) + session.add( + Facility( + id=1, + name="synchrotron", + notes="a place", + fullname="a synchrotron", + city="somewhere", + region="someplace", + laboratory="a lab", + country="somecountry", + ) + ) + + session.add( + Beamline( + facility_id=1, + id=1, + name="my beamline", + notes="a beamline", + xray_source="BM", + ) + ) + session.commit() + + def get_session_override(): + return session + + app.dependency_overrides[get_session] = get_session_override + + client = TestClient(app) + + response = client.get("/api/metadata/") + app.dependency_overrides.clear() + + print(response) + + mr = MetadataResponse.model_validate(response.json()) + + assert response.status_code == 200 + assert mr.elements[0].symbol == "H" + assert mr.edges[0].name == "K" + assert mr.beamlines[0].name == "my beamline" + assert mr.beamlines[0].facility.name == "synchrotron" + + +def test_read_person(): + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + + session.add(Person(id=1, identifier="abc123", admin=False)) + + session.commit() + + def get_session_override(): + return session + + def get_current_user_override(): + return "abc123" + + app.dependency_overrides[get_session] = get_session_override + app.dependency_overrides[get_current_user] = get_current_user_override + + client = TestClient(app) + + response = client.get("/api/user/") + app.dependency_overrides.clear() + + r = response.json() + + assert r["user"] == "abc123" + assert not r["admin"] diff --git a/xas-standards-api/tests/test_crud.py b/xas-standards-api/tests/test_crud.py index 37870c8..f30b088 100644 --- a/xas-standards-api/tests/test_crud.py +++ b/xas-standards-api/tests/test_crud.py @@ -1,10 +1,11 @@ from unittest.mock import Mock, call, create_autospec import pytest +from fastapi import HTTPException from sqlmodel import Session from xas_standards_api import crud -from xas_standards_api.schemas import XASStandard +from xas_standards_api.models.models import XASStandard def test_get_standard(): @@ -13,17 +14,45 @@ def test_get_standard(): test_id = 0 result = XASStandard() - mock_session.get = Mock(return_value=None) + # Session returns None, i.e. no standard for id + mock_session.get = Mock(return_value=None) expected_session_calls = [call.get(XASStandard, test_id)] - with pytest.raises(Exception): + with pytest.raises(HTTPException): crud.get_standard(mock_session, test_id) mock_session.get.assert_has_calls(expected_session_calls) + # Session returns a standard mock_session.get = Mock(return_value=result) output = crud.get_standard(mock_session, test_id) assert output == result + + +# def test_get_standards(): + +# mock_session = create_autospec(Session, instance=True) + + +# crud.read_standards_page(mock_session) + + +# def test_get_metadata(): + +# mock_session = create_autospec(Session, instance=True) + +# result = MetadataResponse(beamlines=[], edges=[], elements=[], licences=[]) +# mock_session.get = Mock(return_value=None) + +# expected_session_calls = [call.get()] + +# # mock_session.get.assert_has_calls(expected_session_calls) + +# # mock_session.get = Mock(return_value=result) + +# output = crud.get_metadata(mock_session) + +# assert output == result