Skip to content

Commit

Permalink
User Filter Polish
Browse files Browse the repository at this point in the history
  • Loading branch information
hagen-danswer committed Jan 15, 2025
1 parent eb70699 commit c126ee7
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 39 deletions.
47 changes: 28 additions & 19 deletions backend/onyx/db/connector_credential_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector import fetch_connector_by_id
from onyx.db.constants import SYSTEM_USER
from onyx.db.constants import SystemUser
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
Expand All @@ -32,8 +34,13 @@


def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
stmt: Select, user: User | None | SystemUser, get_editable: bool = True
) -> Select:
if isinstance(user, SystemUser):
if user is SYSTEM_USER:
return stmt
raise ValueError("Bad SystemUser object")

# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
Expand Down Expand Up @@ -93,7 +100,7 @@ def _add_user_filters(

def get_connector_credential_pairs_for_user(
db_session: Session,
user: User | None,
user: User | None | SystemUser,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
Expand All @@ -104,6 +111,7 @@ def get_connector_credential_pairs_for_user(
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))

stmt = _add_user_filters(stmt, user, get_editable)

if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))

Expand All @@ -114,12 +122,11 @@ def get_connector_credential_pairs(
db_session: Session,
ids: list[int] | None = None,
) -> list[ConnectorCredentialPair]:
stmt = select(ConnectorCredentialPair).distinct()

if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))

return list(db_session.scalars(stmt).all())
return get_connector_credential_pairs_for_user(
db_session=db_session,
user=SYSTEM_USER,
ids=ids,
)


def add_deletion_failure_message(
Expand Down Expand Up @@ -154,7 +161,7 @@ def get_connector_credential_pair_for_user(
db_session: Session,
connector_id: int,
credential_id: int,
user: User | None,
user: User | None | SystemUser,
get_editable: bool = True,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair)
Expand All @@ -170,17 +177,18 @@ def get_connector_credential_pair(
connector_id: int,
credential_id: int,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair)
stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id)
stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
return get_connector_credential_pair_for_user(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
user=SYSTEM_USER,
)


