Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] SA2.0, model/data-access edits, unit testing #17551

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
0c2b388
[TMP] Combined commits from #17180
jdavcs Dec 12, 2023
dc2aec7
Initial data access testing infrastructure
jdavcs Feb 28, 2024
1b5d947
Add first library data access tests
jdavcs Feb 28, 2024
060361d
Add first user data access tests
jdavcs Feb 28, 2024
ebb738a
Start replacing test_galaxy_mapping w/data access tests
jdavcs Feb 28, 2024
0ae692a
security.get_npns_roles: test + factor out
jdavcs Feb 28, 2024
e6f84de
security.get_private_user_role: test + factor out
jdavcs Feb 28, 2024
52c07ca
webapps.galaxy.services.user.get_users_for_index: test + factor out
jdavcs Feb 28, 2024
639419d
Move user data access method from webapps to managers
jdavcs Feb 28, 2024
ab45c5e
More test-galaxy-mapping conversions
jdavcs Feb 28, 2024
aca6e44
Convert another test
jdavcs Mar 1, 2024
92f334f
Convert test_ratings
jdavcs Mar 4, 2024
d5cfe8a
Refactor model fixtures
jdavcs Mar 4, 2024
1913bfd
Drop test
jdavcs Mar 7, 2024
1c3746d
Drop test
jdavcs Mar 7, 2024
7e2379e
Convert test_history_contents
jdavcs Mar 7, 2024
bf16e50
Convert test_current_galaxy_sesssion
jdavcs Mar 7, 2024
a819692
Convert hid tests
jdavcs Mar 7, 2024
4fc7014
Convert test_get_display_name
jdavcs Mar 7, 2024
37a321b
Drop test_tags
jdavcs Mar 7, 2024
90d9bee
Drop incomplete test
jdavcs Mar 7, 2024
5700627
Drop test_basic
jdavcs Mar 7, 2024
75493ff
Convert test_metadata_spec
jdavcs Mar 7, 2024
773a1d7
Add Decimal to accpted types by util.nice_size()
jdavcs Mar 5, 2024
c00a2c7
Drop unused fixture arg
jdavcs Mar 7, 2024
c3d824b
Convert + improve test_job/task_metrics
jdavcs Mar 7, 2024
604fd7a
Drop test_tasks
jdavcs Mar 7, 2024
edb732b
Use DeclarativeBase as base class
jdavcs Mar 8, 2024
951209f
Remove future, autocommit args from session and engine creation
jdavcs Mar 8, 2024
9da049c
Fix table/__table__/HasTable for models
jdavcs Mar 8, 2024
399e34e
Remove fix for model constructors (fixed in SA2.0)
jdavcs Mar 8, 2024
1a09430
Add missing type hints to mapped_column in the model
jdavcs Mar 8, 2024
85d8862
Fix bug: WorkflowInvocationMessage in not serializable
jdavcs Mar 8, 2024
fd9d22d
Add type hint to 1 column_property to fix a new mypy error
jdavcs Mar 8, 2024
260f49d
Fix bug: there's no user attr on WorkflowInvocation
jdavcs Mar 8, 2024
feae41e
Fix typing in CustosAuthmzToken model
jdavcs Mar 8, 2024
754adfe
Fix uuid types in the model
jdavcs Mar 8, 2024
503ad9b
Fix type error in workflow unit test
jdavcs Mar 8, 2024
8bf1f7a
Start trimming model definition syntax
jdavcs Mar 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/galaxy/app_unittest_utils/galaxy_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(self, config=None, **kwargs) -> None:
self[ShortTermStorageMonitor] = sts_manager # type: ignore[type-abstract]
self[galaxy_scoped_session] = self.model.context
self.visualizations_registry = MockVisualizationsRegistry()
self.tag_handler = tags.GalaxyTagHandler(self.model.context)
self.tag_handler = tags.GalaxyTagHandler(self.model.session)
self[tags.GalaxyTagHandler] = self.tag_handler
self.quota_agent = quota.DatabaseQuotaAgent(self.model)
self.job_config = Bunch(
Expand Down
45 changes: 17 additions & 28 deletions lib/galaxy/celery/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
)
from sqlalchemy.dialects.postgresql import insert as ps_insert
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

