Skip to content

Commit

Permalink
fix: migrated to new file API
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Dec 18, 2023
1 parent 621d2b8 commit 5eaeaf7
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 43 deletions.
3 changes: 1 addition & 2 deletions aidial_adapter_bedrock/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from aidial_sdk.chat_completion import ChatCompletion, Request, Response

from aidial_adapter_bedrock.dial_api.auth import Auth
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.consumer import ChoiceConsumer
Expand All @@ -24,7 +23,7 @@ async def chat_completion(self, request: Request, response: Response):
model = await get_bedrock_adapter(
region=self.region,
model=request.deployment_id,
file_api_auth=Auth.from_headers("authorization", request.headers),
headers=request.headers,
)

async def generate_response(
Expand Down
53 changes: 49 additions & 4 deletions aidial_adapter_bedrock/dial_api/storage.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import base64
import hashlib
import io
from typing import TypedDict
from typing import Mapping, Optional, TypedDict

import aiohttp

from aidial_adapter_bedrock.dial_api.auth import Auth
from aidial_adapter_bedrock.utils.env import get_env, get_env_bool
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log


Expand All @@ -21,10 +22,24 @@ class FileStorage:
base_url: str
auth: Auth

def __init__(self, dial_url: str, base_dir: str, auth: Auth):
self.base_url = f"{dial_url}/v1/files/{base_dir}"
def __init__(self, dial_url: str, base_dir: str, bucket: str, auth: Auth):
self.base_url = f"{dial_url}/v1/files/{bucket}/{base_dir}"
self.auth = auth

@classmethod
async def create(cls, dial_url: str, base_dir: str, auth: Auth):
bucket = await FileStorage._get_bucket(dial_url, auth)
return cls(dial_url, base_dir, bucket, auth)

@staticmethod
async def _get_bucket(dial_url: str, auth: Auth) -> str:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{dial_url}/v1/bucket", headers=auth.headers
) as response:
response.raise_for_status()
return await response.text()

@staticmethod
def to_form_data(
filename: str, content_type: str, content: bytes
Expand Down Expand Up @@ -60,9 +75,39 @@ def _hash_digest(string: str) -> str:
return hashlib.sha256(string.encode()).hexdigest()


async def upload_base64_file(
async def upload_file_as_base64(
storage: FileStorage, data: str, content_type: str
) -> FileMetadata:
filename = _hash_digest(data)
content: bytes = base64.b64decode(data)
return await storage.upload(filename, content_type, content)


DIAL_USE_FILE_STORAGE = get_env_bool("DIAL_USE_FILE_STORAGE", False)

DIAL_URL: Optional[str] = None
if DIAL_USE_FILE_STORAGE:
DIAL_URL = get_env(
"DIAL_URL", "DIAL_URL must be set to use the DIAL file storage"
)


async def create_file_storage(
base_dir: str, headers: Mapping[str, str]
) -> Optional[FileStorage]:
if not DIAL_USE_FILE_STORAGE or DIAL_URL is None:
return None

auth = Auth.from_headers("authorization", headers)
if auth is None:
log.warning(
"The request doesn't have required headers to use the DIAL file storage. "
"Fallback to base64 encoding of images."
)
return None

return await FileStorage.create(
dial_url=DIAL_URL,
auth=auth,
base_dir=base_dir,
)
7 changes: 3 additions & 4 deletions aidial_adapter_bedrock/llm/model/adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional
from typing import Mapping

from aidial_adapter_bedrock.bedrock import Bedrock
from aidial_adapter_bedrock.dial_api.auth import Auth
from aidial_adapter_bedrock.llm.chat_emulator import default_emulator
from aidial_adapter_bedrock.llm.chat_model import ChatModel, Model
from aidial_adapter_bedrock.llm.model.ai21 import AI21Adapter
Expand All @@ -14,7 +13,7 @@


async def get_bedrock_adapter(
model: str, region: str, file_api_auth: Optional[Auth]
model: str, region: str, headers: Mapping[str, str]
) -> ChatModel:
client = await Bedrock.acreate(region)
provider = Model.parse(model).provider
Expand All @@ -26,7 +25,7 @@ async def get_bedrock_adapter(
client, model, default_tokenize, default_emulator
)
case "stability":
return StabilityAdapter(client, model, file_api_auth)
return await StabilityAdapter.create(client, model, headers)
case "amazon":
return AmazonAdapter(
client, model, default_tokenize, default_emulator
Expand Down
47 changes: 15 additions & 32 deletions aidial_adapter_bedrock/llm/model/stability.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
import os
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Mapping, Optional

from pydantic import BaseModel, Field

from aidial_adapter_bedrock.bedrock import Bedrock
from aidial_adapter_bedrock.dial_api.auth import Auth
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.storage import (
FileStorage,
upload_base64_file,
create_file_storage,
upload_file_as_base64,
)
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_model import ChatModel, ChatPrompt
from aidial_adapter_bedrock.llm.consumer import Attachment, Consumer
from aidial_adapter_bedrock.llm.exceptions import ValidationError
from aidial_adapter_bedrock.llm.message import BaseMessage
from aidial_adapter_bedrock.utils.env import get_env
from aidial_adapter_bedrock.utils.log_config import app_logger as log


class StabilityStatus(str, Enum):
Expand Down Expand Up @@ -81,7 +78,7 @@ async def save_to_storage(
and attachment.type.startswith("image/")
and attachment.data is not None
):
response = await upload_base64_file(
response = await upload_file_as_base64(
storage, attachment.data, attachment.type
)
return Attachment(
Expand All @@ -93,39 +90,25 @@ async def save_to_storage(
return attachment


DIAL_USE_FILE_STORAGE = (
os.getenv("DIAL_USE_FILE_STORAGE", "false").lower() == "true"
)

if DIAL_USE_FILE_STORAGE:
DIAL_URL = get_env(
"DIAL_URL", "DIAL_URL must be set to use the DIAL file storage"
)


class StabilityAdapter(ChatModel):
client: Bedrock
storage: Optional[FileStorage]

def __init__(
self, client: Bedrock, model: str, file_api_auth: Optional[Auth]
self, client: Bedrock, model: str, storage: Optional[FileStorage]
):
super().__init__(model)
self.client = client
self.storage = None

if DIAL_USE_FILE_STORAGE:
if file_api_auth is None:
log.warning(
"The request doesn't have required headers to use the DIAL file storage. "
"Fallback to base64 encoding of images."
)
else:
self.storage = FileStorage(
dial_url=DIAL_URL,
auth=file_api_auth,
base_dir="images/stable-diffusion",
)
self.storage = storage

@classmethod
async def create(
cls, client: Bedrock, model: str, headers: Mapping[str, str]
):
storage: Optional[FileStorage] = await create_file_storage(
"images/stable-diffusion", headers
)
return cls(client, model, storage)

def _prepare_prompt(
self, messages: List[BaseMessage], max_prompt_tokens: Optional[int]
Expand Down
4 changes: 4 additions & 0 deletions aidial_adapter_bedrock/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ def get_env(name: str, err_msg: Optional[str] = None) -> str:
return val

raise Exception(err_msg or f"{name} env variable is not set")


def get_env_bool(name: str, default: bool = False) -> bool:
return os.getenv(name, str(default)).lower() == "true"
2 changes: 1 addition & 1 deletion client/client_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def main():
model = await get_bedrock_adapter(
model=deployment.get_model_id(),
region=location,
file_api_auth=None,
headers={},
)

messages: List[Message] = []
Expand Down

0 comments on commit 5eaeaf7

Please sign in to comment.