Skip to content

Commit

Permalink
adds config option for preferred message role for attachments; adds i…
Browse files Browse the repository at this point in the history
…nline option for attachment content in Explorer (#245)
  • Loading branch information
bkrabach authored Nov 14, 2024
1 parent 29e5e78 commit a41da7f
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 39 deletions.
91 changes: 79 additions & 12 deletions assistants/explorer-assistant/assistant/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,21 @@
import logging
import re
import time
from typing import Any, Awaitable, Callable
from typing import Any, Awaitable, Callable, Sequence

import deepmerge
import openai_client
from assistant_extensions.artifacts import ArtifactsExtension
from assistant_extensions.artifacts._model import ArtifactsConfigModel
from assistant_extensions.attachments import AttachmentsExtension
from content_safety.evaluators import CombinedContentSafetyEvaluator
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam, ParsedChatCompletion
from openai.types.chat import (
ChatCompletion,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
ParsedChatCompletion,
)
from semantic_workbench_api_model.workbench_model import (
AssistantStateEvent,
ConversationEvent,
Expand Down Expand Up @@ -227,19 +233,19 @@ async def respond_to_conversation(
}
]

# calculate the token count for the messages so far
token_count = openai_client.num_tokens_from_messages(
messages=completion_messages, model=config.request_config.openai_model
)

# generate the attachment messages from the attachment agent
attachment_messages = await attachments_extension.get_completion_messages_for_attachments(
context, config=config.extensions_config.attachments
context,
config=config.extensions_config.attachments,
)
token_count += openai_client.num_tokens_from_messages(
messages=attachment_messages, model=config.request_config.openai_model
)

# add the attachment messages to the completion messages
completion_messages.extend(attachment_messages)

# calculate the token count for the messages so far
token_count = sum([
openai_client.num_tokens_from_message(model=config.request_config.openai_model, message=completion_message)
for completion_message in completion_messages
])

# calculate the total available tokens for the response generation
available_tokens = config.request_config.max_tokens - config.request_config.response_tokens
Expand All @@ -252,6 +258,14 @@ async def respond_to_conversation(
token_limit=available_tokens - token_count,
)

# add the attachment messages to the completion messages, either inline or as separate messages
if config.use_inline_attachments:
# inject the attachment messages inline into the history messages
history_messages = _inject_attachments_inline(history_messages, attachment_messages)
else:
# add the attachment messages to the completion messages before the history messages
completion_messages.extend(attachment_messages)

# add the history messages to the completion messages
completion_messages.extend(history_messages)

Expand Down Expand Up @@ -606,6 +620,59 @@ async def _get_history_messages(
return history


def _inject_attachments_inline(
history_messages: list[ChatCompletionMessageParam],
attachment_messages: Sequence[ChatCompletionSystemMessageParam | ChatCompletionUserMessageParam],
) -> list[ChatCompletionMessageParam]:
"""
Inject the attachment messages inline into the history messages.
"""

# iterate over the history messages and for every message that contains an attachment,
# find the related attachment message and replace the attachment message with the inline attachment content
for index, history_message in enumerate(history_messages):
# if the history message does not contain content, as a string value, skip
content = history_message.get("content")
if not content or not isinstance(content, str):
continue

# get the attachment filenames string from the history message content
attachment_filenames_string = re.findall(r"Attachment\(s\): (.+)", content)

# if the history message does not contain an attachment filenames string, skip
if not attachment_filenames_string:
continue

# split the attachment filenames string into a list of attachment filenames
attachment_filenames = [filename.strip() for filename in attachment_filenames_string[0].split(",")]

# initialize a list to store the replacement messages
replacement_messages = []

# iterate over the attachment filenames and find the related attachment message
for attachment_filename in attachment_filenames:
# find the related attachment message
attachment_message = next(
(
attachment_message
for attachment_message in attachment_messages
if f"<ATTACHMENT><FILENAME>{attachment_filename}</FILENAME>"
in str(attachment_message.get("content"))
),
None,
)

if attachment_message:
# replace the attachment message with the inline attachment content
replacement_messages.append(attachment_message)

# if there are replacement messages, replace the history message with the replacement messages
if len(replacement_messages) > 0:
history_messages[index : index + 1] = replacement_messages

return history_messages


def _get_response_duration_message(response_duration: float) -> str:
"""
Generate a display friendly message for the response duration, to be added to the footer items.
Expand Down
8 changes: 8 additions & 0 deletions assistants/explorer-assistant/assistant/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,14 @@ class AssistantConfigModel(BaseModel):
UISchema(widget="radio"),
] = CombinedContentSafetyEvaluatorConfig()

use_inline_attachments: Annotated[
bool,
Field(
title="Use Inline Attachments",
description="Experimental: place attachment content where it was uploaded in the conversation history.",
),
] = False

extensions_config: Annotated[
ExtensionsConfigModel,
Field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -834,9 +834,7 @@ async def _gc_attachment_check(
)

# update artifact
filenames = await self._attachments_extension.get_attachment_filenames(
context, config=config.agents_config.attachment_agent
)
filenames = await self._attachments_extension.get_attachment_filenames(context)
filenames_str = ", ".join(filenames)

artifact_dict = guided_conversation.get_artifact_dict()
Expand Down Expand Up @@ -1028,9 +1026,7 @@ async def _gc_outline_feedback(
case _:
conversation_status_str = "user_returned"

filenames = await self._attachments_extension.get_attachment_filenames(
context, config=config.agents_config.attachment_agent
)
filenames = await self._attachments_extension.get_attachment_filenames(context)
filenames_str = ", ".join(filenames)

outline_str: str = ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,6 @@ async def get_completion_messages_for_attachments(
A list of messages for the chat completion.
"""

if not config.include_in_response_generation:
return []

# get attachments, filtered by include_filenames and exclude_filenames
attachments = await _get_attachments(
context,
Expand All @@ -155,10 +152,7 @@ async def get_completion_messages_for_attachments(
return []

messages: list[chat.ChatCompletionSystemMessageParam | chat.ChatCompletionUserMessageParam] = [
{
"role": "system",
"content": config.context_description,
}
_create_message(config, config.context_description)
]

# process each attachment
Expand Down Expand Up @@ -189,24 +183,16 @@ async def get_completion_messages_for_attachments(

error_element = f"<{error_tag}>{attachment.error}</{error_tag}>" if attachment.error else ""
content = f"<{attachment_tag}><{filename_tag}>{attachment.filename}</{filename_tag}>{error_element}<{content_tag}>{attachment.content}</{content_tag}></{attachment_tag}>"
messages.append({
# role of system seems to get better results in the chat completion
"role": "system",
"content": content,
})
messages.append(_create_message(config, content))

return messages

async def get_attachment_filenames(
self,
context: ConversationContext,
config: AttachmentsConfigModel,
include_filenames: list[str] | None = None,
exclude_filenames: list[str] = [],
) -> list[str]:
if not config.include_in_response_generation:
return []

# get attachments, filtered by include_filenames and exclude_filenames
attachments = await _get_attachments(
context,
Expand All @@ -225,6 +211,24 @@ async def get_attachment_filenames(
return filenames


def _create_message(
config: AttachmentsConfigModel, content: str
) -> chat.ChatCompletionSystemMessageParam | chat.ChatCompletionUserMessageParam:
match config.preferred_message_role:
case "system":
return {
"role": "system",
"content": content,
}
case "user":
return {
"role": "user",
"content": content,
}
case _:
raise ValueError(f"unsupported preferred_message_role: {config.preferred_message_role}")


async def _get_attachments(
context: ConversationContext,
error_handler: AttachmentProcessingErrorHandler,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
from typing import Annotated, Any
from typing import Annotated, Any, Literal

from pydantic import BaseModel, Field
from semantic_workbench_assistant.config import UISchema
Expand All @@ -17,14 +17,15 @@ class AttachmentsConfigModel(BaseModel):
" provided for why they were included."
)

include_in_response_generation: Annotated[
bool,
preferred_message_role: Annotated[
Literal["system", "user"],
Field(
description=(
"Whether to include the contents of attachments in the context for general response generation."
"The preferred role for attachment messages. Early testing suggests that the system role works best,"
" but you can experiment with the other roles. Image attachments will always use the user role."
),
),
] = True
] = "system"


class Attachment(BaseModel):
Expand Down

0 comments on commit a41da7f

Please sign in to comment.