def get_connector_credential_pair_from_id_for_user(
cc_pair_id: int,
db_session: Session,
user: User | None,
user: User | None | SystemUser,
get_editable: bool = True,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
Expand All @@ -194,10 +202,11 @@ def get_connector_credential_pair_from_id(
db_session: Session,
cc_pair_id: int,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
return get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=SYSTEM_USER,
)


def get_last_successful_attempt_time(
Expand Down
10 changes: 10 additions & 0 deletions backend/onyx/db/constants.py
Original file line number Diff line number Diff line change
@@ -1 +1,11 @@
from typing import Final


SLACK_BOT_PERSONA_PREFIX = "__slack_bot_persona__"


class SystemUser:
"""Represents the system user for internal operations"""


SYSTEM_USER: Final = SystemUser()
32 changes: 21 additions & 11 deletions backend/onyx/db/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from onyx.db.constants import SYSTEM_USER
from onyx.db.constants import SystemUser
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
Expand Down Expand Up @@ -42,11 +44,17 @@

def _add_user_filters(
stmt: Select,
user: User | None,
user: User | None | SystemUser,
get_editable: bool = True,
) -> Select:
"""Attaches filters to the statement to ensure that the user can only
access the appropriate credentials"""

if isinstance(user, SystemUser):
if user is SYSTEM_USER:
return stmt
raise ValueError("Bad SystemUser object")

if user is None:
if not DISABLE_AUTH:
raise ValueError("Anonymous users are not allowed to access credentials")
Expand Down Expand Up @@ -151,7 +159,7 @@ def fetch_credentials_for_user(

def fetch_credential_by_id_for_user(
credential_id: int,
user: User | None,
user: User | None | SystemUser,
db_session: Session,
get_editable: bool = True,
) -> Credential | None:
Expand All @@ -171,16 +179,16 @@ def fetch_credential_by_id(
db_session: Session,
credential_id: int,
) -> Credential | None:
stmt = select(Credential).distinct()
stmt = stmt.where(Credential.id == credential_id)
result = db_session.execute(stmt)
credential = result.scalar_one_or_none()
return credential
return fetch_credential_by_id_for_user(
credential_id=credential_id,
user=SYSTEM_USER,
db_session=db_session,
)


def fetch_credentials_by_source_for_user(
db_session: Session,
user: User | None,
user: User | None | SystemUser,
document_source: DocumentSource | None = None,
get_editable: bool = True,
) -> list[Credential]:
Expand All @@ -194,9 +202,11 @@ def fetch_credentials_by_source(
db_session: Session,
document_source: DocumentSource | None = None,
) -> list[Credential]:
base_query = select(Credential).where(Credential.source == document_source)
credentials = db_session.execute(base_query).scalars().all()
return list(credentials)
return fetch_credentials_by_source_for_user(
db_session=db_session,
user=SYSTEM_USER,
document_source=document_source,
)


def swap_credentials_connector(
Expand Down
20 changes: 18 additions & 2 deletions backend/onyx/db/document_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.constants import SYSTEM_USER
from onyx.db.constants import SystemUser
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
Expand All @@ -35,8 +37,13 @@


def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
stmt: Select, user: User | None | SystemUser, get_editable: bool = True
) -> Select:
if isinstance(user, SystemUser):
if user is SYSTEM_USER:
return stmt
raise ValueError("Bad SystemUser object")

# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
Expand Down Expand Up @@ -487,14 +494,23 @@ def fetch_document_sets(

def fetch_all_document_sets_for_user(
db_session: Session,
user: User | None,
user: User | None | SystemUser,
get_editable: bool = True,
) -> Sequence[DocumentSetDBModel]:
stmt = select(DocumentSetDBModel).distinct()
stmt = _add_user_filters(stmt, user, get_editable=get_editable)
return db_session.scalars(stmt).all()


def fetch_all_document_sets(
db_session: Session,
) -> Sequence[DocumentSetDBModel]:
return fetch_all_document_sets_for_user(
db_session=db_session,
user=SYSTEM_USER,
)


def fetch_documents_for_document_set_paginated(
document_set_id: int,
db_session: Session,
Expand Down
35 changes: 28 additions & 7 deletions backend/onyx/db/persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
from onyx.db.constants import SYSTEM_USER
from onyx.db.constants import SystemUser
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.models import DocumentSet
from onyx.db.models import Persona
Expand All @@ -44,8 +46,13 @@


def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
stmt: Select, user: User | None | SystemUser, get_editable: bool = True
) -> Select:
if isinstance(user, SystemUser):
if user is SYSTEM_USER:
return stmt
raise ValueError("Bad SystemUser object")

# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
Expand Down Expand Up @@ -111,7 +118,10 @@ def _add_user_filters(


def fetch_persona_by_id_for_user(
db_session: Session, persona_id: int, user: User | None, get_editable: bool = True
db_session: Session,
persona_id: int,
user: User | None | SystemUser,
get_editable: bool = True,
) -> Persona:
stmt = select(Persona).where(Persona.id == persona_id).distinct()
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
Expand All @@ -124,6 +134,17 @@ def fetch_persona_by_id_for_user(
return persona


def fetch_persona_by_id(
db_session: Session,
persona_id: int,
) -> Persona:
return fetch_persona_by_id_for_user(
db_session=db_session,
persona_id=persona_id,
user=SYSTEM_USER,
)


def get_best_persona_id_for_user(
db_session: Session, user: User | None, persona_id: int | None = None
) -> int | None:
Expand Down Expand Up @@ -285,7 +306,7 @@ def get_prompts(

def get_personas_for_user(
# if user is `None` assume the user is an admin or auth is disabled
user: User | None,
user: User | None | SystemUser,
db_session: Session,
get_editable: bool = True,
include_default: bool = True,
Expand Down Expand Up @@ -315,10 +336,10 @@ def get_personas_for_user(


def get_personas(db_session: Session) -> Sequence[Persona]:
stmt = select(Persona).distinct()
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
stmt = stmt.where(Persona.deleted.is_(False))
return db_session.execute(stmt).unique().scalars().all()
return get_personas_for_user(
user=SYSTEM_USER,
db_session=db_session,
)


def mark_persona_as_deleted(
Expand Down

0 comments on commit c126ee7

Please sign in to comment.