Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added pseudo chat config; added wrapper for Bedrock client #36

Merged
merged 6 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions aidial_adapter_bedrock/bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import json
from typing import Any, AsyncIterator

import boto3
from botocore.eventstream import EventStream
from botocore.response import StreamingBody

from aidial_adapter_bedrock.utils.concurrency import (
make_async,
to_async_iterator,
)
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log


class Bedrock:
client: Any

def __init__(self, client: Any):
self.client = client

@classmethod
async def acreate(cls, region: str) -> "Bedrock":
client = await make_async(
lambda: boto3.Session().client("bedrock-runtime", region)
)
return cls(client)

def _create_invoke_params(self, model: str, body: dict) -> dict:
return {
"modelId": model,
"body": json.dumps(body),
"accept": "application/json",
"contentType": "application/json",
}

async def ainvoke_non_streaming(self, model: str, args: dict) -> dict:
params = self._create_invoke_params(model, args)
response = await make_async(lambda: self.client.invoke_model(**params))

log.debug(f"response: {response}")

body: StreamingBody = response["body"]
body_dict = json.loads(await make_async(lambda: body.read()))

log.debug(f"response['body']: {body_dict}")

return body_dict

async def ainvoke_streaming(
self, model: str, args: dict
) -> AsyncIterator[dict]:
params = self._create_invoke_params(model, args)
response = await make_async(
lambda: self.client.invoke_model_with_response_stream(**params)
)

log.debug(f"response: {response}")

body: EventStream = response["body"]

async for event in to_async_iterator(iter(body)):
chunk = event.get("chunk")
if chunk:
chunk_dict = json.loads(chunk.get("bytes").decode())
log.debug(f"chunk: {chunk_dict}")
yield chunk_dict
6 changes: 3 additions & 3 deletions aidial_adapter_bedrock/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def __init__(self, region: str):

@dial_exception_decorator
async def chat_completion(self, request: Request, response: Response):
model_params = ModelParameters.create(request)
params = ModelParameters.create(request)
model = await get_bedrock_adapter(
region=self.region,
model_id=request.deployment_id,
model=request.deployment_id,
)

async def generate_response(
Expand All @@ -32,7 +32,7 @@ async def generate_response(
) -> None:
with response.create_choice() as choice:
consumer = ChoiceConsumer(choice)
await model.achat(consumer, model_params, request.messages)
await model.achat(consumer, params, request.messages)
usage.accumulate(consumer.usage)
discarded_messages_set.add(consumer.discarded_messages)

Expand Down
32 changes: 12 additions & 20 deletions aidial_adapter_bedrock/dial_api/request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Mapping, Optional, Union
from typing import List, Optional

from aidial_sdk.chat_completion import Request
from pydantic import BaseModel
Expand All @@ -8,38 +8,30 @@ class ModelParameters(BaseModel):
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stop: List[str] = []
max_tokens: Optional[int] = None
max_prompt_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[Mapping[int, float]] = None
stream: bool = False

@classmethod
def create(cls, request: Request) -> "ModelParameters":
stop: List[str] = []
if request.stop is not None:
stop = (
[request.stop]
if isinstance(request.stop, str)
else request.stop
)

return cls(
temperature=request.temperature,
top_p=request.top_p,
n=request.n,
stop=request.stop,
stop=stop,
max_tokens=request.max_tokens,
max_prompt_tokens=request.max_prompt_tokens,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
logit_bias=request.logit_bias,
stream=request.stream,
)

def add_stop_sequences(self, stop: List[str]) -> "ModelParameters":
if len(stop) == 0:
return self

self_stop: List[str] = []
if self.stop is not None:
if isinstance(self.stop, str):
self_stop = [self.stop]
else:
self_stop = self.stop

return self.copy(update={"stop": [*self_stop, *stop]})
return self.copy(update={"stop": [*self.stop, *stop]})
125 changes: 90 additions & 35 deletions aidial_adapter_bedrock/llm/chat_emulation/pseudo_chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import Enum
from typing import Callable, List, Optional, Set, Tuple
from typing import Callable, List, Optional, Set, Tuple, TypedDict

from pydantic import BaseModel

from aidial_adapter_bedrock.llm.chat_emulation.history import (
FormattedMessage,
Expand All @@ -11,39 +12,72 @@
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
)
from aidial_adapter_bedrock.utils.list import exclude_indices


class RolePrompt(str, Enum):
HUMAN = "\n\nHuman:"
ASSISTANT = "\n\nAssistant:"
class RoleMapping(TypedDict):
system: str
human: str
ai: str


class PseudoChatConf(BaseModel):
prelude_template: Optional[str]
annotate_first: bool
add_invitation: bool
mapping: RoleMapping
separator: str

@property
def prelude(self) -> Optional[str]:
if self.prelude_template is None:
return None
return self.prelude_template.format(**self.mapping)

@property
def stop_sequences(self) -> List[str]:
return [self.separator + self.mapping["human"]]

def format_message(self, message: BaseMessage, is_first: bool) -> str:
role = self.mapping.get(message.type)

if role is None:
raise ValueError(f"Unknown message type: {message.type}")

STOP_SEQUENCES: List[str] = [RolePrompt.HUMAN]
role_prefix = role + " "
if is_first and not self.annotate_first:
role_prefix = ""

separator = self.separator
if is_first:
separator = ""

PRELUDE = f"""
return (separator + role_prefix + message.content.lstrip()).rstrip()


default_conf = PseudoChatConf(
prelude_template="""
You are a helpful assistant participating in a dialog with a user.
The messages from the user start with "{RolePrompt.HUMAN.strip()}".
The messages from you start with "{RolePrompt.ASSISTANT.strip()}".
The messages from the user start with "{ai}".
The messages from you start with "{human}".
Reply to the last message from the user taking into account the preceding dialog history.
====================
""".strip()


def _format_message(message: BaseMessage) -> str:
role = (
RolePrompt.HUMAN
if isinstance(message, (SystemMessage, HumanMessage))
else RolePrompt.ASSISTANT
)
return (role + " " + message.content.lstrip()).rstrip()
""".strip(),
annotate_first=True,
add_invitation=True,
mapping=RoleMapping(
system="Human:",
human="Human:",
ai="Assistant:",
),
separator="\n\n",
)


