Skip to content

Commit

Permalink
feat: saving stability artifacts to DIAL file storage
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Nov 29, 2023
1 parent cd54a90 commit e489bc0
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 36 deletions.
2 changes: 1 addition & 1 deletion aidial_adapter_bedrock/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

default_region = get_env("DEFAULT_REGION")

app = DIALApp(description="AWS Bedrock adapter for RAIL API")
app = DIALApp(description="AWS Bedrock adapter for DIAL API")


@app.get("/healthcheck")
Expand Down
96 changes: 96 additions & 0 deletions aidial_adapter_bedrock/dial_api/storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import base64
import hashlib
import io
from typing import TypedDict

import aiohttp

from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log


class FileMetadata(TypedDict):
name: str
type: str
path: str
contentLength: int
contentType: str


class FileStorage:
base_url: str
api_key: str

def __init__(self, dial_url: str, base_dir: str, api_key: str):
self.base_url = f"{dial_url}/v1/files/{base_dir}"
self.api_key = api_key

def auth_headers(self) -> dict[str, str]:
return {"api-key": self.api_key}

@staticmethod
def to_form_data(
filename: str, content_type: str, content: bytes
) -> aiohttp.FormData:
data = aiohttp.FormData()
data.add_field(
"file",
io.BytesIO(content),
filename=filename,
content_type=content_type,
)
return data

async def list(self) -> list[FileMetadata]:
async with aiohttp.ClientSession() as session:
url = f"{self.base_url}?purpose=metadata&path=relative"
async with session.get(
url, headers=self.auth_headers()
) as response:
response.raise_for_status()
ret = await response.json()
log.debug(f"Listed files at '{url}': {ret}")
return ret

async def delete(self, filename: str):
async with aiohttp.ClientSession() as session:
url = f"{self.base_url}/{filename}"
async with session.delete(
url, headers=self.auth_headers()
) as response:
response.raise_for_status()
ret = await response.text()
log.debug(f"Removed files at '{url}': {ret}")
return ret

async def upload(
self, filename: str, content_type: str, content: bytes
) -> FileMetadata:
async with aiohttp.ClientSession() as session:
data = FileStorage.to_form_data(filename, content_type, content)
async with session.post(
self.base_url,
data=data,
headers=self.auth_headers(),
) as response:
response.raise_for_status()
ret = await response.json()
log.debug(
f"Uploaded to '{self.base_url}' file '{filename}': {ret}"
)
return ret


def hash_digest(string: str) -> str:
return hashlib.sha256(string.encode()).hexdigest()


class ImageStorage:
storage: FileStorage

def __init__(self, dial_url: str, base_dir: str, api_key: str):
self.storage = FileStorage(dial_url, base_dir, api_key)

async def upload_base64_png_image(self, data: str) -> FileMetadata:
filename = hash_digest(data) + ".png"
content: bytes = base64.b64decode(data)
return await self.storage.upload(filename, "image/png", content)
88 changes: 53 additions & 35 deletions aidial_adapter_bedrock/llm/model/stability.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
import os
from enum import Enum
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field

from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.storage import ImageStorage
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_emulation.zero_memory_chat import (
ZeroMemoryChatHistory,
Expand All @@ -13,12 +15,7 @@
from aidial_adapter_bedrock.llm.consumer import Attachment, Consumer
from aidial_adapter_bedrock.llm.message import BaseMessage
from aidial_adapter_bedrock.utils.concurrency import make_async


class ResponseData(BaseModel):
mime_type: str
name: str
content: str
from aidial_adapter_bedrock.utils.env import get_env


class StabilityStatus(str, Enum):
Expand All @@ -40,25 +37,26 @@ class StabilityArtifact(BaseModel):

class StabilityResponse(BaseModel):
# TODO: Use tagged union artifacts/error
result: str
result: StabilityStatus
artifacts: Optional[list[StabilityArtifact]]
error: Optional[StabilityError]

def content(self) -> str:
self._throw_if_error()
return ""

def data(self) -> list[ResponseData]:
def attachments(self) -> list[Attachment]:
self._throw_if_error()
return [
ResponseData(
mime_type="image/png",
name="image",
content=self.artifacts[0].base64, # type: ignore
Attachment(
title="image",
type="image/png",
data=self.artifacts[0].base64, # type: ignore
)
]

def usage(self) -> TokenUsage:
self._throw_if_error()
return TokenUsage(
prompt_tokens=0,
completion_tokens=1,
Expand All @@ -73,14 +71,40 @@ def prepare_input(prompt: str) -> Dict[str, Any]:
return {"text_prompts": [{"text": prompt}]}


async def save_to_storage(
storage: ImageStorage, attachment: Attachment
) -> Attachment:
if attachment.type == "image/png" and attachment.data is not None:
response = await storage.upload_base64_png_image(attachment.data)
return Attachment(
title=attachment.title,
type=attachment.type,
url=response["path"] + "/" + response["name"],
)

return attachment


DIAL_BEDROCK_API_KEY = os.getenv("DIAL_BEDROCK_API_KEY")
if DIAL_BEDROCK_API_KEY is not None:
DIAL_URL = get_env("DIAL_URL")


class StabilityAdapter(ChatModel):
def __init__(
self,
bedrock: Any,
model_id: str,
):
bedrock: Any
storage: Optional[ImageStorage]

def __init__(self, bedrock: Any, model_id: str):
super().__init__(model_id)
self.bedrock = bedrock
self.storage = None

if DIAL_BEDROCK_API_KEY is not None:
self.storage = ImageStorage(
dial_url=DIAL_URL,
api_key=DIAL_BEDROCK_API_KEY,
base_dir="stability",
)

def _prepare_prompt(
self, messages: List[BaseMessage], max_prompt_tokens: Optional[int]
Expand All @@ -95,16 +119,14 @@ def _prepare_prompt(
async def _apredict(
self, consumer: Consumer, model_params: ModelParameters, prompt: str
):
return await make_async(
lambda args: self._call(*args), (consumer, prompt)
)

def _call(self, consumer: Consumer, prompt: str):
model_response = self.bedrock.invoke_model(
modelId=self.model_id,
accept="application/json",
contentType="application/json",
body=json.dumps(prepare_input(prompt)),
model_response = await make_async(
lambda args: self.bedrock.invoke_model(
accept="application/json",
contentType="application/json",
modelId=args[0],
body=args[1],
),
(self.model_id, json.dumps(prepare_input(prompt))),
)

body = json.loads(model_response["body"].read())
Expand All @@ -113,11 +135,7 @@ def _call(self, consumer: Consumer, prompt: str):
consumer.append_content(resp.content())
consumer.add_usage(resp.usage())

for data in resp.data():
consumer.add_attachment(
Attachment(
title=data.name,
data=data.content,
type=data.mime_type,
)
)
for attachment in resp.attachments():
if self.storage is not None:
attachment = await save_to_storage(self.storage, attachment)
consumer.add_attachment(attachment)

0 comments on commit e489bc0

Please sign in to comment.