Skip to content

Commit

Permalink
Fix UserAuthnzToken's social_core interface methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavcs committed Feb 22, 2024
1 parent 036e469 commit b1ef5c1
Showing 1 changed file with 69 additions and 19 deletions.
88 changes: 69 additions & 19 deletions lib/galaxy/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9695,30 +9695,63 @@ def save(self):
with transaction(self.sa_session):
self.sa_session.commit()

@classmethod
def username_max_length(cls):
# Note: This is the maximum field length set for the username column of the galaxy_user table.
# A better alternative is to retrieve this number from the table, instead of this const value.
return 255

@classmethod
def changed(cls, user):
"""
The given user instance is ready to be saved
(Required by social_core.storage.UserMixin interface)
"""
cls.sa_session.add(user)
with transaction(cls.sa_session):
cls.sa_session.commit()

@classmethod
def get_username(cls, user):
"""
Return the username for given user
(Required by social_core.storage.UserMixin interface)
"""
return getattr(user, "username", None)

@classmethod
def user_model(cls):
"""
Return the user model
(Required by social_core.storage.UserMixin interface)
"""
return User

@classmethod
def username_max_length(cls):
"""
Return the max length for username
(Required by social_core.storage.UserMixin interface)
"""
# Note: This is the maximum field length set for the username column of the galaxy_user table.
# A better alternative is to retrieve this number from the table, instead of this const value.
return 255

@classmethod
def user_exists(cls, *args, **kwargs):
"""
Return True/False if a User instance exists with the given arguments.
Arguments are directly passed to filter() manager method.
(Required by social_core.storage.UserMixin interface)
"""
stmt_user = select(User).filter_by(*args, **kwargs)
stmt_count = select(func.count()).select_from(stmt_user)
return cls.sa_session.scalar(stmt_count) > 0

@classmethod
def create_user(cls, *args, **kwargs):
"""
This is used by PSA authnz, do not use directly.
Prefer using the user manager.
(Required by social_core.storage.UserMixin interface)
"""
instance = User(*args, **kwargs)
if cls.email_exists(instance.email):
model = cls.user_model()
instance = model(*args, **kwargs)
if cls.get_users_by_email(instance.email):
raise Exception(f"User with this email '{instance.email}' already exists.")
instance.set_random_password()
cls.sa_session.add(instance)
Expand All @@ -9728,33 +9761,50 @@ def create_user(cls, *args, **kwargs):

@classmethod
def get_user(cls, pk):
return UserAuthnzToken.sa_session.get(User, pk)
"""
Return user instance for given id
(Required by social_core.storage.UserMixin interface)
"""
return cls.sa_session.get(User, pk)

@classmethod
def email_exists(cls, email):
stmt = select(User).where(func.lower(User.email) == email.lower()).limit(1)
return bool(cls.sa_session.scalars(stmt).first())
def get_users_by_email(cls, email):
"""
Return users instances for given email address
(Required by social_core.storage.UserMixin interface)
"""
stmt = select(User).where(func.lower(User.email) == email.lower())
return cls.sa_session.scalars(stmt).all()

@classmethod
def get_social_auth(cls, provider, uid):
"""
Return UserSocialAuth for given provider and uid
(Required by social_core.storage.UserMixin interface)
"""
uid = str(uid)
try:
stmt = select(UserAuthnzToken).filter_by(provider=provider, uid=uid).limit(1)
return cls.sa_session.scalars(stmt).first()
except IndexError:
return None
stmt = select(cls).filter_by(provider=provider, uid=uid).limit(1)
return cls.sa_session.scalars(stmt).first()

@classmethod
def get_social_auth_for_user(cls, user, provider=None, id=None):
stmt = select(UserAuthnzToken).filter_by(user_id=user.id)
"""
Return all the UserSocialAuth instances for given user
(Required by social_core.storage.UserMixin interface)
"""
stmt = select(cls).filter_by(user_id=user.id)
if provider:
stmt = stmt.filter_by(provider=provider)
if id:
stmt = stmt.filter_by(id=id)
return cls.sa_session.scalars(stmt)
return cls.sa_session.scalars(stmt).all()

@classmethod
def create_social_auth(cls, user, uid, provider):
"""
Create a UserSocialAuth instance for given user
(Required by social_core.storage.UserMixin interface)
"""
uid = str(uid)
instance = cls(user=user, uid=uid, provider=provider)
cls.sa_session.add(instance)
Expand Down

0 comments on commit b1ef5c1

Please sign in to comment.