Skip to content

Commit

Permalink
feat: added pseudo chat config, added wrapper for Bedrock client
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Dec 4, 2023
1 parent 590b054 commit 53e75d9
Show file tree
Hide file tree
Showing 16 changed files with 439 additions and 379 deletions.
64 changes: 64 additions & 0 deletions aidial_adapter_bedrock/bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import asyncio
import json
from typing import Any, AsyncIterator

import boto3
from botocore.eventstream import EventStream
from botocore.response import StreamingBody

from aidial_adapter_bedrock.utils.concurrency import make_async
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log


class Bedrock:
client: Any

def __init__(self, client: Any):
self.client = client

@classmethod
async def acreate(cls, region: str) -> "Bedrock":
client = await make_async(
lambda _: boto3.Session().client("bedrock-runtime", region), ()
)
return cls(client)

def _create_invoke_params(self, model: str, body: dict) -> dict:
return {
"modelId": model,
"body": json.dumps(body),
"accept": "application/json",
"contentType": "application/json",
}

async def ainvoke_non_streaming(self, model: str, args: dict) -> dict:
params = self._create_invoke_params(model, args)
response = await make_async(
lambda _: self.client.invoke_model(**params), ()
)

log.debug(f"response: {response}")

body: StreamingBody = response["body"]
body_dict = json.loads(body.read())

log.debug(f"response['body']: {body_dict}")

return body_dict

async def ainvoke_streaming(
self, model: str, args: dict
) -> AsyncIterator[dict]:
params = self._create_invoke_params(model, args)
response = self.client.invoke_model_with_response_stream(**params)

log.debug(f"response: {response}")

body: EventStream = response["body"]
for event in body:
chunk = event.get("chunk")
if chunk:
chunk_dict = json.loads(chunk.get("bytes").decode())
log.debug(f"chunk: {chunk_dict}")
yield chunk_dict
await asyncio.sleep(1e-6)
6 changes: 3 additions & 3 deletions aidial_adapter_bedrock/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def __init__(self, region: str):

@dial_exception_decorator
async def chat_completion(self, request: Request, response: Response):
model_params = ModelParameters.create(request)
params = ModelParameters.create(request)
model = await get_bedrock_adapter(
region=self.region,
model_id=request.deployment_id,
model=request.deployment_id,
)

async def generate_response(
Expand All @@ -32,7 +32,7 @@ async def generate_response(
) -> None:
with response.create_choice() as choice:
consumer = ChoiceConsumer(choice)
await model.achat(consumer, model_params, request.messages)
await model.achat(consumer, params, request.messages)
usage.accumulate(consumer.usage)
discarded_messages_set.add(consumer.discarded_messages)

Expand Down
32 changes: 12 additions & 20 deletions aidial_adapter_bedrock/dial_api/request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Mapping, Optional, Union
from typing import List, Optional

from aidial_sdk.chat_completion import Request
from pydantic import BaseModel
Expand All @@ -8,38 +8,30 @@ class ModelParameters(BaseModel):
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stop: List[str] = []
max_tokens: Optional[int] = None
max_prompt_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[Mapping[int, float]] = None
stream: bool = False

@classmethod
def create(cls, request: Request) -> "ModelParameters":
stop: List[str] = []
if request.stop is not None:
stop = (
[request.stop]
if isinstance(request.stop, str)
else request.stop
)

return cls(
temperature=request.temperature,
top_p=request.top_p,
n=request.n,
stop=request.stop,
stop=stop,
max_tokens=request.max_tokens,
max_prompt_tokens=request.max_prompt_tokens,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
logit_bias=request.logit_bias,
stream=request.stream,
)

def add_stop_sequences(self, stop: List[str]) -> "ModelParameters":
if len(stop) == 0:
return self

self_stop: List[str] = []
if self.stop is not None:
if isinstance(self.stop, str):
self_stop = [self.stop]
else:
self_stop = self.stop

return self.copy(update={"stop": [*self_stop, *stop]})
return self.copy(update={"stop": [*self.stop, *stop]})
125 changes: 90 additions & 35 deletions aidial_adapter_bedrock/llm/chat_emulation/pseudo_chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import Enum
from typing import Callable, List, Optional, Set, Tuple
from typing import Callable, List, Optional, Set, Tuple, TypedDict

from pydantic import BaseModel

from aidial_adapter_bedrock.llm.chat_emulation.history import (
FormattedMessage,
Expand All @@ -11,39 +12,72 @@
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
)
from aidial_adapter_bedrock.utils.list import exclude_indices


class RolePrompt(str, Enum):
HUMAN = "\n\nHuman:"
ASSISTANT = "\n\nAssistant:"
class RoleMapping(TypedDict):
system: str
human: str
ai: str


class PseudoChatConf(BaseModel):
prelude_template: Optional[str]
annotate_first: bool
add_invitation: bool
mapping: RoleMapping
separator: str

@property
def prelude(self) -> Optional[str]:
if self.prelude_template is None:
return None
return self.prelude_template.format(**self.mapping)

@property
def stop_sequences(self) -> List[str]:
return [self.separator + self.mapping["human"]]

def format_message(self, message: BaseMessage, is_first: bool) -> str:
role = self.mapping.get(message.type)

if role is None:
raise ValueError(f"Unknown message type: {message.type}")

STOP_SEQUENCES: List[str] = [RolePrompt.HUMAN]
role_prefix = role + " "
if is_first and not self.annotate_first:
role_prefix = ""

separator = self.separator
if is_first:
separator = ""

PRELUDE = f"""
return (separator + role_prefix + message.content.lstrip()).rstrip()


default_conf = PseudoChatConf(
prelude_template="""
You are a helpful assistant participating in a dialog with a user.
The messages from the user start with "{RolePrompt.HUMAN.strip()}".
The messages from you start with "{RolePrompt.ASSISTANT.strip()}".
The messages from the user start with "{ai}".
The messages from you start with "{human}".
Reply to the last message from the user taking into account the preceding dialog history.
====================
""".strip()


def _format_message(message: BaseMessage) -> str:
role = (
RolePrompt.HUMAN
if isinstance(message, (SystemMessage, HumanMessage))
else RolePrompt.ASSISTANT
)
return (role + " " + message.content.lstrip()).rstrip()
""".strip(),
annotate_first=True,
add_invitation=True,
mapping=RoleMapping(
system="Human:",
human="Human:",
ai="Assistant:",
),
separator="\n\n",
)


