From b1ef5c108b7eaad6fbb4088893711a8bb85eecbe Mon Sep 17 00:00:00 2001 From: John Davis Date: Thu, 22 Feb 2024 12:09:48 -0500 Subject: [PATCH 1/2] Fix UserAuthnzToken's social_core interface methods --- lib/galaxy/model/__init__.py | 88 ++++++++++++++++++++++++++++-------- 1 file changed, 69 insertions(+), 19 deletions(-) diff --git a/lib/galaxy/model/__init__.py b/lib/galaxy/model/__init__.py index f2bac5f47e96..68b84b722397 100644 --- a/lib/galaxy/model/__init__.py +++ b/lib/galaxy/model/__init__.py @@ -9695,30 +9695,63 @@ def save(self): with transaction(self.sa_session): self.sa_session.commit() - @classmethod - def username_max_length(cls): - # Note: This is the maximum field length set for the username column of the galaxy_user table. - # A better alternative is to retrieve this number from the table, instead of this const value. - return 255 - @classmethod def changed(cls, user): + """ + The given user instance is ready to be saved + (Required by social_core.storage.UserMixin interface) + """ cls.sa_session.add(user) with transaction(cls.sa_session): cls.sa_session.commit() @classmethod def get_username(cls, user): + """ + Return the username for given user + (Required by social_core.storage.UserMixin interface) + """ return getattr(user, "username", None) + @classmethod + def user_model(cls): + """ + Return the user model + (Required by social_core.storage.UserMixin interface) + """ + return User + + @classmethod + def username_max_length(cls): + """ + Return the max length for username + (Required by social_core.storage.UserMixin interface) + """ + # Note: This is the maximum field length set for the username column of the galaxy_user table. + # A better alternative is to retrieve this number from the table, instead of this const value. + return 255 + + @classmethod + def user_exists(cls, *args, **kwargs): + """ + Return True/False if a User instance exists with the given arguments. + Arguments are directly passed to filter() manager method. + (Required by social_core.storage.UserMixin interface) + """ + stmt_user = select(User).filter_by(*args, **kwargs) + stmt_count = select(func.count()).select_from(stmt_user) + return cls.sa_session.scalar(stmt_count) > 0 + @classmethod def create_user(cls, *args, **kwargs): """ This is used by PSA authnz, do not use directly. Prefer using the user manager. + (Required by social_core.storage.UserMixin interface) """ - instance = User(*args, **kwargs) - if cls.email_exists(instance.email): + model = cls.user_model() + instance = model(*args, **kwargs) + if cls.get_users_by_email(instance.email): raise Exception(f"User with this email '{instance.email}' already exists.") instance.set_random_password() cls.sa_session.add(instance) @@ -9728,33 +9761,50 @@ def create_user(cls, *args, **kwargs): @classmethod def get_user(cls, pk): - return UserAuthnzToken.sa_session.get(User, pk) + """ + Return user instance for given id + (Required by social_core.storage.UserMixin interface) + """ + return cls.sa_session.get(User, pk) @classmethod - def email_exists(cls, email): - stmt = select(User).where(func.lower(User.email) == email.lower()).limit(1) - return bool(cls.sa_session.scalars(stmt).first()) + def get_users_by_email(cls, email): + """ + Return users instances for given email address + (Required by social_core.storage.UserMixin interface) + """ + stmt = select(User).where(func.lower(User.email) == email.lower()) + return cls.sa_session.scalars(stmt).all() @classmethod def get_social_auth(cls, provider, uid): + """ + Return UserSocialAuth for given provider and uid + (Required by social_core.storage.UserMixin interface) + """ uid = str(uid) - try: - stmt = select(UserAuthnzToken).filter_by(provider=provider, uid=uid).limit(1) - return cls.sa_session.scalars(stmt).first() - except IndexError: - return None + stmt = select(cls).filter_by(provider=provider, uid=uid).limit(1) + return cls.sa_session.scalars(stmt).first() @classmethod def get_social_auth_for_user(cls, user, provider=None, id=None): - stmt = select(UserAuthnzToken).filter_by(user_id=user.id) + """ + Return all the UserSocialAuth instances for given user + (Required by social_core.storage.UserMixin interface) + """ + stmt = select(cls).filter_by(user_id=user.id) if provider: stmt = stmt.filter_by(provider=provider) if id: stmt = stmt.filter_by(id=id) - return cls.sa_session.scalars(stmt) + return cls.sa_session.scalars(stmt).all() @classmethod def create_social_auth(cls, user, uid, provider): + """ + Create a UserSocialAuth instance for given user + (Required by social_core.storage.UserMixin interface) + """ uid = str(uid) instance = cls(user=user, uid=uid, provider=provider) cls.sa_session.add(instance) From e50c6154e0a1b4250577ce5528181fdfac8467cf Mon Sep 17 00:00:00 2001 From: John Davis Date: Thu, 22 Feb 2024 12:18:23 -0500 Subject: [PATCH 2/2] Add comments to ensure methods don't get dropped These methods are used by social_core --- lib/galaxy/model/__init__.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/lib/galaxy/model/__init__.py b/lib/galaxy/model/__init__.py index 68b84b722397..8df7bc4cdb91 100644 --- a/lib/galaxy/model/__init__.py +++ b/lib/galaxy/model/__init__.py @@ -9526,6 +9526,11 @@ def save(self): @classmethod def store(cls, server_url, association): + """ + Create an Association instance + (Required by social_core.storage.AssociationMixin interface) + """ + def get_or_create(): stmt = select(PSAAssociation).filter_by(server_url=server_url, handle=association.handle).limit(1) assoc = cls.sa_session.scalars(stmt).first() @@ -9542,11 +9547,19 @@ def get_or_create(): @classmethod def get(cls, *args, **kwargs): + """ + Get an Association instance + (Required by social_core.storage.AssociationMixin interface) + """ stmt = select(PSAAssociation).filter_by(*args, **kwargs) return cls.sa_session.scalars(stmt).all() @classmethod def remove(cls, ids_to_delete): + """ + Remove an Association instance + (Required by social_core.storage.AssociationMixin interface) + """ stmt = ( delete(PSAAssociation) .where(PSAAssociation.id.in_(ids_to_delete)) @@ -9577,6 +9590,9 @@ def save(self): @classmethod def get_code(cls, code): + """ + (Required by social_core.storage.CodeMixin interface) + """ stmt = select(PSACode).where(PSACode.code == code).limit(1) return cls.sa_session.scalars(stmt).first() @@ -9604,6 +9620,10 @@ def save(self): @classmethod def use(cls, server_url, timestamp, salt): + """ + Create a Nonce instance + (Required by social_core.storage.NonceMixin interface) + """ try: stmt = select(PSANonce).where(server_url=server_url, timestamp=timestamp, salt=salt).limit(1) return cls.sa_session.scalars(stmt).first() @@ -9640,11 +9660,17 @@ def save(self): @classmethod def load(cls, token): + """ + (Required by social_core.storage.PartialMixin interface) + """ stmt = select(PSAPartial).where(PSAPartial.token == token).limit(1) return cls.sa_session.scalars(stmt).first() @classmethod def destroy(cls, token): + """ + (Required by social_core.storage.PartialMixin interface) + """ partial = cls.load(token) if partial: session = cls.sa_session