Skip to content

Commit

Permalink
add tests for protected routes (#14)
Browse files Browse the repository at this point in the history
* add tests for protected routes

* add tmp dir
  • Loading branch information
jacobfilik authored Jun 25, 2024
1 parent 3272746 commit c0d520b
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 11 deletions.
6 changes: 3 additions & 3 deletions xas-standards-api/src/xas_standards_api/crud.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import uuid

from fastapi import HTTPException, Query
Expand Down Expand Up @@ -28,8 +29,7 @@
size=Query(10, ge=1, le=100),
)


pvc_location = "/scratch/xas-standards-pretend-pvc/"
pvc_location = os.environ.get("PVC_LOCATION", "/scratch/xas-standards-pretend-pvc/")


def patch_standard_review(
Expand Down Expand Up @@ -91,7 +91,7 @@ def get_standard(session, id) -> XASStandard:

if standard:
if standard.review_status != ReviewStatus.approved:
raise HTTPException(status_code=401, detail="Standard not available")
raise HTTPException(status_code=403, detail="Standard not available")
return standard
else:
raise HTTPException(status_code=404, detail=f"No standard with id={id}")
Expand Down
8 changes: 6 additions & 2 deletions xas-standards-api/tests/test_admin_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
from sqlmodel import Session, SQLModel, create_engine
from sqlmodel.pool import StaticPool

import xas_standards_api.crud
from utils import build_test_database
from xas_standards_api.app import app
from xas_standards_api.auth import get_current_user
from xas_standards_api.database import get_session
from xas_standards_api.models.response_models import AdminXASStandardResponse


def test_admin_read_permissions():
def test_admin_read_permissions(tmpdir):

xas_standards_api.crud.pvc_location = str(tmpdir)

engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
Expand Down Expand Up @@ -45,7 +49,7 @@ def get_admin_user():

# check cant get data from open endpoint
response = client.get("/api/data/2")
assert response.status_code == 401
assert response.status_code == 403

# now try admin user
app.dependency_overrides.clear()
Expand Down
10 changes: 7 additions & 3 deletions xas-standards-api/tests/test_open_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sqlmodel import Session, SQLModel, create_engine
from sqlmodel.pool import StaticPool

import xas_standards_api.crud
from utils import build_test_database
from xas_standards_api.app import app
from xas_standards_api.database import get_session
Expand All @@ -11,7 +12,10 @@
)


def test_read_metadata():
def test_read_metadata(tmpdir):

xas_standards_api.crud.pvc_location = str(tmpdir)

engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
Expand Down Expand Up @@ -63,7 +67,7 @@ def get_session_override():

# check cant get unreviewed data from open endpoint
response = client.get("/api/data/2")
assert response.status_code == 401
assert response.status_code == 403

# check cant get id that doesnt exist
response = client.get("/api/data/3")
Expand All @@ -73,7 +77,7 @@ def get_session_override():
assert response.status_code == 200

response = client.get("/api/standards/2")
assert response.status_code == 401
assert response.status_code == 403

response = client.get("/api/standards/3")
assert response.status_code == 404
64 changes: 61 additions & 3 deletions xas-standards-api/tests/test_protected_router.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import datetime

from fastapi.testclient import TestClient
from sqlmodel import Session, SQLModel, create_engine
from sqlmodel.pool import StaticPool

import xas_standards_api.crud
from utils import build_test_database
from xas_standards_api.app import app
from xas_standards_api.auth import get_current_user
from xas_standards_api.database import get_session
from xas_standards_api.models.models import XASStandard


def test_protected_router(tmpdir):

xas_standards_api.crud.pvc_location = str(tmpdir)

def test_read_person():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
Expand Down Expand Up @@ -54,5 +61,56 @@ def get_admin_user():
assert r["user"] == "admin"
assert r["admin"]

# TODO check post of standard
# TODO check patch of standard
unique_sample_name = f"Test sample {datetime.datetime.now()}"

formdata = {
"element_id": 1,
"edge_id": 1,
"beamline_id": 1,
"sample_name": unique_sample_name,
"sample_prep": "test",
"doi": "doi",
"citation": "citation",
"comments": "comments",
"date": str(datetime.datetime.min),
"licence": "cc_by",
"sample_comp": "H",
}

with open("test.xdi") as fh:
xditext = fh.read()

response = client.post(
"/api/standards", data=formdata, files={"xdi_file": xditext}
)

assert response.status_code == 200

rjson = response.json()

xass = XASStandard.model_validate(rjson)

assert xass.sample_name == unique_sample_name

print(xass.id)

# not reviewed, should fail
response = client.get(f"/api/standards/{xass.id}")
assert response.status_code == 403

# get and review
app.dependency_overrides.clear()
app.dependency_overrides[get_session] = get_session_override
app.dependency_overrides[get_current_user] = get_admin_user

review_json = {
"reviewer_comments": "reviewer",
"review_status": "approved",
"standard_id": 3,
}

response = client.patch("/api/standards", json=review_json)
assert response.status_code == 200

response = client.get(f"/api/standards/{xass.id}")
assert response.status_code == 200

0 comments on commit c0d520b

Please sign in to comment.