Skip to content

Commit

Permalink
Fix some unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bbrondel committed Jan 7, 2025
1 parent 55d211a commit f56d4bc
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 37 deletions.
6 changes: 3 additions & 3 deletions alembic-autogenerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import re
import sys

from alembic.config import Config
from alembic import command
from felis.tests.postgresql import setup_postgres_test_db
from sqlalchemy.sql import text

from felis.tests.postgresql import setup_postgres_test_db
from alembic import command
from alembic.config import Config

if len(sys.argv) <= 1:
print(
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/consdb/cdb_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from enum import StrEnum
import logging
from enum import StrEnum
from typing import Generator

import sqlalchemy
import sqlalchemy.dialects.postgresql
from sqlalchemy.orm import Session
from packaging.version import Version
from sqlalchemy.orm import Session

from .exceptions import BadValueException

Expand Down
2 changes: 2 additions & 0 deletions python/lsst/consdb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import re
import sys

from pydantic import Field, field_validator
from pydantic_settings import BaseSettings

Expand All @@ -11,6 +12,7 @@

class Configuration(BaseSettings):
"""Configuration for consdb."""

name: str = Field("pqserver", title="Application name")

version: str = Field("noversion", title="Application version number")
Expand Down
8 changes: 3 additions & 5 deletions python/lsst/consdb/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import logging
from typing import Annotated

from fastapi import Path, Request
from pydantic import AfterValidator
from sqlalchemy import create_engine, inspect
from sqlalchemy.orm import sessionmaker
from typing import Annotated

from .config import config
from .cdb_schema import InstrumentTable
from .config import config
from .exceptions import UnknownInstrumentException

__all__ = ["get_logger", "get_db"]
Expand Down Expand Up @@ -98,9 +98,7 @@ def get_instrument_table(instrument: str):
if instrument in instrument_tables:
instrument_table = instrument_tables[instrument]
else:
instrument_table = InstrumentTable(
engine=engine, instrument=instrument, get_db=get_db, logger=logger
)
instrument_table = InstrumentTable(engine=engine, instrument=instrument, get_db=get_db, logger=logger)
instrument_tables[instrument] = instrument_table

return instrument_table
Expand Down
42 changes: 22 additions & 20 deletions python/lsst/consdb/handlers/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
import astropy
import sqlalchemy
import sqlalchemy.dialects.postgresql
from sqlalchemy.orm import Session
from fastapi import APIRouter, Body, Depends, Path, Query
from packaging.version import Version
from sqlalchemy.orm import Session

from ..cdb_schema import (
AllowedFlexType,
Expand All @@ -38,7 +38,7 @@
convert_to_flex_type,
)
from ..config import config
from ..dependencies import get_db, get_logger, get_instrument_list, get_instrument_table, InstrumentName
from ..dependencies import InstrumentName, get_db, get_instrument_list, get_instrument_table, get_logger
from ..exceptions import BadValueException
from ..models import (
AddKeyRequestModel,
Expand All @@ -53,6 +53,13 @@
QueryResponseModel,
)

# Check SQLAlchemy version
required_version = (2, 0, 0)
current_version = tuple(map(int, sqlalchemy.__version__.split(".")[:3]))
assert (
current_version >= required_version
), f"SQLAlchemy version must be >= 2.0.0, but found {sqlalchemy.__version__}"

external_router = APIRouter()
"""FastAPI router for all external handlers."""

Expand Down Expand Up @@ -221,9 +228,7 @@ def insert_flexible_metadata(
index_elements=["day_obs", "seq_num", "key"], set_={"value": value_str}
)
else:
stmt = stmt.on_conflict_do_update(
index_elements=["obs_id", "key"], set_={"value": value_str}
)
stmt = stmt.on_conflict_do_update(index_elements=["obs_id", "key"], set_={"value": value_str})

logger.debug(str(stmt))
_ = db.execute(stmt)
Expand Down Expand Up @@ -355,20 +360,21 @@ def insert_multiple(
timestamp = astropy.time.Time(timestamp, format="isot", scale="tai")
valdict[column] = timestamp.to_datetime()

stmt: sqlalchemy.sql.dml.Insert
stmt = sqlalchemy.dialects.postgresql.insert(table_obj).values(valdict)
if u != 0:
stmt = stmt.on_conflict_do_update(index_elements=[obs_id_colname], set_=valdict)
bulk_data.append(valdict)

try:
if bulk_data:
db.execute(insert(table_obj).values(bulk_data))
stmt = sqlalchemy.dialects.postgresql.insert(table_obj).values(bulk_data)
if u != 0:
# Specify update behavior for conflicts
update_dict = {col: stmt.excluded[col] for col in table_obj.columns.keys()}
stmt = stmt.on_conflict_do_update(index_elements=[obs_id_colname], set_=update_dict)

db.execute(stmt)
db.commit()

except Exception:
db.rollback()
logger.exception("Failed to insert data")
logger.exception("Failed to insert or update data")
raise

return InsertMultipleResponseModel(
Expand Down Expand Up @@ -455,14 +461,10 @@ def query(
columns = []
rows = []

cursor = db.exec_driver_sql(data.query)
first = True
for row in cursor:
logger.debug(row)
if first:
columns.extend(row.keys())
first = False
rows.append(list(row))
with db.connection() as connection:
result = connection.exec_driver_sql(data.query)
columns = result.keys()
rows = [list(row) for row in result]

return QueryResponseModel(
columns=columns,
Expand Down
3 changes: 2 additions & 1 deletion python/lsst/consdb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from typing import Any

import astropy
from pydantic import BaseModel, Field, field_validator
from safir.metadata import Metadata
from typing import Any

from .cdb_schema import AllowedFlexType, AllowedFlexTypeEnum, ObservationIdType, ObsTypeEnum
from .dependencies import InstrumentName
Expand Down
1 change: 0 additions & 1 deletion python/lsst/consdb/pqserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def unknown_instrument_exception_handler(request: Request, exc: UnknownInstrumen

@app.exception_handler(BadValueException)
def bad_value_exception_handler(request: Request, exc: BadValueException):
exc_str = f"{exc}".replace("\n", " ").replace(" ", " ")
return JSONResponse(content=exc.to_dict(), status_code=status.HTTP_404_NOT_FOUND)


Expand Down
11 changes: 6 additions & 5 deletions tests/test_pqserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,11 +490,12 @@ def test_flexible_metadata(lsstcomcamsim):
assert result == {"baz": 2.71828}

response = client.post("/consdb/flex/latiss/exposure/obs/7024052800003", json={})
_assert_http_status(response, 404)
result = response.json()
assert "Validation error" in result["message"]
assert result["detail"][0]["type"] == "missing"
assert "values" in result["detail"][0]["loc"]
_assert_http_status(response, 422)
result = response.json()["detail"][0]
assert "Field required" in result["msg"]
assert result["type"] == "missing"
assert "values" in result["loc"]
assert "body" in result["loc"]

response = client.post(
"/consdb/insert/latiss/exposure/obs/2024032100003",
Expand Down

0 comments on commit f56d4bc

Please sign in to comment.