Skip to content

Commit

Permalink
feat: supported content parts (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Oct 11, 2024
1 parent 4ae0fcf commit 5814648
Show file tree
Hide file tree
Showing 24 changed files with 768 additions and 268 deletions.
2 changes: 1 addition & 1 deletion aidial_adapter_bedrock/dial_api/embedding_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
cast,
)

from aidial_sdk.chat_completion.request import Attachment
from aidial_sdk.chat_completion import Attachment
from aidial_sdk.embeddings.request import EmbeddingsRequest

from aidial_adapter_bedrock.llm.errors import ValidationError
Expand Down
67 changes: 66 additions & 1 deletion aidial_adapter_bedrock/dial_api/request.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
from typing import List, Optional
from typing import List, Optional, TypeGuard, assert_never

from aidial_sdk.chat_completion import (
MessageContentImagePart,
MessageContentPart,
MessageContentTextPart,
)
from aidial_sdk.chat_completion.request import ChatCompletionRequest
from pydantic import BaseModel

from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.llm.tools.tools_config import (
ToolsConfig,
ToolsMode,
validate_messages,
)

MessageContent = str | List[MessageContentPart] | None
MessageContentSpecialized = (
MessageContent
| List[MessageContentTextPart]
| List[MessageContentImagePart]
)


class ModelParameters(BaseModel):
temperature: Optional[float] = None
Expand Down Expand Up @@ -51,3 +64,55 @@ def tools_mode(self) -> ToolsMode | None:
if self.tool_config is not None:
return self.tool_config.tools_mode
return None


def collect_text_content(
content: MessageContentSpecialized, delimiter: str = "\n\n"
) -> str:

if content is None:
return ""

if isinstance(content, str):
return content

texts: List[str] = []
for part in content:
if isinstance(part, MessageContentTextPart):
texts.append(part.text)
else:
raise ValidationError(
"Can't extract text from a multi-modal content part"
)

return delimiter.join(texts)


def to_message_content(content: MessageContentSpecialized) -> MessageContent:
match content:
case None | str():
return content
case list():
return [*content]
case _:
assert_never(content)


def is_text_content(
content: MessageContent,
) -> TypeGuard[str | List[MessageContentTextPart]]:
match content:
case None:
return False
case str():
return True
case list():
return all(
isinstance(part, MessageContentTextPart) for part in content
)
case _:
assert_never(content)


def is_plain_text_content(content: MessageContent) -> TypeGuard[str | None]:
return content is None or isinstance(content, str)
181 changes: 181 additions & 0 deletions aidial_adapter_bedrock/dial_api/resource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import base64
import mimetypes
from abc import ABC, abstractmethod
from typing import List

from aidial_sdk.chat_completion import Attachment
from pydantic import BaseModel, Field, root_validator, validator

from aidial_adapter_bedrock.dial_api.storage import FileStorage, download_file
from aidial_adapter_bedrock.utils.resource import Resource
from aidial_adapter_bedrock.utils.text import truncate_string


class ValidationError(Exception):
message: str

def __init__(self, message: str):
self.message = message
super().__init__(message)


class MissingContentType(ValidationError):
pass


class UnsupportedContentType(ValidationError):
type: str
supported_types: List[str]

def __init__(self, *, message: str, type: str, supported_types: List[str]):
self.type = type
self.supported_types = supported_types
super().__init__(message)


class DialResource(ABC, BaseModel):
entity_name: str = Field(default=None)
supported_types: List[str] | None = Field(default=None)

@abstractmethod
async def download(self, storage: FileStorage | None) -> Resource: ...

@abstractmethod
async def guess_content_type(self) -> str | None: ...

@abstractmethod
async def get_resource_name(self, storage: FileStorage | None) -> str: ...

async def get_content_type(self) -> str:
type = await self.guess_content_type()

if not type:
raise MissingContentType(
f"Can't derive content type of the {self.entity_name}"
)

if (
self.supported_types is not None
and type not in self.supported_types
):
raise UnsupportedContentType(
message=f"The {self.entity_name} is not one of the supported types",
type=type,
supported_types=self.supported_types,
)

return type


class URLResource(DialResource):
url: str
content_type: str | None = None

@root_validator
def validator(cls, values):
values["entity_name"] = values.get("entity_name") or "URL"
return values

async def download(self, storage: FileStorage | None) -> Resource:
type = await self.get_content_type()
data = await _download_url(storage, self.url)
return Resource(type=type, data=data)

async def guess_content_type(self) -> str | None:
return (
self.content_type
or Resource.parse_data_url_content_type(self.url)
or mimetypes.guess_type(self.url)[0]
)

def is_data_url(self) -> bool:
return Resource.parse_data_url_content_type(self.url) is not None

async def get_resource_name(self, storage: FileStorage | None) -> str:
if self.is_data_url():
return f"data URL ({await self.guess_content_type()})"

name = self.url
if storage is not None:
name = await storage.get_human_readable_name(self.url)

return truncate_string(name, n=50)


class AttachmentResource(DialResource):
attachment: Attachment

@validator("attachment", pre=True)
def parse_attachment(cls, value):
if isinstance(value, dict):
attachment = Attachment.parse_obj(value)
# Working around the issue of defaulting missing type to a markdown:
# https://github.com/epam/ai-dial-sdk/blob/2835107e950c89645a2b619fecba2518fa2d7bb1/aidial_sdk/chat_completion/request.py#L22
if "type" not in value:
attachment.type = None
return attachment
return value

