diff --git a/aidial_adapter_bedrock/app.py b/aidial_adapter_bedrock/app.py index 1890ee0..6c5627e 100644 --- a/aidial_adapter_bedrock/app.py +++ b/aidial_adapter_bedrock/app.py @@ -19,7 +19,7 @@ default_region = get_env("DEFAULT_REGION") -app = DIALApp(description="AWS Bedrock adapter for RAIL API") +app = DIALApp(description="AWS Bedrock adapter for DIAL API") @app.get("/healthcheck") diff --git a/aidial_adapter_bedrock/dial_api/storage.py b/aidial_adapter_bedrock/dial_api/storage.py new file mode 100644 index 0000000..04e7c39 --- /dev/null +++ b/aidial_adapter_bedrock/dial_api/storage.py @@ -0,0 +1,96 @@ +import base64 +import hashlib +import io +from typing import TypedDict + +import aiohttp + +from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log + + +class FileMetadata(TypedDict): + name: str + type: str + path: str + contentLength: int + contentType: str + + +class FileStorage: + base_url: str + api_key: str + + def __init__(self, dial_url: str, base_dir: str, api_key: str): + self.base_url = f"{dial_url}/v1/files/{base_dir}" + self.api_key = api_key + + def auth_headers(self) -> dict[str, str]: + return {"api-key": self.api_key} + + @staticmethod + def to_form_data( + filename: str, content_type: str, content: bytes + ) -> aiohttp.FormData: + data = aiohttp.FormData() + data.add_field( + "file", + io.BytesIO(content), + filename=filename, + content_type=content_type, + ) + return data + + async def list(self) -> list[FileMetadata]: + async with aiohttp.ClientSession() as session: + url = f"{self.base_url}?purpose=metadata&path=relative" + async with session.get( + url, headers=self.auth_headers() + ) as response: + response.raise_for_status() + ret = await response.json() + log.debug(f"Listed files at '{url}': {ret}") + return ret + + async def delete(self, filename: str): + async with aiohttp.ClientSession() as session: + url = f"{self.base_url}/{filename}" + async with session.delete( + url, headers=self.auth_headers() + ) as response: + response.raise_for_status() + ret = await response.text() + log.debug(f"Removed files at '{url}': {ret}") + return ret + + async def upload( + self, filename: str, content_type: str, content: bytes + ) -> FileMetadata: + async with aiohttp.ClientSession() as session: + data = FileStorage.to_form_data(filename, content_type, content) + async with session.post( + self.base_url, + data=data, + headers=self.auth_headers(), + ) as response: + response.raise_for_status() + ret = await response.json() + log.debug( + f"Uploaded to '{self.base_url}' file '{filename}': {ret}" + ) + return ret + + +def hash_digest(string: str) -> str: + return hashlib.sha256(string.encode()).hexdigest() + + +class ImageStorage: + storage: FileStorage + + def __init__(self, dial_url: str, base_dir: str, api_key: str): + self.storage = FileStorage(dial_url, base_dir, api_key) + + async def upload_base64_png_image(self, data: str) -> FileMetadata: + filename = hash_digest(data) + ".png" + content: bytes = base64.b64decode(data) + return await self.storage.upload(filename, "image/png", content) diff --git a/aidial_adapter_bedrock/llm/model/stability.py b/aidial_adapter_bedrock/llm/model/stability.py index c7f111b..57ce05d 100644 --- a/aidial_adapter_bedrock/llm/model/stability.py +++ b/aidial_adapter_bedrock/llm/model/stability.py @@ -1,10 +1,12 @@ import json +import os from enum import Enum from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field from aidial_adapter_bedrock.dial_api.request import ModelParameters +from aidial_adapter_bedrock.dial_api.storage import ImageStorage from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage from aidial_adapter_bedrock.llm.chat_emulation.zero_memory_chat import ( ZeroMemoryChatHistory, @@ -13,12 +15,7 @@ from aidial_adapter_bedrock.llm.consumer import Attachment, Consumer from aidial_adapter_bedrock.llm.message import BaseMessage from aidial_adapter_bedrock.utils.concurrency import make_async - - -class ResponseData(BaseModel): - mime_type: str - name: str - content: str +from aidial_adapter_bedrock.utils.env import get_env class StabilityStatus(str, Enum): @@ -40,7 +37,7 @@ class StabilityArtifact(BaseModel): class StabilityResponse(BaseModel): # TODO: Use tagged union artifacts/error - result: str + result: StabilityStatus artifacts: Optional[list[StabilityArtifact]] error: Optional[StabilityError] @@ -48,17 +45,18 @@ def content(self) -> str: self._throw_if_error() return "" - def data(self) -> list[ResponseData]: + def attachments(self) -> list[Attachment]: self._throw_if_error() return [ - ResponseData( - mime_type="image/png", - name="image", - content=self.artifacts[0].base64, # type: ignore + Attachment( + title="image", + type="image/png", + data=self.artifacts[0].base64, # type: ignore ) ] def usage(self) -> TokenUsage: + self._throw_if_error() return TokenUsage( prompt_tokens=0, completion_tokens=1, @@ -73,14 +71,40 @@ def prepare_input(prompt: str) -> Dict[str, Any]: return {"text_prompts": [{"text": prompt}]} +async def save_to_storage( + storage: ImageStorage, attachment: Attachment +) -> Attachment: + if attachment.type == "image/png" and attachment.data is not None: + response = await storage.upload_base64_png_image(attachment.data) + return Attachment( + title=attachment.title, + type=attachment.type, + url=response["path"] + "/" + response["name"], + ) + + return attachment + + +DIAL_BEDROCK_API_KEY = os.getenv("DIAL_BEDROCK_API_KEY") +if DIAL_BEDROCK_API_KEY is not None: + DIAL_URL = get_env("DIAL_URL") + + class StabilityAdapter(ChatModel): - def __init__( - self, - bedrock: Any, - model_id: str, - ): + bedrock: Any + storage: Optional[ImageStorage] + + def __init__(self, bedrock: Any, model_id: str): super().__init__(model_id) self.bedrock = bedrock + self.storage = None + + if DIAL_BEDROCK_API_KEY is not None: + self.storage = ImageStorage( + dial_url=DIAL_URL, + api_key=DIAL_BEDROCK_API_KEY, + base_dir="stability", + ) def _prepare_prompt( self, messages: List[BaseMessage], max_prompt_tokens: Optional[int] @@ -95,16 +119,14 @@ def _prepare_prompt( async def _apredict( self, consumer: Consumer, model_params: ModelParameters, prompt: str ): - return await make_async( - lambda args: self._call(*args), (consumer, prompt) - ) - - def _call(self, consumer: Consumer, prompt: str): - model_response = self.bedrock.invoke_model( - modelId=self.model_id, - accept="application/json", - contentType="application/json", - body=json.dumps(prepare_input(prompt)), + model_response = await make_async( + lambda args: self.bedrock.invoke_model( + accept="application/json", + contentType="application/json", + modelId=args[0], + body=args[1], + ), + (self.model_id, json.dumps(prepare_input(prompt))), ) body = json.loads(model_response["body"].read()) @@ -113,11 +135,7 @@ def _call(self, consumer: Consumer, prompt: str): consumer.append_content(resp.content()) consumer.add_usage(resp.usage()) - for data in resp.data(): - consumer.add_attachment( - Attachment( - title=data.name, - data=data.content, - type=data.mime_type, - ) - ) + for attachment in resp.attachments(): + if self.storage is not None: + attachment = await save_to_storage(self.storage, attachment) + consumer.add_attachment(attachment)