Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User Filter Polish #3681

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions backend/onyx/db/connector_credential_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector import fetch_connector_by_id
from onyx.db.constants import SYSTEM_USER
from onyx.db.constants import SystemUser
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.enums import AccessType
Expand All @@ -33,8 +35,13 @@


def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
stmt: Select, user: User | None | SystemUser, get_editable: bool = True
) -> Select:
if isinstance(user, SystemUser):
if user is SYSTEM_USER:
return stmt
raise ValueError("Bad SystemUser object")

# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
Expand Down Expand Up @@ -94,7 +101,7 @@ def _add_user_filters(

def get_connector_credential_pairs_for_user(
db_session: Session,
user: User | None,
user: User | None | SystemUser,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
Expand All @@ -105,6 +112,7 @@ def get_connector_credential_pairs_for_user(
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))

stmt = _add_user_filters(stmt, user, get_editable)

if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))

Expand All @@ -115,12 +123,11 @@ def get_connector_credential_pairs(
db_session: Session,
ids: list[int] | None = None,
) -> list[ConnectorCredentialPair]:
stmt = select(ConnectorCredentialPair).distinct()

if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))

return list(db_session.scalars(stmt).all())
return get_connector_credential_pairs_for_user(
db_session=db_session,
user=SYSTEM_USER,
ids=ids,
)


def add_deletion_failure_message(
Expand Down Expand Up @@ -155,7 +162,7 @@ def get_connector_credential_pair_for_user(
db_session: Session,
connector_id: int,
credential_id: int,
user: User | None,
user: User | None | SystemUser,
get_editable: bool = True,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair)
Expand All @@ -171,17 +178,18 @@ def get_connector_credential_pair(
connector_id: int,
credential_id: int,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair)
stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id)
stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
return get_connector_credential_pair_for_user(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
user=SYSTEM_USER,
)