@root_validator(pre=True)
def validator(cls, values):
values["entity_name"] = values.get("entity_name") or "attachment"
return values

async def download(self, storage: FileStorage | None) -> Resource:
type = await self.get_content_type()

if self.attachment.data:
data = base64.b64decode(self.attachment.data)
elif self.attachment.url:
data = await _download_url(storage, self.attachment.url)
else:
raise ValidationError(f"Invalid {self.entity_name}")

return Resource(type=type, data=data)

def create_url_resource(self, url: str) -> URLResource:
return URLResource(
url=url,
content_type=self.informative_content_type,
entity_name=self.entity_name,
)

@property
def informative_content_type(self) -> str | None:
if (
self.attachment.type is None
or "octet-stream" in self.attachment.type
):
return None
return self.attachment.type

async def guess_content_type(self) -> str | None:
if url := self.attachment.url:
type = await self.create_url_resource(url).guess_content_type()
if type:
return type

return self.attachment.type

async def get_resource_name(self, storage: FileStorage | None) -> str:
if title := self.attachment.title:
return title

if self.attachment.data:
return f"data {self.entity_name}"
elif url := self.attachment.url:
return await self.create_url_resource(url).get_resource_name(
storage
)
else:
raise ValidationError(f"Invalid {self.entity_name}")


async def _download_url(file_storage: FileStorage | None, url: str) -> bytes:
if (resource := Resource.from_data_url(url)) is not None:
return resource.data

if file_storage:
return await file_storage.download_file(url)
else:
return await download_file(url)
65 changes: 41 additions & 24 deletions aidial_adapter_bedrock/dial_api/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import mimetypes
import os
from typing import Mapping, Optional, TypedDict
from urllib.parse import urljoin
from urllib.parse import unquote, urljoin

import aiohttp
from pydantic import BaseModel

from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log
from aidial_adapter_bedrock.utils.log_config import app_logger as log


class FileMetadata(TypedDict):
Expand All @@ -23,15 +24,10 @@ class Bucket(TypedDict):
appdata: str


class FileStorage:
class FileStorage(BaseModel):
dial_url: str
api_key: str
bucket: Optional[Bucket]

def __init__(self, dial_url: str, api_key: str):
self.dial_url = dial_url
self.api_key = api_key
self.bucket = None
bucket: Optional[Bucket] = None

@property
def auth_headers(self) -> Mapping[str, str]:
Expand All @@ -49,6 +45,15 @@ async def _get_bucket(self, session: aiohttp.ClientSession) -> Bucket:

return self.bucket

async def _get_user_bucket(self, session: aiohttp.ClientSession) -> str:
bucket = await self._get_bucket(session)
appdata = bucket.get("appdata")
if appdata is None:
raise ValueError(
"Can't retrieve user bucket because appdata isn't available"
)
return appdata.split("/", 1)[0]

@staticmethod
def _to_form_data(
filename: str, content_type: str, content: bytes
Expand Down Expand Up @@ -87,36 +92,48 @@ async def upload(
async def upload_file_as_base64(
self, upload_dir: str, data: str, content_type: str
) -> FileMetadata:
filename = f"{upload_dir}/{_compute_hash_digest(data)}"
filename = f"{upload_dir}/{compute_hash_digest(data)}"
content: bytes = base64.b64decode(data)
return await self.upload(filename, content_type, content)

async def download_file_as_base64(self, dial_path: str) -> str:
url = urljoin(f"{self.dial_url}/v1/", dial_path)
def attachment_link_to_url(self, link: str) -> str:
return urljoin(f"{self.dial_url}/v1/", link)

def _url_to_attachment_link(self, url: str) -> str:
return url.removeprefix(f"{self.dial_url}/v1/")

async def download_file(self, link: str) -> bytes:
url = self.attachment_link_to_url(link)
headers: Mapping[str, str] = {}
if url.lower().startswith(self.dial_url.lower()):
headers = self.auth_headers
return await download_file(url, headers)

async def get_human_readable_name(self, link: str) -> str:
url = self.attachment_link_to_url(link)
link = self._url_to_attachment_link(url)

return await download_file_as_base64(url, headers)
link = link.removeprefix("files/")

if link.startswith("public/"):
bucket = "public"
else:
async with aiohttp.ClientSession() as session:
bucket = await self._get_user_bucket(session)

async def _download_file(
url: str, headers: Optional[Mapping[str, str]]
) -> bytes:
link = link.removeprefix(f"{bucket}/")
decoded_link = unquote(link)
return link if link == decoded_link else repr(decoded_link)


async def download_file(url: str, headers: Mapping[str, str] = {}) -> bytes:
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
response.raise_for_status()
return await response.read()


async def download_file_as_base64(
url: str, headers: Optional[Mapping[str, str]] = None
) -> str:
data = await _download_file(url, headers)
return base64.b64encode(data).decode("ascii")


def _compute_hash_digest(file_content: str) -> str:
def compute_hash_digest(file_content: str) -> str:
return hashlib.sha256(file_content.encode()).hexdigest()


Expand Down
Loading

0 comments on commit 5814648

Please sign in to comment.