class PseudoChatHistory(History):
stop_sequences: List[str]
pseudo_history_conf: PseudoChatConf

def trim(
self, count_tokens: Callable[[str], int], max_prompt_tokens: int
Expand Down Expand Up @@ -77,7 +111,8 @@ def trim(
self.messages, discarded_messages
)
if message.source_message
]
],
conf=self.pseudo_history_conf,
),
len(discarded_messages),
)
Expand All @@ -88,12 +123,15 @@ def trim(
source_messages_count - discarded_messages_count == 1
and isinstance(last_source_message, HumanMessage)
):
history = PseudoChatHistory.create([last_source_message])
history = PseudoChatHistory.create(
messages=[last_source_message],
conf=self.pseudo_history_conf,
)
prompt_tokens = sum(
count_tokens(message.text) for message in history.messages
)
if prompt_tokens <= max_prompt_tokens:
return history, len(discarded_messages)
return history, discarded_messages_count

raise ValidationError(
f"The token size of system messages and the last user message ({prompt_tokens}) exceeds"
Expand All @@ -105,32 +143,49 @@ def trim(
)

@classmethod
def create(cls, messages: List[BaseMessage]) -> "PseudoChatHistory":
def create(
cls, messages: List[BaseMessage], conf: PseudoChatConf
) -> "PseudoChatHistory":
if len(messages) == 1 and isinstance(messages[0], HumanMessage):
single_message = messages[0]
message = messages[0]
return cls(
messages=[
FormattedMessage(
text=single_message.content,
source_message=single_message,
text=message.content,
source_message=message,
)
],
stop_sequences=[],
pseudo_history_conf=conf,
)

formatted_messages = [FormattedMessage(text=PRELUDE)]
formatted_messages: List[FormattedMessage] = []

if conf.prelude is not None:
formatted_messages.append(FormattedMessage(text=conf.prelude))

for index, message in enumerate(messages):
for idx, message in enumerate(messages):
formatted_messages.append(
FormattedMessage(
text=_format_message(message),
text=conf.format_message(
message, len(formatted_messages) == 0
),
source_message=message,
is_important=is_important_message(messages, index),
is_important=is_important_message(messages, idx),
)
)

formatted_messages.append(
FormattedMessage(text=_format_message(AIMessage(content="")))
)
if conf.add_invitation:
formatted_messages.append(
FormattedMessage(
text=conf.format_message(
AIMessage(content=""), len(formatted_messages) == 0
)
)
)

return cls(messages=formatted_messages, stop_sequences=STOP_SEQUENCES)
return cls(
messages=formatted_messages,
stop_sequences=conf.stop_sequences,
pseudo_history_conf=conf,
)
Loading

0 comments on commit 53e75d9

Please sign in to comment.