class PseudoChatHistory(History):
stop_sequences: List[str]
pseudo_history_conf: PseudoChatConf

def trim(
self, count_tokens: Callable[[str], int], max_prompt_tokens: int
Expand Down Expand Up @@ -77,7 +111,8 @@ def trim(
self.messages, discarded_messages
)
if message.source_message
]
],
conf=self.pseudo_history_conf,
),
len(discarded_messages),
)
Expand All @@ -88,12 +123,15 @@ def trim(
source_messages_count - discarded_messages_count == 1
and isinstance(last_source_message, HumanMessage)
):
history = PseudoChatHistory.create([last_source_message])
history = PseudoChatHistory.create(
messages=[last_source_message],
conf=self.pseudo_history_conf,
)
prompt_tokens = sum(
count_tokens(message.text) for message in history.messages
)
if prompt_tokens <= max_prompt_tokens:
return history, len(discarded_messages)
return history, discarded_messages_count

raise ValidationError(
f"The token size of system messages and the last user message ({prompt_tokens}) exceeds"
Expand All @@ -105,32 +143,49 @@ def trim(
)

@classmethod
def create(cls, messages: List[BaseMessage]) -> "PseudoChatHistory":
def create(
cls, messages: List[BaseMessage], conf: PseudoChatConf
) -> "PseudoChatHistory":
if len(messages) == 1 and isinstance(messages[0], HumanMessage):
single_message = messages[0]
message = messages[0]
return cls(
messages=[
FormattedMessage(
text=single_message.content,
source_message=single_message,
text=message.content,
source_message=message,
)
],
stop_sequences=[],
pseudo_history_conf=conf,
)

formatted_messages = [FormattedMessage(text=PRELUDE)]
formatted_messages: List[FormattedMessage] = []

if conf.prelude is not None:
formatted_messages.append(FormattedMessage(text=conf.prelude))

for index, message in enumerate(messages):
for idx, message in enumerate(messages):
formatted_messages.append(
FormattedMessage(
text=_format_message(message),
text=conf.format_message(
message, len(formatted_messages) == 0
),
source_message=message,
is_important=is_important_message(messages, index),
is_important=is_important_message(messages, idx),
)
)

formatted_messages.append(
FormattedMessage(text=_format_message(AIMessage(content="")))
)
if conf.add_invitation:
formatted_messages.append(
FormattedMessage(
text=conf.format_message(
AIMessage(content=""), len(formatted_messages) == 0
)
)
)

return cls(messages=formatted_messages, stop_sequences=STOP_SEQUENCES)
return cls(
messages=formatted_messages,
stop_sequences=conf.stop_sequences,
pseudo_history_conf=conf,
)
Loading