diff --git a/aidial_adapter_bedrock/bedrock.py b/aidial_adapter_bedrock/bedrock.py new file mode 100644 index 0000000..766d3db --- /dev/null +++ b/aidial_adapter_bedrock/bedrock.py @@ -0,0 +1,66 @@ +import json +from typing import Any, AsyncIterator + +import boto3 +from botocore.eventstream import EventStream +from botocore.response import StreamingBody + +from aidial_adapter_bedrock.utils.concurrency import ( + make_async, + to_async_iterator, +) +from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log + + +class Bedrock: + client: Any + + def __init__(self, client: Any): + self.client = client + + @classmethod + async def acreate(cls, region: str) -> "Bedrock": + client = await make_async( + lambda: boto3.Session().client("bedrock-runtime", region) + ) + return cls(client) + + def _create_invoke_params(self, model: str, body: dict) -> dict: + return { + "modelId": model, + "body": json.dumps(body), + "accept": "application/json", + "contentType": "application/json", + } + + async def ainvoke_non_streaming(self, model: str, args: dict) -> dict: + params = self._create_invoke_params(model, args) + response = await make_async(lambda: self.client.invoke_model(**params)) + + log.debug(f"response: {response}") + + body: StreamingBody = response["body"] + body_dict = json.loads(await make_async(lambda: body.read())) + + log.debug(f"response['body']: {body_dict}") + + return body_dict + + async def ainvoke_streaming( + self, model: str, args: dict + ) -> AsyncIterator[dict]: + params = self._create_invoke_params(model, args) + response = await make_async( + lambda: self.client.invoke_model_with_response_stream(**params) + ) + + log.debug(f"response: {response}") + + body: EventStream = response["body"] + + async for event in to_async_iterator(iter(body)): + chunk = event.get("chunk") + if chunk: + chunk_dict = json.loads(chunk.get("bytes").decode()) + log.debug(f"chunk: {chunk_dict}") + yield chunk_dict diff --git a/aidial_adapter_bedrock/chat_completion.py b/aidial_adapter_bedrock/chat_completion.py index acd9b7b..94007fc 100644 --- a/aidial_adapter_bedrock/chat_completion.py +++ b/aidial_adapter_bedrock/chat_completion.py @@ -19,10 +19,10 @@ def __init__(self, region: str): @dial_exception_decorator async def chat_completion(self, request: Request, response: Response): - model_params = ModelParameters.create(request) + params = ModelParameters.create(request) model = await get_bedrock_adapter( region=self.region, - model_id=request.deployment_id, + model=request.deployment_id, ) async def generate_response( @@ -32,7 +32,7 @@ async def generate_response( ) -> None: with response.create_choice() as choice: consumer = ChoiceConsumer(choice) - await model.achat(consumer, model_params, request.messages) + await model.achat(consumer, params, request.messages) usage.accumulate(consumer.usage) discarded_messages_set.add(consumer.discarded_messages) diff --git a/aidial_adapter_bedrock/dial_api/request.py b/aidial_adapter_bedrock/dial_api/request.py index c2699cc..bae6577 100644 --- a/aidial_adapter_bedrock/dial_api/request.py +++ b/aidial_adapter_bedrock/dial_api/request.py @@ -1,4 +1,4 @@ -from typing import List, Mapping, Optional, Union +from typing import List, Optional from aidial_sdk.chat_completion import Request from pydantic import BaseModel @@ -8,38 +8,30 @@ class ModelParameters(BaseModel): temperature: Optional[float] = None top_p: Optional[float] = None n: Optional[int] = None - stop: Optional[Union[str, List[str]]] = None + stop: List[str] = [] max_tokens: Optional[int] = None max_prompt_tokens: Optional[int] = None - presence_penalty: Optional[float] = None - frequency_penalty: Optional[float] = None - logit_bias: Optional[Mapping[int, float]] = None stream: bool = False @classmethod def create(cls, request: Request) -> "ModelParameters": + stop: List[str] = [] + if request.stop is not None: + stop = ( + [request.stop] + if isinstance(request.stop, str) + else request.stop + ) + return cls( temperature=request.temperature, top_p=request.top_p, n=request.n, - stop=request.stop, + stop=stop, max_tokens=request.max_tokens, max_prompt_tokens=request.max_prompt_tokens, - presence_penalty=request.presence_penalty, - frequency_penalty=request.frequency_penalty, - logit_bias=request.logit_bias, stream=request.stream, ) def add_stop_sequences(self, stop: List[str]) -> "ModelParameters": - if len(stop) == 0: - return self - - self_stop: List[str] = [] - if self.stop is not None: - if isinstance(self.stop, str): - self_stop = [self.stop] - else: - self_stop = self.stop - - return self.copy(update={"stop": [*self_stop, *stop]}) + return self.copy(update={"stop": [*self.stop, *stop]}) diff --git a/aidial_adapter_bedrock/llm/chat_emulation/pseudo_chat.py b/aidial_adapter_bedrock/llm/chat_emulation/pseudo_chat.py index b2ea530..fbffabf 100644 --- a/aidial_adapter_bedrock/llm/chat_emulation/pseudo_chat.py +++ b/aidial_adapter_bedrock/llm/chat_emulation/pseudo_chat.py @@ -1,5 +1,6 @@ -from enum import Enum -from typing import Callable, List, Optional, Set, Tuple +from typing import Callable, List, Optional, Set, Tuple, TypedDict + +from pydantic import BaseModel from aidial_adapter_bedrock.llm.chat_emulation.history import ( FormattedMessage, @@ -11,39 +12,72 @@ AIMessage, BaseMessage, HumanMessage, - SystemMessage, ) from aidial_adapter_bedrock.utils.list import exclude_indices -class RolePrompt(str, Enum): - HUMAN = "\n\nHuman:" - ASSISTANT = "\n\nAssistant:" +class RoleMapping(TypedDict): + system: str + human: str + ai: str + + +class PseudoChatConf(BaseModel): + prelude_template: Optional[str] + annotate_first: bool + add_invitation: bool + mapping: RoleMapping + separator: str + + @property + def prelude(self) -> Optional[str]: + if self.prelude_template is None: + return None + return self.prelude_template.format(**self.mapping) + + @property + def stop_sequences(self) -> List[str]: + return [self.separator + self.mapping["human"]] + + def format_message(self, message: BaseMessage, is_first: bool) -> str: + role = self.mapping.get(message.type) + if role is None: + raise ValueError(f"Unknown message type: {message.type}") -STOP_SEQUENCES: List[str] = [RolePrompt.HUMAN] + role_prefix = role + " " + if is_first and not self.annotate_first: + role_prefix = "" + separator = self.separator + if is_first: + separator = "" -PRELUDE = f""" + return (separator + role_prefix + message.content.lstrip()).rstrip() + + +default_conf = PseudoChatConf( + prelude_template=""" You are a helpful assistant participating in a dialog with a user. -The messages from the user start with "{RolePrompt.HUMAN.strip()}". -The messages from you start with "{RolePrompt.ASSISTANT.strip()}". +The messages from the user start with "{ai}". +The messages from you start with "{human}". Reply to the last message from the user taking into account the preceding dialog history. ==================== -""".strip() - - -def _format_message(message: BaseMessage) -> str: - role = ( - RolePrompt.HUMAN - if isinstance(message, (SystemMessage, HumanMessage)) - else RolePrompt.ASSISTANT - ) - return (role + " " + message.content.lstrip()).rstrip() +""".strip(), + annotate_first=True, + add_invitation=True, + mapping=RoleMapping( + system="Human:", + human="Human:", + ai="Assistant:", + ), + separator="\n\n", +) class PseudoChatHistory(History): stop_sequences: List[str] + pseudo_history_conf: PseudoChatConf def trim( self, count_tokens: Callable[[str], int], max_prompt_tokens: int @@ -77,7 +111,8 @@ def trim( self.messages, discarded_messages ) if message.source_message - ] + ], + conf=self.pseudo_history_conf, ), len(discarded_messages), ) @@ -88,12 +123,15 @@ def trim( source_messages_count - discarded_messages_count == 1 and isinstance(last_source_message, HumanMessage) ): - history = PseudoChatHistory.create([last_source_message]) + history = PseudoChatHistory.create( + messages=[last_source_message], + conf=self.pseudo_history_conf, + ) prompt_tokens = sum( count_tokens(message.text) for message in history.messages ) if prompt_tokens <= max_prompt_tokens: - return history, len(discarded_messages) + return history, discarded_messages_count raise ValidationError( f"The token size of system messages and the last user message ({prompt_tokens}) exceeds" @@ -105,32 +143,49 @@ def trim( ) @classmethod - def create(cls, messages: List[BaseMessage]) -> "PseudoChatHistory": + def create( + cls, messages: List[BaseMessage], conf: PseudoChatConf + ) -> "PseudoChatHistory": if len(messages) == 1 and isinstance(messages[0], HumanMessage): - single_message = messages[0] + message = messages[0] return cls( messages=[ FormattedMessage( - text=single_message.content, - source_message=single_message, + text=message.content, + source_message=message, ) ], stop_sequences=[], + pseudo_history_conf=conf, ) - formatted_messages = [FormattedMessage(text=PRELUDE)] + formatted_messages: List[FormattedMessage] = [] + + if conf.prelude is not None: + formatted_messages.append(FormattedMessage(text=conf.prelude)) - for index, message in enumerate(messages): + for idx, message in enumerate(messages): formatted_messages.append( FormattedMessage( - text=_format_message(message), + text=conf.format_message( + message, len(formatted_messages) == 0 + ), source_message=message, - is_important=is_important_message(messages, index), + is_important=is_important_message(messages, idx), ) ) - formatted_messages.append( - FormattedMessage(text=_format_message(AIMessage(content=""))) - ) + if conf.add_invitation: + formatted_messages.append( + FormattedMessage( + text=conf.format_message( + AIMessage(content=""), len(formatted_messages) == 0 + ) + ) + ) - return cls(messages=formatted_messages, stop_sequences=STOP_SEQUENCES) + return cls( + messages=formatted_messages, + stop_sequences=conf.stop_sequences, + pseudo_history_conf=conf, + ) diff --git a/aidial_adapter_bedrock/llm/chat_model.py b/aidial_adapter_bedrock/llm/chat_model.py index 88ec9b1..64f9f5c 100644 --- a/aidial_adapter_bedrock/llm/chat_model.py +++ b/aidial_adapter_bedrock/llm/chat_model.py @@ -1,11 +1,13 @@ from abc import ABC, abstractmethod -from typing import Callable, List, Optional +from typing import AsyncIterator, Callable, List, Optional from aidial_sdk.chat_completion import Message from pydantic import BaseModel +import aidial_adapter_bedrock.utils.stream as stream_utils from aidial_adapter_bedrock.dial_api.request import ModelParameters from aidial_adapter_bedrock.llm.chat_emulation.pseudo_chat import ( + PseudoChatConf, PseudoChatHistory, ) from aidial_adapter_bedrock.llm.consumer import Consumer @@ -29,10 +31,10 @@ class ChatPrompt(BaseModel): class ChatModel(ABC): - model_id: str + model: str - def __init__(self, model_id: str): - self.model_id = model_id + def __init__(self, model: str): + self.model = model @abstractmethod def _prepare_prompt( @@ -42,7 +44,7 @@ def _prepare_prompt( @abstractmethod async def _apredict( - self, consumer: Consumer, model_params: ModelParameters, prompt: str + self, consumer: Consumer, params: ModelParameters, prompt: str ) -> None: pass @@ -62,40 +64,49 @@ def _validate_and_cleanup_messages( async def achat( self, consumer: Consumer, - model_params: ModelParameters, + params: ModelParameters, messages: List[Message], ): base_messages = list(map(parse_message, messages)) base_messages = self._validate_and_cleanup_messages(base_messages) chat_prompt = self._prepare_prompt( - base_messages, model_params.max_prompt_tokens + base_messages, params.max_prompt_tokens ) - model_params = model_params.add_stop_sequences( - chat_prompt.stop_sequences - ) + params = params.add_stop_sequences(chat_prompt.stop_sequences) log.debug( - f"model parameters:\n{model_params.json(indent=2, exclude_none=True)}" + f"model parameters:\n{params.json(indent=2, exclude_none=True)}" ) log.debug(f"prompt:\n{chat_prompt.text}") - await self._apredict(consumer, model_params, chat_prompt.text) + await self._apredict(consumer, params, chat_prompt.text) if chat_prompt.discarded_messages is not None: consumer.set_discarded_messages(chat_prompt.discarded_messages) class PseudoChatModel(ChatModel, ABC): - def __init__(self, model_id: str, count_tokens: Callable[[str], int]): - super().__init__(model_id) + pseudo_history_conf: PseudoChatConf + count_tokens: Callable[[str], int] + + def __init__( + self, + model: str, + count_tokens: Callable[[str], int], + pseudo_history_conf: PseudoChatConf, + ): + super().__init__(model) self.count_tokens = count_tokens + self.pseudo_history_conf = pseudo_history_conf def _prepare_prompt( self, messages: List[BaseMessage], max_prompt_tokens: Optional[int] ) -> ChatPrompt: - history = PseudoChatHistory.create(messages) + history = PseudoChatHistory.create( + messages=messages, conf=self.pseudo_history_conf + ) if max_prompt_tokens is None: return ChatPrompt( text=history.format(), stop_sequences=history.stop_sequences @@ -111,6 +122,31 @@ def _prepare_prompt( discarded_messages=discarded_messages_count, ) + @staticmethod + def post_process_stream( + stream: AsyncIterator[str], + params: ModelParameters, + pseudo_chat_conf: PseudoChatConf, + ) -> AsyncIterator[str]: + # Removing leading spaces + stream = stream_utils.lstrip(stream) + + # Model may occasionally starts its response with the role prefix + stream = stream_utils.remove_prefix( + stream, + pseudo_chat_conf.mapping["ai"] + " ", + ) + + # If the model doesn't support stop sequences, so do it manually + if params.stop: + stream = stream_utils.stop_at(stream, params.stop) + + # After all the post processing, the stream may become empty. + # To avoid this, add a space to the stream. + stream = stream_utils.ensure_not_empty(stream, " ") + + return stream + class Model(BaseModel): provider: str diff --git a/aidial_adapter_bedrock/llm/model/adapter.py b/aidial_adapter_bedrock/llm/model/adapter.py index 8c12e13..b81a8d8 100644 --- a/aidial_adapter_bedrock/llm/model/adapter.py +++ b/aidial_adapter_bedrock/llm/model/adapter.py @@ -1,11 +1,10 @@ -import boto3 - +from aidial_adapter_bedrock.bedrock import Bedrock +from aidial_adapter_bedrock.llm.chat_emulation.pseudo_chat import default_conf from aidial_adapter_bedrock.llm.chat_model import ChatModel, Model from aidial_adapter_bedrock.llm.model.ai21 import AI21Adapter from aidial_adapter_bedrock.llm.model.amazon import AmazonAdapter from aidial_adapter_bedrock.llm.model.anthropic import AnthropicAdapter from aidial_adapter_bedrock.llm.model.stability import StabilityAdapter -from aidial_adapter_bedrock.utils.concurrency import make_async def count_tokens(string: str) -> int: @@ -22,19 +21,17 @@ def count_tokens(string: str) -> int: return len(string.encode("utf-8")) -async def get_bedrock_adapter(model_id: str, region: str) -> ChatModel: - bedrock = await make_async( - lambda _: boto3.Session().client("bedrock-runtime", region), () - ) - model_provider = Model.parse(model_id).provider - match model_provider: +async def get_bedrock_adapter(model: str, region: str) -> ChatModel: + client = await Bedrock.acreate(region) + provider = Model.parse(model).provider + match provider: case "anthropic": - return AnthropicAdapter(bedrock, model_id) + return AnthropicAdapter(client, model) case "ai21": - return AI21Adapter(bedrock, model_id, count_tokens) + return AI21Adapter(client, model, count_tokens, default_conf) case "stability": - return StabilityAdapter(bedrock, model_id) + return StabilityAdapter(client, model) case "amazon": - return AmazonAdapter(bedrock, model_id, count_tokens) + return AmazonAdapter(client, model, count_tokens, default_conf) case _: - raise ValueError(f"Unknown model provider: '{model_provider}'") + raise ValueError(f"Unknown model provider: '{provider}'") diff --git a/aidial_adapter_bedrock/llm/model/ai21.py b/aidial_adapter_bedrock/llm/model/ai21.py index e9f6040..a9cb623 100644 --- a/aidial_adapter_bedrock/llm/model/ai21.py +++ b/aidial_adapter_bedrock/llm/model/ai21.py @@ -1,14 +1,14 @@ -import json from typing import Any, Callable, Dict, List, Optional from pydantic import BaseModel +from aidial_adapter_bedrock.bedrock import Bedrock from aidial_adapter_bedrock.dial_api.request import ModelParameters from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage +from aidial_adapter_bedrock.llm.chat_emulation.pseudo_chat import PseudoChatConf from aidial_adapter_bedrock.llm.chat_model import PseudoChatModel from aidial_adapter_bedrock.llm.consumer import Consumer from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_AI21 -from aidial_adapter_bedrock.utils.concurrency import make_async class TextRange(BaseModel): @@ -65,70 +65,58 @@ def usage(self) -> TokenUsage: # NOTE: See https://docs.ai21.com/reference/j2-instruct-ref -def prepare_model_kwargs(model_params: ModelParameters) -> Dict[str, Any]: - model_kwargs = {} +def convert_params(params: ModelParameters) -> Dict[str, Any]: + ret = {} - if model_params.max_tokens is not None: - model_kwargs["maxTokens"] = model_params.max_tokens + if params.max_tokens is not None: + ret["maxTokens"] = params.max_tokens else: # The default for max tokens is 16, which is too small for most use cases. # Choosing reasonable default. - model_kwargs["maxTokens"] = DEFAULT_MAX_TOKENS_AI21 + ret["maxTokens"] = DEFAULT_MAX_TOKENS_AI21 - if model_params.temperature is not None: + if params.temperature is not None: # AI21 temperature ranges from 0.0 to 1.0 # OpenAI temperature ranges from 0.0 to 2.0 # Thus scaling down by 2x to match the AI21 range - model_kwargs["temperature"] = model_params.temperature / 2.0 + ret["temperature"] = params.temperature / 2.0 - if model_params.top_p is not None: - model_kwargs["topP"] = model_params.top_p + if params.top_p is not None: + ret["topP"] = params.top_p - if model_params.stop is not None: - model_kwargs["stopSequences"] = ( - [model_params.stop] - if isinstance(model_params.stop, str) - else model_params.stop - ) + if params.stop: + ret["stopSequences"] = params.stop # NOTE: AI21 has "numResults" parameter, however we emulate multiple result # via multiple calls to support all models uniformly. - return model_kwargs + return ret -def prepare_input(prompt: str, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: - return {"prompt": prompt, **model_kwargs} +def create_request(prompt: str, params: Dict[str, Any]) -> Dict[str, Any]: + return {"prompt": prompt, **params} class AI21Adapter(PseudoChatModel): + client: Bedrock + def __init__( - self, bedrock: Any, model_id: str, count_tokens: Callable[[str], int] + self, + client: Bedrock, + model: str, + count_tokens: Callable[[str], int], + pseudo_history_conf: PseudoChatConf, ): - super().__init__(model_id, count_tokens) - self.bedrock = bedrock + super().__init__(model, count_tokens, pseudo_history_conf) + self.client = client async def _apredict( - self, consumer: Consumer, model_params: ModelParameters, prompt: str + self, consumer: Consumer, params: ModelParameters, prompt: str ): - await make_async( - lambda args: self._call(*args), (consumer, model_params, prompt) - ) - - def _call( - self, consumer: Consumer, model_params: ModelParameters, prompt: str - ): - model_kwargs = prepare_model_kwargs(model_params) - - model_response = self.bedrock.invoke_model( - modelId=self.model_id, - accept="application/json", - contentType="application/json", - body=json.dumps(prepare_input(prompt, model_kwargs)), - ) + args = create_request(prompt, convert_params(params)) + response = await self.client.ainvoke_non_streaming(self.model, args) - body = json.loads(model_response["body"].read()) - resp = AI21Response.parse_obj(body) + resp = AI21Response.parse_obj(response) consumer.append_content(resp.content()) consumer.add_usage(resp.usage()) diff --git a/aidial_adapter_bedrock/llm/model/amazon.py b/aidial_adapter_bedrock/llm/model/amazon.py index 993c773..ff90f10 100644 --- a/aidial_adapter_bedrock/llm/model/amazon.py +++ b/aidial_adapter_bedrock/llm/model/amazon.py @@ -1,19 +1,16 @@ -import json -from typing import Any, Callable, Dict, Generator, List, Optional +from typing import Any, AsyncIterator, Callable, Dict, List, Optional from pydantic import BaseModel from typing_extensions import override -import aidial_adapter_bedrock.utils.stream as stream +from aidial_adapter_bedrock.bedrock import Bedrock from aidial_adapter_bedrock.dial_api.request import ModelParameters from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage -from aidial_adapter_bedrock.llm.chat_emulation.pseudo_chat import RolePrompt +from aidial_adapter_bedrock.llm.chat_emulation.pseudo_chat import PseudoChatConf from aidial_adapter_bedrock.llm.chat_model import PseudoChatModel from aidial_adapter_bedrock.llm.consumer import Consumer from aidial_adapter_bedrock.llm.message import BaseMessage from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_AMAZON -from aidial_adapter_bedrock.utils.concurrency import make_async -from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log class AmazonResult(BaseModel): @@ -42,67 +39,53 @@ def usage(self) -> TokenUsage: ) -def prepare_model_kwargs(model_params: ModelParameters) -> Dict[str, Any]: - model_kwargs = {} +def convert_params(params: ModelParameters) -> Dict[str, Any]: + ret = {} - if model_params.temperature is not None: - model_kwargs["temperature"] = model_params.temperature + if params.temperature is not None: + ret["temperature"] = params.temperature - if model_params.top_p is not None: - model_kwargs["topP"] = model_params.top_p + if params.top_p is not None: + ret["topP"] = params.top_p - if model_params.max_tokens is not None: - model_kwargs["maxTokenCount"] = model_params.max_tokens + if params.max_tokens is not None: + ret["maxTokenCount"] = params.max_tokens else: # The default for max tokens is 128, which is too small for most use cases. # Choosing reasonable default. - model_kwargs["maxTokenCount"] = DEFAULT_MAX_TOKENS_AMAZON + ret["maxTokenCount"] = DEFAULT_MAX_TOKENS_AMAZON # NOTE: Amazon Titan (amazon.titan-tg1-large) currently only supports # stop sequences matching pattern "$\|+". - # if model_params.stop is not None: - # model_kwargs["stopSequences"] = model_params.stop + # if params.stop is not None: + # ret["stopSequences"] = params.stop - return model_kwargs + return ret -def prepare_input(prompt: str, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: - return { - "inputText": prompt, - "textGenerationConfig": model_kwargs, - } +def create_request(prompt: str, params: Dict[str, Any]) -> Dict[str, Any]: + return {"inputText": prompt, "textGenerationConfig": params} -def get_generator_for_streaming( - response: Any, - usage: TokenUsage, -) -> Generator[str, None, None]: - body = response["body"] - for event in body: - chunk = event.get("chunk") - if chunk: - chunk_obj = json.loads(chunk.get("bytes").decode()) - log.debug(f"chunk: {chunk_obj}") +async def chunks_to_stream( + chunks: AsyncIterator[dict], usage: TokenUsage +) -> AsyncIterator[str]: + async for chunk in chunks: + input_tokens = chunk.get("inputTextTokenCount") + if input_tokens is not None: + usage.prompt_tokens = input_tokens - input_tokens = chunk_obj.get("inputTextTokenCount") - if input_tokens is not None: - usage.prompt_tokens = input_tokens + output_tokens = chunk.get("totalOutputTextTokenCount") + if output_tokens is not None: + usage.completion_tokens = output_tokens - output_tokens = chunk_obj.get("totalOutputTextTokenCount") - if output_tokens is not None: - usage.completion_tokens = output_tokens + yield chunk["outputText"] - yield chunk_obj["outputText"] - -def get_generator_for_non_streaming( - response: Any, - usage: TokenUsage, -) -> Generator[str, None, None]: - body = json.loads(response["body"].read()) - log.debug(f"body: {body}") - - resp = AmazonResponse.parse_obj(body) +async def response_to_stream( + response: dict, usage: TokenUsage +) -> AsyncIterator[str]: + resp = AmazonResponse.parse_obj(response) token_usage = resp.usage() usage.completion_tokens = token_usage.completion_tokens @@ -111,38 +94,18 @@ def get_generator_for_non_streaming( yield resp.content() -def post_process_stream( - model_params: ModelParameters, content_stream: Generator[str, None, None] -) -> Generator[str, None, None]: - content_stream = stream.lstrip(content_stream) - - # Titan occasionally starts its response with the role prefix - content_stream = stream.remove_prefix( - content_stream, RolePrompt.ASSISTANT.lstrip() + " " - ) - - # Titan doesn't support stop sequences, so do it manually - if model_params.stop is not None: - stop_sequences = ( - [model_params.stop] - if isinstance(model_params.stop, str) - else model_params.stop - ) - content_stream = stream.stop_at(content_stream, stop_sequences) - - # After all the post processing, the stream may become empty. - # To avoid this, add a space to the stream. - content_stream = stream.ensure_not_empty(content_stream, " ") - - return content_stream - - class AmazonAdapter(PseudoChatModel): + client: Bedrock + def __init__( - self, bedrock: Any, model_id: str, count_tokens: Callable[[str], int] + self, + client: Bedrock, + model_id: str, + count_tokens: Callable[[str], int], + pseudo_history_conf: PseudoChatConf, ): - super().__init__(model_id, count_tokens) - self.bedrock = bedrock + super().__init__(model_id, count_tokens, pseudo_history_conf) + self.client = client @override def _validate_and_cleanup_messages( @@ -158,39 +121,24 @@ def _validate_and_cleanup_messages( return messages async def _apredict( - self, consumer: Consumer, model_params: ModelParameters, prompt: str + self, consumer: Consumer, params: ModelParameters, prompt: str ): - await make_async( - lambda args: self._call(*args), (consumer, model_params, prompt) - ) - - def _call( - self, consumer: Consumer, model_params: ModelParameters, prompt: str - ): - model_kwargs = prepare_model_kwargs(model_params) - - invoke_params = { - "modelId": self.model_id, - "accept": "application/json", - "contentType": "application/json", - "body": json.dumps(prepare_input(prompt, model_kwargs)), - } + args = create_request(prompt, convert_params(params)) usage = TokenUsage() - if not model_params.stream: - response = self.bedrock.invoke_model(**invoke_params) - content_stream = get_generator_for_non_streaming(response, usage) + if params.stream: + chunks = self.client.ainvoke_streaming(self.model, args) + stream = chunks_to_stream(chunks, usage) else: - response = self.bedrock.invoke_model_with_response_stream( - **invoke_params - ) - content_stream = get_generator_for_streaming(response, usage) + response = await self.client.ainvoke_non_streaming(self.model, args) + stream = response_to_stream(response, usage) - content_stream = post_process_stream(model_params, content_stream) + stream = self.post_process_stream( + stream, params, self.pseudo_history_conf + ) - for content in content_stream: - log.debug(f"content: {repr(content)}") + async for content in stream: consumer.append_content(content) consumer.add_usage(usage) diff --git a/aidial_adapter_bedrock/llm/model/anthropic.py b/aidial_adapter_bedrock/llm/model/anthropic.py index db87471..a2596fc 100644 --- a/aidial_adapter_bedrock/llm/model/anthropic.py +++ b/aidial_adapter_bedrock/llm/model/anthropic.py @@ -1,8 +1,9 @@ -import json -from typing import Any, Dict, Generator, List, Optional +from typing import Any, AsyncIterator, Dict, List, Optional from anthropic.tokenizer import count_tokens +import aidial_adapter_bedrock.utils.stream as stream_utils +from aidial_adapter_bedrock.bedrock import Bedrock from aidial_adapter_bedrock.dial_api.request import ModelParameters from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage from aidial_adapter_bedrock.llm.chat_emulation import claude_chat @@ -13,8 +14,6 @@ from aidial_adapter_bedrock.llm.consumer import Consumer from aidial_adapter_bedrock.llm.message import BaseMessage from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_ANTHROPIC -from aidial_adapter_bedrock.utils.concurrency import make_async -from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log def compute_usage(prompt: str, completion: str) -> TokenUsage: @@ -25,63 +24,49 @@ def compute_usage(prompt: str, completion: str) -> TokenUsage: # NOTE: See https://docs.anthropic.com/claude/reference/complete_post -def prepare_model_kwargs(model_params: ModelParameters) -> Dict[str, Any]: - model_kwargs = {} +def convert_params(params: ModelParameters) -> Dict[str, Any]: + ret = {} - if model_params.max_tokens is not None: - model_kwargs["max_tokens_to_sample"] = model_params.max_tokens + if params.max_tokens is not None: + ret["max_tokens_to_sample"] = params.max_tokens else: # The max tokens parameter is required for Anthropic models. # Choosing reasonable default. - model_kwargs["max_tokens_to_sample"] = DEFAULT_MAX_TOKENS_ANTHROPIC + ret["max_tokens_to_sample"] = DEFAULT_MAX_TOKENS_ANTHROPIC - if model_params.stop is not None: - model_kwargs["stop_sequences"] = ( - [model_params.stop] - if isinstance(model_params.stop, str) - else model_params.stop - ) - - if model_params.temperature is not None: - model_kwargs["temperature"] = model_params.temperature + if params.stop: + ret["stop_sequences"] = params.stop - if model_params.top_p is not None: - model_kwargs["top_p"] = model_params.top_p + if params.temperature is not None: + ret["temperature"] = params.temperature - return model_kwargs + if params.top_p is not None: + ret["top_p"] = params.top_p + return ret -def prepare_input(prompt: str, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: - return {"prompt": prompt, **model_kwargs} +def create_request(prompt: str, params: Dict[str, Any]) -> Dict[str, Any]: + return {"prompt": prompt, **params} -def get_generator_for_streaming(response: Any) -> Generator[str, None, None]: - body = response["body"] - for event in body: - chunk = event.get("chunk") - if chunk: - chunk_obj = json.loads(chunk.get("bytes").decode()) - log.debug(f"chunk: {chunk_obj}") - yield chunk_obj["completion"] +async def chunks_to_stream( + chunks: AsyncIterator[dict], +) -> AsyncIterator[str]: + async for chunk in chunks: + yield chunk["completion"] -def get_generator_for_non_streaming( - response: Any, -) -> Generator[str, None, None]: - body = json.loads(response["body"].read()) - log.debug(f"body: {body}") - yield body["completion"] +async def response_to_stream(response: dict) -> AsyncIterator[str]: + yield response["completion"] class AnthropicAdapter(ChatModel): - def __init__( - self, - bedrock: Any, - model_id: str, - ): - super().__init__(model_id) - self.bedrock = bedrock + client: Bedrock + + def __init__(self, client: Bedrock, model: str): + super().__init__(model) + self.client = client def _prepare_prompt( self, messages: List[BaseMessage], max_prompt_tokens: Optional[int] @@ -103,37 +88,21 @@ def _prepare_prompt( ) async def _apredict( - self, consumer: Consumer, model_params: ModelParameters, prompt: str - ): - return await make_async( - lambda args: self._predict(*args), (consumer, model_params, prompt) - ) - - def _predict( - self, consumer: Consumer, model_params: ModelParameters, prompt: str + self, consumer: Consumer, params: ModelParameters, prompt: str ): - model_kwargs = prepare_model_kwargs(model_params) - - invoke_params = { - "modelId": self.model_id, - "accept": "application/json", - "contentType": "application/json", - "body": json.dumps(prepare_input(prompt, model_kwargs)), - } - - if not model_params.stream: - response = self.bedrock.invoke_model(**invoke_params) - content_stream = get_generator_for_non_streaming(response) + args = create_request(prompt, convert_params(params)) + if params.stream: + chunks = self.client.ainvoke_streaming(self.model, args) + stream = chunks_to_stream(chunks) else: - response = self.bedrock.invoke_model_with_response_stream( - **invoke_params - ) - content_stream = get_generator_for_streaming(response) + response = await self.client.ainvoke_non_streaming(self.model, args) + stream = response_to_stream(response) - completion = "" + stream = stream_utils.lstrip(stream) - for content in content_stream: + completion = "" + async for content in stream: completion += content consumer.append_content(content) diff --git a/aidial_adapter_bedrock/llm/model/stability.py b/aidial_adapter_bedrock/llm/model/stability.py index 1183d1c..94b73ee 100644 --- a/aidial_adapter_bedrock/llm/model/stability.py +++ b/aidial_adapter_bedrock/llm/model/stability.py @@ -1,10 +1,10 @@ -import json import os from enum import Enum from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field +from aidial_adapter_bedrock.bedrock import Bedrock from aidial_adapter_bedrock.dial_api.request import ModelParameters from aidial_adapter_bedrock.dial_api.storage import ( FileStorage, @@ -17,7 +17,6 @@ from aidial_adapter_bedrock.llm.chat_model import ChatModel, ChatPrompt from aidial_adapter_bedrock.llm.consumer import Attachment, Consumer from aidial_adapter_bedrock.llm.message import BaseMessage -from aidial_adapter_bedrock.utils.concurrency import make_async from aidial_adapter_bedrock.utils.env import get_env @@ -70,7 +69,7 @@ def _throw_if_error(self): raise Exception(self.error.message) # type: ignore -def prepare_input(prompt: str) -> Dict[str, Any]: +def create_request(prompt: str) -> Dict[str, Any]: return {"text_prompts": [{"text": prompt}]} @@ -104,12 +103,12 @@ async def save_to_storage( class StabilityAdapter(ChatModel): - bedrock: Any + client: Bedrock storage: Optional[FileStorage] - def __init__(self, bedrock: Any, model_id: str): - super().__init__(model_id) - self.bedrock = bedrock + def __init__(self, client: Bedrock, model: str): + super().__init__(model) + self.client = client self.storage = None if USE_DIAL_FILE_STORAGE: @@ -130,25 +129,16 @@ def _prepare_prompt( ) async def _apredict( - self, consumer: Consumer, model_params: ModelParameters, prompt: str + self, consumer: Consumer, params: ModelParameters, prompt: str ): - model_response = await make_async( - lambda args: self.bedrock.invoke_model( - accept="application/json", - contentType="application/json", - modelId=args[0], - body=args[1], - ), - (self.model_id, json.dumps(prepare_input(prompt))), - ) - - body = json.loads(model_response["body"].read()) - resp = StabilityResponse.parse_obj(body) + args = create_request(prompt) + response = await self.client.ainvoke_non_streaming(self.model, args) + resp = StabilityResponse.parse_obj(response) consumer.append_content(resp.content()) consumer.add_usage(resp.usage()) for attachment in resp.attachments(): - if self.storage is not None: + if self.storage: attachment = await save_to_storage(self.storage, attachment) consumer.add_attachment(attachment) diff --git a/aidial_adapter_bedrock/utils/concurrency.py b/aidial_adapter_bedrock/utils/concurrency.py index 78148a2..13b971f 100644 --- a/aidial_adapter_bedrock/utils/concurrency.py +++ b/aidial_adapter_bedrock/utils/concurrency.py @@ -1,12 +1,34 @@ import asyncio from concurrent.futures import ThreadPoolExecutor -from typing import Callable, TypeVar +from typing import ( + AsyncIterator, + Callable, + Iterator, + Optional, + Tuple, + TypeVar, + cast, +) T = TypeVar("T") -A = TypeVar("A") -async def make_async(func: Callable[[A], T], arg: A) -> T: - with ThreadPoolExecutor() as executor: +async def make_async(func: Callable[[], T]) -> T: + with ThreadPoolExecutor(max_workers=1) as executor: loop = asyncio.get_event_loop() - return await loop.run_in_executor(executor, func, arg) + return await loop.run_in_executor(executor, func) + + +async def to_async_iterator(iter: Iterator[T]) -> AsyncIterator[T]: + def _next() -> Tuple[bool, Optional[T]]: + try: + return False, next(iter) + except StopIteration: + return True, None + + while True: + is_end, item = await make_async(lambda: _next()) + if is_end: + break + else: + yield cast(T, item) diff --git a/aidial_adapter_bedrock/utils/stream.py b/aidial_adapter_bedrock/utils/stream.py index fb5b7da..f7258b6 100644 --- a/aidial_adapter_bedrock/utils/stream.py +++ b/aidial_adapter_bedrock/utils/stream.py @@ -1,11 +1,11 @@ -from typing import Generator, List +from typing import AsyncIterator, List import tests.utils.string as string -def lstrip(gen: Generator[str, None, None]) -> Generator[str, None, None]: +async def lstrip(stream: AsyncIterator[str]) -> AsyncIterator[str]: start = True - for chunk in gen: + async for chunk in stream: if start: chunk = chunk.lstrip() if chunk != "": @@ -15,13 +15,13 @@ def lstrip(gen: Generator[str, None, None]) -> Generator[str, None, None]: yield chunk -def remove_prefix( - gen: Generator[str, None, None], prefix: str -) -> Generator[str, None, None]: +async def remove_prefix( + stream: AsyncIterator[str], prefix: str +) -> AsyncIterator[str]: acc = "" start = True - for chunk in gen: + async for chunk in stream: if start: acc += chunk if len(acc) >= len(prefix): @@ -34,17 +34,18 @@ def remove_prefix( yield acc -def stop_at( - gen: Generator[str, None, None], stop_sequences: List[str] -) -> Generator[str, None, None]: +async def stop_at( + stream: AsyncIterator[str], stop_sequences: List[str] +) -> AsyncIterator[str]: if len(stop_sequences) == 0: - yield from gen + async for item in stream: + yield item return buffer_len = max(map(len, stop_sequences)) - 1 hold = "" - for chunk in gen: + async for chunk in stream: hold += chunk min_index = len(hold) @@ -66,13 +67,14 @@ def stop_at( yield hold -def ensure_not_empty( - gen: Generator[str, None, None], default: str -) -> Generator[str, None, None]: +async def ensure_not_empty( + gen: AsyncIterator[str], default: str +) -> AsyncIterator[str]: all_chunks_are_empty = True - for chunk in gen: + async for chunk in gen: all_chunks_are_empty = all_chunks_are_empty and chunk == "" yield chunk if all_chunks_are_empty: yield default + yield default diff --git a/client/client_bedrock.py b/client/client_bedrock.py index 8849389..d334bfe 100755 --- a/client/client_bedrock.py +++ b/client/client_bedrock.py @@ -20,10 +20,10 @@ async def main(): deployment = select_enum("Select the deployment", BedrockDeployment) - model_params = ModelParameters() + params = ModelParameters() model = await get_bedrock_adapter( - model_id=deployment.get_model_id(), + model=deployment.get_model_id(), region=location, ) @@ -39,7 +39,7 @@ async def main(): messages.append(Message(role=Role.USER, content=content)) response = CollectConsumer() - await model.achat(response, model_params, messages) + await model.achat(response, params, messages) print_info(response.usage.json(indent=2)) diff --git a/tests/integration_tests/test_chat_completion.py b/tests/integration_tests/test_chat_completion.py index 8725d51..8997044 100644 --- a/tests/integration_tests/test_chat_completion.py +++ b/tests/integration_tests/test_chat_completion.py @@ -109,7 +109,7 @@ def get_test_cases( max_tokens=None, stop=None, messages=[user(query)], - test=lambda s: "hello" in s.lower(), + test=lambda s: "hello" in s.lower() or "hi" in s.lower(), ) ) @@ -161,13 +161,18 @@ def get_test_cases( ) ) + # ai21 models do not support more than one stop word + stop = ["world", "World"] + if "ai21" in deployment.value: + stop = ["world"] + ret.append( TestCase( name="stop sequence", deployment=deployment, streaming=streaming, max_tokens=None, - stop=["world"], + stop=stop, messages=[user('Reply with "hello world"')], test=lambda s: "world" not in s.lower(), ) diff --git a/tests/unit_tests/chat_emulation/test_pseudo_chat_history.py b/tests/unit_tests/chat_emulation/test_pseudo_chat_history.py index e3328e7..3b955d4 100644 --- a/tests/unit_tests/chat_emulation/test_pseudo_chat_history.py +++ b/tests/unit_tests/chat_emulation/test_pseudo_chat_history.py @@ -4,8 +4,8 @@ from aidial_adapter_bedrock.llm.chat_emulation.history import FormattedMessage from aidial_adapter_bedrock.llm.chat_emulation.pseudo_chat import ( - PRELUDE, PseudoChatHistory, + default_conf, ) from aidial_adapter_bedrock.llm.exceptions import ValidationError from aidial_adapter_bedrock.llm.message import ( @@ -23,11 +23,13 @@ def test_construction(): AIMessage(content=" ai message1 "), HumanMessage(content=" human message2 "), ] - history = PseudoChatHistory.create(messages) + history = PseudoChatHistory.create(messages, conf=default_conf) + prelude = history.pseudo_history_conf.prelude + assert prelude is not None assert history.stop_sequences == ["\n\nHuman:"] assert history.messages == [ - FormattedMessage(text=PRELUDE), + FormattedMessage(text=prelude), FormattedMessage( text="\n\nHuman: system message1", source_message=messages[0] ), @@ -50,7 +52,7 @@ def test_construction(): def test_construction_with_single_user_message(): messages: List[BaseMessage] = [HumanMessage(content=" human message ")] - history = PseudoChatHistory.create(messages) + history = PseudoChatHistory.create(messages, conf=default_conf) assert history.stop_sequences == [] assert history.messages == [ @@ -64,7 +66,9 @@ def test_formatting(): FormattedMessage(text="text2"), FormattedMessage(text="text3"), ] - history = PseudoChatHistory(messages=messages, stop_sequences=[]) + history = PseudoChatHistory( + messages=messages, stop_sequences=[], pseudo_history_conf=default_conf + ) prompt = history.format() @@ -77,7 +81,9 @@ def test_no_trimming(): FormattedMessage(text="text2"), FormattedMessage(text="text3"), ] - history = PseudoChatHistory(messages=messages, stop_sequences=[]) + history = PseudoChatHistory( + messages=messages, stop_sequences=[], pseudo_history_conf=default_conf + ) trimmed_history, discarded_messages_count = history.trim(lambda _: 1, 3) @@ -104,16 +110,17 @@ def test_trimming(): FormattedMessage(text="\n\nAssistant:"), ] history = PseudoChatHistory( - messages=messages, - stop_sequences=[], + messages=messages, stop_sequences=[], pseudo_history_conf=default_conf ) trimmed_history, discarded_messages_count = history.trim(lambda _: 1, 4) + prelude = history.pseudo_history_conf.prelude + assert prelude is not None assert discarded_messages_count == 2 assert trimmed_history.stop_sequences == ["\n\nHuman:"] assert trimmed_history.messages == [ - FormattedMessage(text=PRELUDE), + FormattedMessage(text=prelude), FormattedMessage( text="\n\nHuman: system message1", source_message=SystemMessage(content="system message1"), @@ -143,8 +150,7 @@ def test_trimming_with_one_message_left(): ), ] history = PseudoChatHistory( - messages=messages, - stop_sequences=[], + messages=messages, stop_sequences=[], pseudo_history_conf=default_conf ) trimmed_history, discarded_messages_count = history.trim(lambda _: 1, 1) @@ -172,8 +178,7 @@ def test_trimming_with_one_message_accepted_after_second_check(): ), ] history = PseudoChatHistory( - messages=messages, - stop_sequences=[], + messages=messages, stop_sequences=[], pseudo_history_conf=default_conf ) trimmed_history, discarded_messages_count = history.trim( @@ -195,7 +200,9 @@ def test_prompt_is_too_big(): FormattedMessage(text="text2"), FormattedMessage(text="text3"), ] - history = PseudoChatHistory(messages=messages, stop_sequences=[]) + history = PseudoChatHistory( + messages=messages, stop_sequences=[], pseudo_history_conf=default_conf + ) with pytest.raises(ValidationError) as exc_info: history.trim(lambda _: 1, 2) @@ -212,7 +219,9 @@ def test_prompt_with_history_is_too_big(): FormattedMessage(text="text2", is_important=False), FormattedMessage(text="text3"), ] - history = PseudoChatHistory(messages=messages, stop_sequences=[]) + history = PseudoChatHistory( + messages=messages, stop_sequences=[], pseudo_history_conf=default_conf + ) with pytest.raises(ValidationError) as exc_info: history.trim(lambda _: 1, 1) diff --git a/tests/unit_tests/test_stream.py b/tests/unit_tests/test_stream.py index 1316e5c..3bb1d50 100644 --- a/tests/unit_tests/test_stream.py +++ b/tests/unit_tests/test_stream.py @@ -1,4 +1,4 @@ -from typing import Generator, List, Tuple +from typing import AsyncIterator, List, Tuple import pytest @@ -11,13 +11,16 @@ ) -def list_to_gen(xs: List[str]) -> Generator[str, None, None]: +async def list_to_stream(xs: List[str]) -> AsyncIterator[str]: for x in xs: yield x -def gen_to_string(gen: Generator[str, None, None]) -> str: - return "".join(x for x in gen) +async def stream_to_string(stream: AsyncIterator[str]) -> str: + ret = "" + async for chunk in stream: + ret += chunk + return ret lstrip_test_cases: List[Tuple[List[str]]] = [ @@ -35,16 +38,17 @@ def gen_to_string(gen: Generator[str, None, None]) -> str: ] +@pytest.mark.asyncio @pytest.mark.parametrize( "test", lstrip_test_cases, ids=lambda arg: f"{arg[0]}", ) -def test_lstrip(test): +async def test_lstrip(test): (xs,) = test - gen = lstrip(list_to_gen(xs)) - actual = gen_to_string(gen) - expected = "".join(xs).lstrip() + stream = lstrip(list_to_stream(xs)) + actual: str = await stream_to_string(stream) + expected: str = "".join(xs).lstrip() assert actual == expected @@ -70,16 +74,17 @@ def test_lstrip(test): ] +@pytest.mark.asyncio @pytest.mark.parametrize( "test", remove_prefix_test_cases, ids=lambda arg: f"{arg[0]}-{arg[1]}", ) -def test_remove_prefix(test): +async def test_remove_prefix(test): (prefix, xs) = test - gen = remove_prefix(list_to_gen(xs), prefix) - actual = gen_to_string(gen) - expected = string.remove_prefix(prefix, "".join(xs)) + steam = remove_prefix(list_to_stream(xs), prefix) + actual: str = await stream_to_string(steam) + expected: str = string.remove_prefix(prefix, "".join(xs)) assert actual == expected @@ -108,17 +113,18 @@ def test_remove_prefix(test): ] +@pytest.mark.asyncio @pytest.mark.parametrize( "test", stop_at_test_cases, ids=lambda arg: f"{arg[0]}-{arg[1]}", ) -def test_stop_at(test): +async def test_stop_at(test): (stop, xs) = test stop_sequences: List[str] = [stop] if isinstance(stop, str) else stop - gen = stop_at(list_to_gen(xs), stop_sequences) - actual = gen_to_string(gen) - expected = string.stop_at(stop_sequences, "".join(xs)) + stream = stop_at(list_to_stream(xs), stop_sequences) + actual: str = await stream_to_string(stream) + expected: str = string.stop_at(stop_sequences, "".join(xs)) assert actual == expected @@ -130,14 +136,15 @@ def test_stop_at(test): ] +@pytest.mark.asyncio @pytest.mark.parametrize( "test", ensure_not_empty_test_cases, ids=lambda arg: f"{arg[0]}-{arg[1]}", ) -def test_ensure_not_empty(test): +async def test_ensure_not_empty(test): (default, xs) = test - gen = ensure_not_empty(list_to_gen(xs), default) - actual = gen_to_string(gen) - expected = string.ensure_not_empty(default, "".join(xs)) + stream = ensure_not_empty(list_to_stream(xs), default) + actual: str = await stream_to_string(stream) + expected: str = string.ensure_not_empty(default, "".join(xs)) assert actual == expected