Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/danswer-ai/danswer into fea…
Browse files Browse the repository at this point in the history
…ture/oauth
  • Loading branch information
rkuo-danswer committed Dec 6, 2024
2 parents 02d6cfa + 18e6388 commit 764c953
Show file tree
Hide file tree
Showing 145 changed files with 2,346 additions and 2,312 deletions.
36 changes: 36 additions & 0 deletions backend/alembic/versions/9f696734098f_combine_search_and_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Combine Search and Chat
Revision ID: 9f696734098f
Revises: a8c2065484e6
Create Date: 2024-11-27 15:32:19.694972
"""
from alembic import op
import sqlalchemy as sa

# revision identifiers, used by Alembic.
revision = "9f696734098f"
down_revision = "a8c2065484e6"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.alter_column("chat_session", "description", nullable=True)
op.drop_column("chat_session", "one_shot")
op.drop_column("slack_channel_config", "response_type")


def downgrade() -> None:
op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL")
op.alter_column("chat_session", "description", nullable=False)
op.add_column(
"chat_session",
sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()),
)
op.add_column(
"slack_channel_config",
sa.Column(
"response_type", sa.String(), nullable=False, server_default="citations"
),
)
5 changes: 5 additions & 0 deletions backend/danswer/access/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ class ExternalAccess:

@dataclass(frozen=True)
class DocExternalAccess:
"""
This is just a class to wrap the external access and the document ID
together. It's used for syncing document permissions to Redis.
"""

external_access: ExternalAccess
# The document ID
doc_id: str
Expand Down
6 changes: 1 addition & 5 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.db.users import get_user_by_email
from danswer.server.utils import BasicAuthenticationError
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
Expand All @@ -99,11 +100,6 @@
logger = setup_logger()


class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)


def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
Expand Down
5 changes: 4 additions & 1 deletion backend/danswer/background/celery/apps/primary.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT


logger = setup_logger()

celery_app = Celery(__name__)
Expand Down Expand Up @@ -117,9 +116,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.

# set thread_local=False since we don't control what thread the periodic task might
# reacquire the lock with
lock: RedisLock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
thread_local=False,
)

logger.info("Primary worker lock: Acquire starting.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import upsert_document_by_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
Expand Down Expand Up @@ -262,7 +263,12 @@ def connector_permission_sync_generator_task(
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.permissions.generate_tasks(
self.app, lock, document_external_accesses, source_type
celery_app=self.app,
lock=lock,
new_permissions=document_external_accesses,
source_string=source_type,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
if tasks_generated is None:
return None
Expand Down Expand Up @@ -298,6 +304,8 @@ def update_external_document_permissions_task(
tenant_id: str | None,
serialized_doc_external_access: dict,
source_string: str,
connector_id: int,
credential_id: int,
) -> bool:
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
Expand All @@ -306,18 +314,28 @@ def update_external_document_permissions_task(
external_access = document_external_access.external_access
try:
with get_session_with_tenant(tenant_id) as db_session:
# Then we build the update requests to update vespa
# Add the users to the DB if they don't exist
batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
emails=list(external_access.external_user_emails),
)
upsert_document_external_perms(
# Then we upsert the document's external permissions in postgres
created_new_doc = upsert_document_external_perms(
db_session=db_session,
doc_id=doc_id,
external_access=external_access,
source_type=DocumentSource(source_string),
)

if created_new_doc:
# If a new document was created, we associate it with the cc_pair
upsert_document_by_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
document_ids=[doc_id],
)

logger.debug(
f"Successfully synced postgres document permissions for {doc_id}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,14 @@
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
from ee.danswer.external_permissions.sync_params import (
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
)

logger = setup_logger()

Expand Down Expand Up @@ -107,6 +111,22 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)

# We only want to sync one cc_pair per source type in
# GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
# These are ordered by cc_pair id so the first one is the one we want
cc_pairs_to_dedupe = get_cc_pairs_by_source(
db_session, source, only_sync=True
)
# We only want to sync one cc_pair per source type
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
for cc_pair_to_remove in cc_pairs_to_dedupe[1:]:
cc_pairs = [
cc_pair
for cc_pair in cc_pairs
if cc_pair.id != cc_pair_to_remove.id
]

for cc_pair in cc_pairs:
if _is_external_group_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
Expand Down
3 changes: 3 additions & 0 deletions backend/danswer/background/celery/tasks/indexing/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,9 +789,12 @@ def connector_indexing_task(
)
break

# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
redis_connector_index.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
thread_local=False,
)

acquired = lock.acquire(blocking=False)
Expand Down
6 changes: 5 additions & 1 deletion backend/danswer/background/celery/tasks/pruning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session

from danswer.background.celery.apps.app_base import task_logger
Expand Down Expand Up @@ -239,9 +240,12 @@ def connector_pruning_generator_task(

r = get_redis_client(tenant_id=tenant_id)

lock = r.lock(
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
thread_local=False,
)

acquired = lock.acquire(blocking=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,27 @@
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import ToolCall

from danswer.chat.llm_response_handler import LLMResponseHandlerManager
from danswer.chat.models import AnswerQuestionPossibleReturn
from danswer.chat.models import AnswerStyleConfig
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.llm_response_handler import LLMCall
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.answering.prompts.build import default_build_system_message
from danswer.llm.answering.prompts.build import default_build_user_message
from danswer.llm.answering.stream_processing.answer_response_handler import (
AnswerResponseHandler,
)
from danswer.llm.answering.stream_processing.answer_response_handler import (
from danswer.chat.models import PromptConfig
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
from danswer.chat.prompt_builder.build import default_build_system_message
from danswer.chat.prompt_builder.build import default_build_user_message
from danswer.chat.prompt_builder.build import LLMCall
from danswer.chat.stream_processing.answer_response_handler import (
CitationResponseHandler,
)
from danswer.llm.answering.stream_processing.answer_response_handler import (
from danswer.chat.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler,
)
from danswer.llm.answering.stream_processing.answer_response_handler import (
QuotesResponseHandler,
)
from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
from danswer.chat.stream_processing.utils import map_document_id_order
from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.interfaces import LLM
from danswer.llm.models import PreviousMessage
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolResponse
Expand Down Expand Up @@ -214,18 +208,23 @@ def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:

search_result = SearchTool.get_search_result(current_llm_call) or []

answer_handler: AnswerResponseHandler
if self.answer_style_config.citation_config:
answer_handler = CitationResponseHandler(
context_docs=search_result,
doc_id_to_rank_map=map_document_id_order(search_result),
)
elif self.answer_style_config.quotes_config:
answer_handler = QuotesResponseHandler(
context_docs=search_result,
)
else:
raise ValueError("No answer style config provided")
# Quotes are no longer supported
# answer_handler: AnswerResponseHandler
# if self.answer_style_config.citation_config:
# answer_handler = CitationResponseHandler(
# context_docs=search_result,
# doc_id_to_rank_map=map_document_id_order(search_result),
# )
# elif self.answer_style_config.quotes_config:
# answer_handler = QuotesResponseHandler(
# context_docs=search_result,
# )
# else:
# raise ValueError("No answer style config provided")
answer_handler = CitationResponseHandler(
context_docs=search_result,
doc_id_to_rank_map=map_document_id_order(search_result),
)

response_handler_manager = LLMResponseHandlerManager(
tool_call_handler, answer_handler, self.is_cancelled
Expand Down
Loading

0 comments on commit 764c953

Please sign in to comment.