Skip to content

Commit

Permalink
Merge pull request #17530 from jdavcs/23.2_social_core
Browse files Browse the repository at this point in the history
[23.2] Fix social_core methods
  • Loading branch information
mvdbeek authored Feb 23, 2024
2 parents 036e469 + e50c615 commit 51ee56e
Showing 1 changed file with 95 additions and 19 deletions.
114 changes: 95 additions & 19 deletions lib/galaxy/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9526,6 +9526,11 @@ def save(self):

@classmethod
def store(cls, server_url, association):
"""
Create an Association instance
(Required by social_core.storage.AssociationMixin interface)
"""

def get_or_create():
stmt = select(PSAAssociation).filter_by(server_url=server_url, handle=association.handle).limit(1)
assoc = cls.sa_session.scalars(stmt).first()
Expand All @@ -9542,11 +9547,19 @@ def get_or_create():

@classmethod
def get(cls, *args, **kwargs):
"""
Get an Association instance
(Required by social_core.storage.AssociationMixin interface)
"""
stmt = select(PSAAssociation).filter_by(*args, **kwargs)
return cls.sa_session.scalars(stmt).all()

@classmethod
def remove(cls, ids_to_delete):
"""
Remove an Association instance
(Required by social_core.storage.AssociationMixin interface)
"""
stmt = (
delete(PSAAssociation)
.where(PSAAssociation.id.in_(ids_to_delete))
Expand Down Expand Up @@ -9577,6 +9590,9 @@ def save(self):

@classmethod
def get_code(cls, code):
"""
(Required by social_core.storage.CodeMixin interface)
"""
stmt = select(PSACode).where(PSACode.code == code).limit(1)
return cls.sa_session.scalars(stmt).first()

Expand Down Expand Up @@ -9604,6 +9620,10 @@ def save(self):

@classmethod
def use(cls, server_url, timestamp, salt):
"""
Create a Nonce instance
(Required by social_core.storage.NonceMixin interface)
"""
try:
stmt = select(PSANonce).where(server_url=server_url, timestamp=timestamp, salt=salt).limit(1)
return cls.sa_session.scalars(stmt).first()
Expand Down Expand Up @@ -9640,11 +9660,17 @@ def save(self):

@classmethod
def load(cls, token):
"""
(Required by social_core.storage.PartialMixin interface)
"""
stmt = select(PSAPartial).where(PSAPartial.token == token).limit(1)
return cls.sa_session.scalars(stmt).first()

@classmethod
def destroy(cls, token):
"""
(Required by social_core.storage.PartialMixin interface)
"""
partial = cls.load(token)
if partial:
session = cls.sa_session
Expand Down Expand Up @@ -9695,30 +9721,63 @@ def save(self):
with transaction(self.sa_session):
self.sa_session.commit()

@classmethod
def username_max_length(cls):
# Note: This is the maximum field length set for the username column of the galaxy_user table.
# A better alternative is to retrieve this number from the table, instead of this const value.
return 255

@classmethod
def changed(cls, user):
"""
The given user instance is ready to be saved
(Required by social_core.storage.UserMixin interface)
"""
cls.sa_session.add(user)
with transaction(cls.sa_session):
cls.sa_session.commit()

@classmethod
def get_username(cls, user):
"""
Return the username for given user
(Required by social_core.storage.UserMixin interface)
"""
return getattr(user, "username", None)

@classmethod
def user_model(cls):
"""
Return the user model
(Required by social_core.storage.UserMixin interface)
"""
return User

@classmethod
def username_max_length(cls):
"""
Return the max length for username
(Required by social_core.storage.UserMixin interface)
"""
# Note: This is the maximum field length set for the username column of the galaxy_user table.
# A better alternative is to retrieve this number from the table, instead of this const value.
return 255

@classmethod
def user_exists(cls, *args, **kwargs):
"""
Return True/False if a User instance exists with the given arguments.
Arguments are directly passed to filter() manager method.
(Required by social_core.storage.UserMixin interface)
"""
stmt_user = select(User).filter_by(*args, **kwargs)
stmt_count = select(func.count()).select_from(stmt_user)
return cls.sa_session.scalar(stmt_count) > 0

@classmethod
def create_user(cls, *args, **kwargs):
"""
This is used by PSA authnz, do not use directly.
Prefer using the user manager.
(Required by social_core.storage.UserMixin interface)
"""
instance = User(*args, **kwargs)
if cls.email_exists(instance.email):
model = cls.user_model()
instance = model(*args, **kwargs)
if cls.get_users_by_email(instance.email):
raise Exception(f"User with this email '{instance.email}' already exists.")
instance.set_random_password()
cls.sa_session.add(instance)
Expand All @@ -9728,33 +9787,50 @@ def create_user(cls, *args, **kwargs):

@classmethod
def get_user(cls, pk):
return UserAuthnzToken.sa_session.get(User, pk)
"""
Return user instance for given id
(Required by social_core.storage.UserMixin interface)
"""
return cls.sa_session.get(User, pk)

@classmethod
def email_exists(cls, email):
stmt = select(User).where(func.lower(User.email) == email.lower()).limit(1)
return bool(cls.sa_session.scalars(stmt).first())
def get_users_by_email(cls, email):
"""
Return users instances for given email address
(Required by social_core.storage.UserMixin interface)
"""
stmt = select(User).where(func.lower(User.email) == email.lower())
return cls.sa_session.scalars(stmt).all()

@classmethod
def get_social_auth(cls, provider, uid):
"""
Return UserSocialAuth for given provider and uid
(Required by social_core.storage.UserMixin interface)
"""
uid = str(uid)
try:
stmt = select(UserAuthnzToken).filter_by(provider=provider, uid=uid).limit(1)
return cls.sa_session.scalars(stmt).first()
except IndexError:
return None
stmt = select(cls).filter_by(provider=provider, uid=uid).limit(1)
return cls.sa_session.scalars(stmt).first()

@classmethod
def get_social_auth_for_user(cls, user, provider=None, id=None):
stmt = select(UserAuthnzToken).filter_by(user_id=user.id)
"""
Return all the UserSocialAuth instances for given user
(Required by social_core.storage.UserMixin interface)
"""
stmt = select(cls).filter_by(user_id=user.id)
if provider:
stmt = stmt.filter_by(provider=provider)
if id:
stmt = stmt.filter_by(id=id)
return cls.sa_session.scalars(stmt)
return cls.sa_session.scalars(stmt).all()

@classmethod
def create_social_auth(cls, user, uid, provider):
"""
Create a UserSocialAuth instance for given user
(Required by social_core.storage.UserMixin interface)
"""
uid = str(uid)
instance = cls(user=user, uid=uid, provider=provider)
cls.sa_session.add(instance)
Expand Down

0 comments on commit 51ee56e

Please sign in to comment.