Skip to content

Commit

Permalink
feat: supported content parts for Claude3
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Oct 2, 2024
1 parent da465f4 commit 7d91ef2
Show file tree
Hide file tree
Showing 20 changed files with 444 additions and 153 deletions.
2 changes: 1 addition & 1 deletion aidial_adapter_bedrock/dial_api/embedding_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
cast,
)

from aidial_sdk.chat_completion.request import Attachment
from aidial_sdk.chat_completion import Attachment
from aidial_sdk.embeddings.request import EmbeddingsRequest

from aidial_adapter_bedrock.llm.errors import ValidationError
Expand Down
2 changes: 1 addition & 1 deletion aidial_adapter_bedrock/embedding/amazon/titan_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from typing import AsyncIterator, List, Self

from aidial_sdk.chat_completion.request import Attachment
from aidial_sdk.chat_completion import Attachment
from aidial_sdk.embeddings import Response as EmbeddingsResponse
from aidial_sdk.embeddings import Usage
from aidial_sdk.embeddings.request import EmbeddingsRequest
Expand Down
4 changes: 2 additions & 2 deletions aidial_adapter_bedrock/llm/chat_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _format_message(self, message: BaseMessage, idx: int) -> str:
else:
cue_prefix = cue + " "

return (cue_prefix + message.content.lstrip()).rstrip()
return (cue_prefix + message.text_content.lstrip()).rstrip()

def get_ai_cue(self) -> Optional[str]:
return self.cues["ai"]
Expand All @@ -69,7 +69,7 @@ def display(self, messages: List[BaseMessage]) -> Tuple[str, List[str]]:
and len(messages) == 1
and isinstance(messages[0], HumanRegularMessage)
):
return messages[0].content, []
return messages[0].text_content, []

ret: List[str] = []

Expand Down
6 changes: 4 additions & 2 deletions aidial_adapter_bedrock/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
)
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log
from aidial_adapter_bedrock.utils.not_implemented import not_implemented
from aidial_adapter_bedrock.utils.request import (
get_message_content_text_content,
)


def _is_empty_system_message(msg: Message) -> bool:
return (
msg.role == Role.SYSTEM
and msg.content is not None
and msg.content.strip() == ""
and get_message_content_text_content(msg.content).strip() == ""
)


Expand Down
275 changes: 198 additions & 77 deletions aidial_adapter_bedrock/llm/message.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,248 @@
from typing import List, Optional, Union

from aidial_sdk.chat_completion import (
CustomContent,
FunctionCall,
Message,
Role,
ToolCall,
)
from abc import ABC, abstractmethod
from typing import List, Optional, Self, Union

from aidial_sdk.chat_completion import CustomContent, FunctionCall
from aidial_sdk.chat_completion import Message as DialMessage
from aidial_sdk.chat_completion import MessageContentPart, Role, ToolCall
from pydantic import BaseModel

from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.utils.request import (
get_message_content_text_content,
is_plain_text_content,
is_text_content_parts,
)


class SystemMessage(BaseModel):
content: str
class MessageABC(ABC, BaseModel):
@abstractmethod
def to_message(self) -> DialMessage: ...

def to_message(self) -> Message:
return Message(role=Role.SYSTEM, content=self.content)
@classmethod
@abstractmethod
def from_message(cls, message: DialMessage) -> Self | None: ...


class HumanRegularMessage(BaseModel):
content: str
class BaseMessageABC(MessageABC):
@property
@abstractmethod
def text_content(self) -> str: ...


class SystemMessage(BaseMessageABC):
content: str | List[MessageContentPart]

def to_message(self) -> DialMessage:
return DialMessage(
role=Role.SYSTEM,
content=self.content,
)

@classmethod
def from_message(cls, message: DialMessage) -> Self | None:
if message.role != Role.SYSTEM:
return None

content = message.content

if content is None:
raise ValidationError("System message is expected to have content")

if isinstance(content, str):
return cls(content=content)

if is_text_content_parts(content):
return cls(content=content) # type: ignore
else:
raise ValidationError(
"Unexpected non-text content parts in system message"
)

@property
def text_content(self) -> str:
return get_message_content_text_content(self.content)


class HumanRegularMessage(BaseMessageABC):
"""MM stands for multi-modal"""

content: str | List[MessageContentPart]
custom_content: Optional[CustomContent] = None