def get_connector_credential_pair_from_id_for_user(
cc_pair_id: int,
db_session: Session,
user: User | None,
user: User | None | SystemUser,
get_editable: bool = True,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
Expand All @@ -195,10 +203,11 @@ def get_connector_credential_pair_from_id(
db_session: Session,
cc_pair_id: int,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
return get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=SYSTEM_USER,
)


def get_last_successful_attempt_time(
Expand Down
10 changes: 10 additions & 0 deletions backend/onyx/db/constants.py
Original file line number Diff line number Diff line change
@@ -1 +1,11 @@
from typing import Final


SLACK_BOT_PERSONA_PREFIX = "__slack_bot_persona__"


class SystemUser:
"""Represents the system user for internal operations"""


SYSTEM_USER: Final = SystemUser()
32 changes: 21 additions & 11 deletions backend/onyx/db/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from onyx.db.constants import SYSTEM_USER
from onyx.db.constants import SystemUser
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
Expand Down Expand Up @@ -42,11 +44,17 @@

def _add_user_filters(
stmt: Select,
user: User | None,
user: User | None | SystemUser,
get_editable: bool = True,
) -> Select:
"""Attaches filters to the statement to ensure that the user can only
access the appropriate credentials"""

if isinstance(user, SystemUser):
if user is SYSTEM_USER:
return stmt
raise ValueError("Bad SystemUser object")

if user is None:
if not DISABLE_AUTH:
raise ValueError("Anonymous users are not allowed to access credentials")
Expand Down Expand Up @@ -151,7 +159,7 @@ def fetch_credentials_for_user(

def fetch_credential_by_id_for_user(
credential_id: int,
user: User | None,
user: User | None | SystemUser,
db_session: Session,
get_editable: bool = True,
) -> Credential | None:
Expand All @@ -171,16 +179,16 @@ def fetch_credential_by_id(
db_session: Session,
credential_id: int,
) -> Credential | None:
stmt = select(Credential).distinct()
stmt = stmt.where(Credential.id == credential_id)
result = db_session.execute(stmt)
credential = result.scalar_one_or_none()
return credential
return fetch_credential_by_id_for_user(
credential_id=credential_id,
user=SYSTEM_USER,
db_session=db_session,
)


def fetch_credentials_by_source_for_user(
db_session: Session,
user: User | None,
user: User | None | SystemUser,
document_source: DocumentSource | None = None,
get_editable: bool = True,
) -> list[Credential]:
Expand All @@ -194,9 +202,11 @@ def fetch_credentials_by_source(
db_session: Session,
document_source: DocumentSource | None = None,
) -> list[Credential]:
base_query = select(Credential).where(Credential.source == document_source)
credentials = db_session.execute(base_query).scalars().all()
return list(credentials)
return fetch_credentials_by_source_for_user(
db_session=db_session,
user=SYSTEM_USER,
document_source=document_source,
)


def swap_credentials_connector(
Expand Down
20 changes: 18 additions & 2 deletions backend/onyx/db/document_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.constants import SYSTEM_USER
from onyx.db.constants import SystemUser
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
Expand All @@ -35,8 +37,13 @@


def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
stmt: Select, user: User | None | SystemUser, get_editable: bool = True
) -> Select:
if isinstance(user, SystemUser):
if user is SYSTEM_USER:
return stmt
raise ValueError("Bad SystemUser object")

# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
Expand Down Expand Up @@ -487,14 +494,23 @@ def fetch_document_sets(

def fetch_all_document_sets_for_user(
db_session: Session,
user: User | None,
user: User | None | SystemUser,
get_editable: bool = True,
) -> Sequence[DocumentSetDBModel]:
stmt = select(DocumentSetDBModel).distinct()
stmt = _add_user_filters(stmt, user, get_editable=get_editable)
return db_session.scalars(stmt).all()


def fetch_all_document_sets(
db_session: Session,
) -> Sequence[DocumentSetDBModel]:
return fetch_all_document_sets_for_user(
db_session=db_session,
user=SYSTEM_USER,
)


def fetch_documents_for_document_set_paginated(
document_set_id: int,
db_session: Session,
Expand Down
35 changes: 28 additions & 7 deletions backend/onyx/db/persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
from onyx.db.constants import SYSTEM_USER
from onyx.db.constants import SystemUser
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.models import DocumentSet
from onyx.db.models import Persona
Expand All @@ -44,8 +46,13 @@


def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
stmt: Select, user: User | None | SystemUser, get_editable: bool = True
) -> Select:
if isinstance(user, SystemUser):
if user is SYSTEM_USER:
return stmt
raise ValueError("Bad SystemUser object")

# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
Expand Down Expand Up @@ -111,7 +118,10 @@ def _add_user_filters(


def fetch_persona_by_id_for_user(
db_session: Session, persona_id: int, user: User | None, get_editable: bool = True
db_session: Session,
persona_id: int,
user: User | None | SystemUser,
get_editable: bool = True,
) -> Persona:
stmt = select(Persona).where(Persona.id == persona_id).distinct()
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
Expand All @@ -124,6 +134,17 @@ def fetch_persona_by_id_for_user(
return persona


def fetch_persona_by_id(
db_session: Session,
persona_id: int,
) -> Persona:
return fetch_persona_by_id_for_user(
db_session=db_session,
persona_id=persona_id,
user=SYSTEM_USER,
)


def get_best_persona_id_for_user(
db_session: Session, user: User | None, persona_id: int | None = None
) -> int | None:
Expand Down Expand Up @@ -285,7 +306,7 @@ def get_prompts(

def get_personas_for_user(
# if user is `None` assume the user is an admin or auth is disabled
user: User | None,
user: User | None | SystemUser,
db_session: Session,
get_editable: bool = True,
include_default: bool = True,
Expand Down Expand Up @@ -315,10 +336,10 @@ def get_personas_for_user(


def get_personas(db_session: Session) -> Sequence[Persona]:
stmt = select(Persona).distinct()
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
stmt = stmt.where(Persona.deleted.is_(False))
return db_session.execute(stmt).unique().scalars().all()
return get_personas_for_user(
user=SYSTEM_USER,
db_session=db_session,
)


def mark_persona_as_deleted(
Expand Down
Loading