from galaxy.model import CeleryUserRateLimit
from galaxy.model.base import transaction
Expand Down Expand Up @@ -70,7 +69,7 @@ def __call__(self, task: Task, task_id, args, kwargs):

@abstractmethod
def calculate_task_start_time(
self, user_id: int, sa_session: Session, task_interval_secs: float, now: datetime.datetime
self, user_id: int, sa_session: galaxy_scoped_session, task_interval_secs: float, now: datetime.datetime
) -> datetime.datetime:
return now

Expand All @@ -81,38 +80,28 @@ class GalaxyTaskBeforeStartUserRateLimitPostgres(GalaxyTaskBeforeStartUserRateLi
We take advantage of efficiencies in its dialect.
"""

_update_stmt = (
update(CeleryUserRateLimit)
.where(CeleryUserRateLimit.user_id == bindparam("userid"))
.values(last_scheduled_time=text("greatest(last_scheduled_time + ':interval second', " ":now) "))
.returning(CeleryUserRateLimit.last_scheduled_time)
)

_insert_stmt = (
ps_insert(CeleryUserRateLimit)
.values(user_id=bindparam("userid"), last_scheduled_time=bindparam("now"))
.returning(CeleryUserRateLimit.last_scheduled_time)
)

_upsert_stmt = _insert_stmt.on_conflict_do_update(
index_elements=["user_id"], set_=dict(last_scheduled_time=bindparam("sched_time"))
)

def calculate_task_start_time( # type: ignore
self, user_id: int, sa_session: Session, task_interval_secs: float, now: datetime.datetime
self, user_id: int, sa_session: galaxy_scoped_session, task_interval_secs: float, now: datetime.datetime
) -> datetime.datetime:
with transaction(sa_session):
result = sa_session.execute(
self._update_stmt, {"userid": user_id, "interval": task_interval_secs, "now": now}
update_stmt = (
update(CeleryUserRateLimit)
.where(CeleryUserRateLimit.user_id == user_id)
.values(last_scheduled_time=text("greatest(last_scheduled_time + ':interval second', " ":now) "))
.returning(CeleryUserRateLimit.last_scheduled_time)
)
if result.rowcount == 0:
result = sa_session.execute(update_stmt, {"interval": task_interval_secs, "now": now}).all()
if not result:
sched_time = now + datetime.timedelta(seconds=task_interval_secs)
result = sa_session.execute(
self._upsert_stmt, {"userid": user_id, "now": now, "sched_time": sched_time}
upsert_stmt = (
ps_insert(CeleryUserRateLimit) # type:ignore[attr-defined]
.values(user_id=user_id, last_scheduled_time=now)
.returning(CeleryUserRateLimit.last_scheduled_time)
.on_conflict_do_update(index_elements=["user_id"], set_=dict(last_scheduled_time=sched_time))
)
for row in result:
return row[0]
result = sa_session.execute(upsert_stmt).all()
sa_session.commit()
return result[0][0]


class GalaxyTaskBeforeStartUserRateLimitStandard(GalaxyTaskBeforeStartUserRateLimit):
Expand All @@ -138,7 +127,7 @@ class GalaxyTaskBeforeStartUserRateLimitStandard(GalaxyTaskBeforeStartUserRateLi
)

def calculate_task_start_time(
self, user_id: int, sa_session: Session, task_interval_secs: float, now: datetime.datetime
self, user_id: int, sa_session: galaxy_scoped_session, task_interval_secs: float, now: datetime.datetime
) -> datetime.datetime:
last_scheduled_time = None
with transaction(sa_session):
Expand Down
5 changes: 4 additions & 1 deletion lib/galaxy/celery/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def set_metadata(
try:
if overwrite:
hda_manager.overwrite_metadata(dataset_instance)
dataset_instance.datatype.set_meta(dataset_instance)
dataset_instance.datatype.set_meta(dataset_instance) # type:ignore [arg-type]
dataset_instance.set_peek()
# Reset SETTING_METADATA state so the dataset instance getter picks the dataset state
dataset_instance.set_metadata_success_state()
Expand Down Expand Up @@ -228,6 +228,7 @@ def setup_fetch_data(
):
tool = cached_create_tool_from_representation(app=app, raw_tool_source=raw_tool_source)
job = sa_session.get(Job, job_id)
assert job
# self.request.hostname is the actual worker name given by the `-n` argument, not the hostname as you might think.
job.handler = self.request.hostname
job.job_runner_name = "celery"
Expand Down Expand Up @@ -260,6 +261,7 @@ def finish_job(
):
tool = cached_create_tool_from_representation(app=app, raw_tool_source=raw_tool_source)
job = sa_session.get(Job, job_id)
assert job
# TODO: assert state ?
mini_job_wrapper = MinimalJobWrapper(job=job, app=app, tool=tool)
mini_job_wrapper.finish("", "")
Expand Down Expand Up @@ -320,6 +322,7 @@ def fetch_data(
task_user_id: Optional[int] = None,
) -> str:
job = sa_session.get(Job, job_id)
assert job
mini_job_wrapper = MinimalJobWrapper(job=job, app=app)
mini_job_wrapper.change_state(model.Job.states.RUNNING, flush=True, job=job)
return abort_when_job_stops(_fetch_data, session=sa_session, job_id=job_id, setup_return=setup_return)
Expand Down
31 changes: 0 additions & 31 deletions lib/galaxy/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,6 @@ class GalaxyAppConfiguration(BaseAppConfiguration, CommonConfigurationMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._override_tempdir(kwargs)
self._configure_sqlalchemy20_warnings(kwargs)
self._process_config(kwargs)
self._set_dependent_defaults()

Expand All @@ -760,36 +759,6 @@ def _set_dependent_defaults(self):
f"{dependent_config_param}, {config_param}"
)

def _configure_sqlalchemy20_warnings(self, kwargs):
"""
This method should be deleted after migration to SQLAlchemy 2.0 is complete.
To enable warnings, set `GALAXY_CONFIG_SQLALCHEMY_WARN_20=1`,
"""
warn = string_as_bool(kwargs.get("sqlalchemy_warn_20", False))
if warn:
import sqlalchemy

sqlalchemy.util.deprecations.SQLALCHEMY_WARN_20 = True
self._setup_sqlalchemy20_warnings_filters()

def _setup_sqlalchemy20_warnings_filters(self):
import warnings

from sqlalchemy.exc import RemovedIn20Warning

# Always display RemovedIn20Warning warnings.
warnings.filterwarnings("always", category=RemovedIn20Warning)
# Optionally, enable filters for specific warnings (raise error, or log, etc.)
# messages = [
# r"replace with warning text to match",
# ]
# for msg in messages:
# warnings.filterwarnings('error', message=msg, category=RemovedIn20Warning)
#
# See documentation:
# https://docs.python.org/3.7/library/warnings.html#the-warnings-filter
# https://docs.sqlalchemy.org/en/14/changelog/migration_20.html#migration-to-2-0-step-three-resolve-all-removedin20warnings

def _load_schema(self):
return AppSchema(GALAXY_CONFIG_SCHEMA_PATH, GALAXY_APP_NAME)

Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/dependencies/pinned-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ sniffio==1.3.1 ; python_version >= "3.8" and python_version < "3.13"
social-auth-core[openidconnect]==4.5.3 ; python_version >= "3.8" and python_version < "3.13"
sortedcontainers==2.4.0 ; python_version >= "3.8" and python_version < "3.13"
spython==0.3.13 ; python_version >= "3.8" and python_version < "3.13"
sqlalchemy==1.4.51 ; python_version >= "3.8" and python_version < "3.13"
sqlalchemy==2.0.25 ; python_version >= "3.8" and python_version < "3.13"
sqlitedict==2.1.0 ; python_version >= "3.8" and python_version < "3.13"
sqlparse==0.4.4 ; python_version >= "3.8" and python_version < "3.13"
starlette-context==0.3.6 ; python_version >= "3.8" and python_version < "3.13"
Expand Down
6 changes: 4 additions & 2 deletions lib/galaxy/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,9 @@ def galaxy_url(self):
return self.get_destination_configuration("galaxy_infrastructure_url")

def get_job(self) -> model.Job:
return self.sa_session.get(Job, self.job_id)
job = self.sa_session.get(Job, self.job_id)
assert job
return job

def get_id_tag(self):
# For compatibility with drmaa, which uses job_id right now, and TaskWrapper
Expand Down Expand Up @@ -1552,7 +1554,7 @@ def change_state(self, state, info=False, flush=True, job=None):
def get_state(self) -> str:
job = self.get_job()
self.sa_session.refresh(job)
return job.state
return job.state # type:ignore[return-value]

def set_runner(self, runner_url, external_id):
log.warning("set_runner() is deprecated, use set_job_destination()")
Expand Down
1 change: 0 additions & 1 deletion lib/galaxy/jobs/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def __init__(
self.self_handler_tags = self_handler_tags
self.max_grab = max_grab
self.handler_tags = handler_tags
self._grab_conn_opts = {"autocommit": False}
self._grab_query = None
self._supports_returning = self.app.application_stack.supports_returning()

Expand Down
8 changes: 4 additions & 4 deletions lib/galaxy/managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def get_object(trans, id, class_name, check_ownership=False, check_accessible=Fa


# =============================================================================
U = TypeVar("U", bound=model._HasTable)
U = TypeVar("U", bound=model.Base)


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -318,9 +318,9 @@ def _one_with_recast_errors(self, query: Query) -> U:
# overridden to raise serializable errors
try:
return query.one()
except sqlalchemy.orm.exc.NoResultFound:
except sqlalchemy.exc.NoResultFound:
raise exceptions.ObjectNotFound(f"{self.model_class.__name__} not found")
except sqlalchemy.orm.exc.MultipleResultsFound:
except sqlalchemy.exc.MultipleResultsFound:
raise exceptions.InconsistentDatabase(f"found more than one {self.model_class.__name__}")

# NOTE: at this layer, all ids are expected to be decoded and in int form
Expand Down Expand Up @@ -999,7 +999,7 @@ class ModelFilterParser(HasAModelManager):
# (as the model informs how the filter params are parsed)
# I have no great idea where this 'belongs', so it's here for now

model_class: Type[model._HasTable]
model_class: Type[Union[model.Base, model.DatasetInstance]]
parsed_filter = parsed_filter
orm_filter_parsers: OrmFilterParsersType
fn_filter_parsers: FunctionFilterParsersType
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/managers/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def get_collection_contents(self, trans: ProvidesAppContext, parent_id, limit=No
def _get_collection_contents_qry(self, parent_id, limit=None, offset=None):
"""Build query to find first level of collection contents by containing collection parent_id"""
DCE = model.DatasetCollectionElement
qry = Query(DCE).filter(DCE.dataset_collection_id == parent_id)
qry = Query(DCE).filter(DCE.dataset_collection_id == parent_id) # type:ignore[var-annotated]
qry = qry.order_by(DCE.element_index)
qry = qry.options(
joinedload(model.DatasetCollectionElement.child_collection), joinedload(model.DatasetCollectionElement.hda)
Expand Down
6 changes: 3 additions & 3 deletions lib/galaxy/managers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def purge_datasets(self, request: PurgeDatasetsTaskRequest):
self.error_unless_dataset_purge_allowed()
with self.session().begin():
for dataset_id in request.dataset_ids:
dataset: Dataset = self.session().get(Dataset, dataset_id)
if dataset.user_can_purge:
dataset: Optional[Dataset] = self.session().get(Dataset, dataset_id)
if dataset and dataset.user_can_purge:
try:
dataset.full_delete()
except Exception:
Expand Down Expand Up @@ -339,7 +339,7 @@ def serialize_permissions(self, item, key, user=None, **context):

# ============================================================================= AKA DatasetInstanceManager
class DatasetAssociationManager(
base.ModelManager[model.DatasetInstance],
base.ModelManager[model.DatasetInstance], # type:ignore[type-var]
secured.AccessibleManagerMixin,
secured.OwnableManagerMixin,
deletable.PurgableManagerMixin,
Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/managers/dbkeys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
)

from sqlalchemy import select
from sqlalchemy.orm import Session

from galaxy.model import HistoryDatasetAssociation
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.util import (
galaxy_directory,
sanitize_lists_to_string,
Expand Down Expand Up @@ -166,6 +166,6 @@ def get_chrom_info(self, dbkey, trans=None, custom_build_hack_get_len_from_fasta
return (chrom_info, db_dataset)


def get_len_files_by_history(session: Session, history_id: int):
def get_len_files_by_history(session: galaxy_scoped_session, history_id: int):
stmt = select(HistoryDatasetAssociation).filter_by(history_id=history_id, extension="len", deleted=False)
return session.scalars(stmt)
6 changes: 3 additions & 3 deletions lib/galaxy/managers/export_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
and_,
select,
)
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm.scoping import scoped_session

from galaxy.exceptions import ObjectNotFound
Expand Down Expand Up @@ -44,7 +44,7 @@ def set_export_association_metadata(self, export_association_id: int, export_met
export_association: StoreExportAssociation = self.session.execute(stmt).scalars().one()
except NoResultFound:
raise ObjectNotFound("Cannot set export metadata. Reason: Export association not found")
export_association.export_metadata = export_metadata.json()
export_association.export_metadata = export_metadata.json() # type:ignore[assignment]
with transaction(self.session):
self.session.commit()

Expand Down Expand Up @@ -76,4 +76,4 @@ def get_object_exports(
stmt = stmt.offset(offset)
if limit:
stmt = stmt.limit(limit)
return self.session.execute(stmt).scalars()
return self.session.execute(stmt).scalars() # type:ignore[return-value]
9 changes: 5 additions & 4 deletions lib/galaxy/managers/folders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
or_,
select,
)
from sqlalchemy.orm import aliased
from sqlalchemy.orm.exc import (
from sqlalchemy.exc import (
MultipleResultsFound,
NoResultFound,
)
from sqlalchemy.orm import aliased

from galaxy import (
model,
Expand Down Expand Up @@ -505,7 +505,7 @@ def _get_contained_datasets_statement(
stmt = stmt.where(
or_(
func.lower(ldda.name).contains(search_text, autoescape=True),
func.lower(ldda.message).contains(search_text, autoescape=True),
func.lower(ldda.message).contains(search_text, autoescape=True), # type:ignore[attr-defined]
)
)
sort_column = LDDA_SORT_COLUMN_MAP[payload.order_by](ldda, associated_dataset)
Expand Down Expand Up @@ -536,7 +536,7 @@ def _filter_by_include_deleted(

def build_folder_path(
self, sa_session: galaxy_scoped_session, folder: model.LibraryFolder
) -> List[Tuple[str, str]]:
) -> List[Tuple[int, Optional[str]]]:
"""
Returns the folder path from root to the given folder.

Expand All @@ -546,6 +546,7 @@ def build_folder_path(
path_to_root = [(current_folder.id, current_folder.name)]
while current_folder.parent_id is not None:
parent_folder = sa_session.get(LibraryFolder, current_folder.parent_id)
assert parent_folder
current_folder = parent_folder
path_to_root.insert(0, (current_folder.id, current_folder.name))
return path_to_root
Expand Down
9 changes: 6 additions & 3 deletions lib/galaxy/managers/forms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from sqlalchemy import select
from sqlalchemy.orm import exc as sqlalchemy_exceptions
from sqlalchemy.exc import (
MultipleResultsFound,
NoResultFound,
)

from galaxy.exceptions import (
InconsistentDatabase,
Expand Down Expand Up @@ -59,9 +62,9 @@ def get(self, trans: ProvidesUserContext, form_id: int) -> FormDefinitionCurrent
try:
stmt = select(FormDefinitionCurrent).where(FormDefinitionCurrent.id == form_id)
form = self.session().execute(stmt).scalar_one()
except sqlalchemy_exceptions.MultipleResultsFound:
except MultipleResultsFound:
raise InconsistentDatabase("Multiple forms found with the same id.")
except sqlalchemy_exceptions.NoResultFound:
except NoResultFound:
raise RequestParameterInvalidException("No accessible form found with the id provided.")
except Exception as e:
raise InternalServerError(f"Error loading from the database.{unicodify(e)}")
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/managers/genomes.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _create_genome_filter(model_class=None):
if self.database_connection.startswith("postgres"):
column = text("convert_from(metadata, 'UTF8')::json ->> 'dbkey'")
else:
column = func.json_extract(model_class.table.c._metadata, "$.dbkey")
column = func.json_extract(model_class.table.c._metadata, "$.dbkey") # type:ignore[assignment]
lower_val = val.lower() # Ignore case
# dbkey can either be "hg38" or '["hg38"]', so we need to check both
if op == "eq":
Expand Down
Loading
Loading