Skip to content

Commit

Permalink
make external permissioned user creation case insensitive
Browse files Browse the repository at this point in the history
  • Loading branch information
hagen-danswer committed Nov 21, 2024
1 parent deee237 commit a63bfcd
Showing 1 changed file with 23 additions and 18 deletions.
41 changes: 23 additions & 18 deletions backend/danswer/db/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,6 @@ def list_users(
return db_session.scalars(stmt).unique().all()


def get_users_by_emails(
db_session: Session, emails: list[str]
) -> tuple[list[User], list[str]]:
# Use distinct to avoid duplicates
stmt = select(User).filter(User.email.in_(emails)) # type: ignore
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
found_users_emails = [user.email for user in found_users]
missing_user_emails = [email for email in emails if email not in found_users_emails]
return found_users, missing_user_emails


def get_user_by_email(email: str, db_session: Session) -> User | None:
user = (
db_session.query(User)
Expand All @@ -128,7 +117,7 @@ def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None:
return db_session.query(User).filter(User.id == user_id).first() # type: ignore


def _generate_non_web_slack_user(email: str) -> User:
def _generate_slack_user(email: str) -> User:
fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
Expand All @@ -149,13 +138,29 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
db_session.commit()
return user

user = _generate_non_web_slack_user(email=email)
user = _generate_slack_user(email=email)
db_session.add(user)
db_session.commit()
return user


def _generate_non_web_permissioned_user(email: str) -> User:
def _get_users_by_emails(
db_session: Session, lower_emails: list[str]
) -> tuple[list[User], list[str]]:
stmt = select(User).filter(func.lower(User.email).in_(lower_emails)) # type: ignore
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list

# Extract found emails and convert to lowercase to avoid case sensitivity issues
found_users_emails = [user.email.lower() for user in found_users]

# Separate emails for users that were not found
missing_user_emails = [
email for email in lower_emails if email not in found_users_emails
]
return found_users, missing_user_emails


def _generate_ext_permissioned_user(email: str) -> User:
fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
Expand All @@ -169,12 +174,12 @@ def _generate_non_web_permissioned_user(email: str) -> User:
def batch_add_ext_perm_user_if_not_exists(
db_session: Session, emails: list[str]
) -> list[User]:
emails = [email.lower() for email in emails]
found_users, missing_user_emails = get_users_by_emails(db_session, emails)
lower_emails = [email.lower() for email in emails]
found_users, missing_lower_emails = _get_users_by_emails(db_session, lower_emails)

new_users: list[User] = []
for email in missing_user_emails:
new_users.append(_generate_non_web_permissioned_user(email=email))
for email in missing_lower_emails:
new_users.append(_generate_ext_permissioned_user(email=email))

db_session.add_all(new_users)
db_session.commit()
Expand Down

0 comments on commit a63bfcd

Please sign in to comment.