def to_message(self) -> Message:
return Message(
def to_message(self) -> DialMessage:
return DialMessage(
role=Role.USER,
content=self.content,
custom_content=self.custom_content,
)

@classmethod
def from_message(cls, message: DialMessage) -> Self | None:
if message.role != Role.USER:
return None

content = message.content
if content is None:
raise ValidationError(
"User message is expected to have content field"
)

return cls(content=content, custom_content=message.custom_content)

@property
def text_content(self) -> str:
return get_message_content_text_content(self.content)

class HumanToolResultMessage(BaseModel):

class HumanToolResultMessage(MessageABC):
id: str
content: str

def to_message(self) -> Message:
return Message(
def to_message(self) -> DialMessage:
return DialMessage(
role=Role.TOOL,
tool_call_id=self.id,
content=self.content,
)

@classmethod
def from_message(cls, message: DialMessage) -> Self | None:
if message.role != Role.TOOL:
return None

if not is_plain_text_content(message.content):
raise ValidationError(
"The tool message shouldn't contain content parts"
)

if message.content is None or message.tool_call_id is None:
raise ValidationError(
"The tool message is expected to have content and tool_call_id fields"
)

return cls(id=message.tool_call_id, content=message.content)


class HumanFunctionResultMessage(BaseModel):
class HumanFunctionResultMessage(MessageABC):
name: str
content: str

def to_message(self) -> Message:
return Message(
def to_message(self) -> DialMessage:
return DialMessage(
role=Role.FUNCTION,
name=self.name,
content=self.content,
)

@classmethod
def from_message(cls, message: DialMessage) -> Self | None:
if message.role != Role.FUNCTION:
return None

class AIRegularMessage(BaseModel):
if not is_plain_text_content(message.content):
raise ValidationError(
"The function message shouldn't contain content parts"
)

if message.content is None or message.name is None:
raise ValidationError(
"The function message is expected to have content and name fields"
)

return cls(name=message.name, content=message.content)


class AIRegularMessage(BaseMessageABC):
content: str
custom_content: Optional[CustomContent] = None

def to_message(self) -> Message:
return Message(
def to_message(self) -> DialMessage:
return DialMessage(
role=Role.ASSISTANT,
content=self.content,
custom_content=self.custom_content,
)

@classmethod
def from_message(cls, message: DialMessage) -> Self | None:
if message.role != Role.ASSISTANT:
return None

if message.function_call is not None or message.tool_calls is not None:
return None

if not is_plain_text_content(message.content):
raise ValidationError(
"The assistant message shouldn't contain content parts"
)

if message.content is None:
raise ValidationError(
"The assistant message is expected to have content"
)

return cls(
content=message.content, custom_content=message.custom_content
)

@property
def text_content(self) -> str:
return self.content


class AIToolCallMessage(BaseModel):
class AIToolCallMessage(MessageABC):
calls: List[ToolCall]
content: Optional[str] = None

def to_message(self) -> Message:
return Message(
def to_message(self) -> DialMessage:
return DialMessage(
role=Role.ASSISTANT,
content=self.content,
tool_calls=self.calls,
)

@classmethod
def from_message(cls, message: DialMessage) -> Self | None:
if message.role != Role.ASSISTANT:
return None

class AIFunctionCallMessage(BaseModel):
if message.tool_calls is None or message.function_call is not None:
return None

if not is_plain_text_content(message.content):
raise ValidationError(
"The assistant message with tool calls shouldn't contain content parts"
)

return cls(calls=message.tool_calls, content=message.content)


class AIFunctionCallMessage(MessageABC):
call: FunctionCall
content: Optional[str] = None

def to_message(self) -> Message:
return Message(
def to_message(self) -> DialMessage:
return DialMessage(
role=Role.ASSISTANT,
content=self.content,
function_call=self.call,
)

@classmethod
def from_message(cls, message: DialMessage) -> Self | None:
if message.role != Role.ASSISTANT:
return None

if message.function_call is None or message.tool_calls is not None:
return None

if not is_plain_text_content(message.content):
raise ValidationError(
"The assistant message with function call shouldn't contain content parts"
)

return cls(call=message.function_call, content=message.content)


BaseMessage = Union[SystemMessage, HumanRegularMessage, AIRegularMessage]

Expand All @@ -101,51 +254,19 @@ def to_message(self) -> Message:
]


def _parse_assistant_message(
content: Optional[str],
function_call: Optional[FunctionCall],
tool_calls: Optional[List[ToolCall]],
custom_content: Optional[CustomContent],
) -> BaseMessage | ToolMessage:
if content is not None and function_call is None and tool_calls is None:
return AIRegularMessage(content=content, custom_content=custom_content)

if function_call is not None and tool_calls is None:
return AIFunctionCallMessage(call=function_call, content=content)

if function_call is None and tool_calls is not None:
return AIToolCallMessage(calls=tool_calls, content=content)
def parse_dial_message(msg: DialMessage) -> BaseMessage | ToolMessage:

raise ValidationError("Unknown type of assistant message")
message = (
SystemMessage.from_message(msg)
or HumanRegularMessage.from_message(msg)
or HumanToolResultMessage.from_message(msg)
or HumanFunctionResultMessage.from_message(msg)
or AIRegularMessage.from_message(msg)
or AIToolCallMessage.from_message(msg)
or AIFunctionCallMessage.from_message(msg)
)

if message is None:
raise ValidationError("Unknown message type or invalid message")

def parse_dial_message(msg: Message) -> BaseMessage | ToolMessage:
match msg:
case Message(role=Role.SYSTEM, content=content) if content is not None:
return SystemMessage(content=content)
case Message(
role=Role.USER, content=content, custom_content=custom_content
) if content is not None:
return HumanRegularMessage(
content=content, custom_content=custom_content
)
case Message(
role=Role.ASSISTANT,
content=content,
function_call=function_call,
tool_calls=tool_calls,
custom_content=custom_content,
):
return _parse_assistant_message(
content, function_call, tool_calls, custom_content
)
case Message(
role=Role.FUNCTION, name=name, content=content
) if content is not None and name is not None:
return HumanFunctionResultMessage(name=name, content=content)
case Message(
role=Role.TOOL, tool_call_id=id, content=content
) if content is not None and id is not None:
return HumanToolResultMessage(id=id, content=content)
case _:
raise ValidationError("Unknown message type or invalid message")
return message
Loading

0 comments on commit 7d91ef2

Please sign in to comment.