Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: supported content parts #155

Merged
merged 17 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aidial_adapter_bedrock/dial_api/embedding_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
cast,
)

from aidial_sdk.chat_completion.request import Attachment
from aidial_sdk.chat_completion import Attachment
from aidial_sdk.embeddings.request import EmbeddingsRequest

from aidial_adapter_bedrock.llm.errors import ValidationError
Expand Down
67 changes: 66 additions & 1 deletion aidial_adapter_bedrock/dial_api/request.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
from typing import List, Optional
from typing import List, Optional, TypeGuard, assert_never

from aidial_sdk.chat_completion import (
MessageContentImagePart,
MessageContentPart,
MessageContentTextPart,
)
from aidial_sdk.chat_completion.request import ChatCompletionRequest
from pydantic import BaseModel

from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.llm.tools.tools_config import (
ToolsConfig,
ToolsMode,
validate_messages,
)

MessageContent = str | List[MessageContentPart] | None
MessageContentSpecialized = (
MessageContent
| List[MessageContentTextPart]
| List[MessageContentImagePart]
)


class ModelParameters(BaseModel):
temperature: Optional[float] = None
Expand Down Expand Up @@ -51,3 +64,55 @@ def tools_mode(self) -> ToolsMode | None:
if self.tool_config is not None:
return self.tool_config.tools_mode
return None


def collect_text_content(
content: MessageContentSpecialized, delimiter: str = "\n\n"
) -> str:

if content is None:
return ""

if isinstance(content, str):
return content

texts: List[str] = []
for part in content:
if isinstance(part, MessageContentTextPart):
texts.append(part.text)
else:
raise ValidationError(
"Can't extract text from a multi-modal content part"
)

return delimiter.join(texts)


def to_message_content(content: MessageContentSpecialized) -> MessageContent:
match content:
case None | str():
return content
case list():
return [*content]
case _:
assert_never(content)


def is_text_content(
content: MessageContent,
) -> TypeGuard[str | List[MessageContentTextPart]]:
match content:
case None:
return False
case str():
return True
case list():
return all(
isinstance(part, MessageContentTextPart) for part in content
)
case _:
assert_never(content)


def is_plain_text_content(content: MessageContent) -> TypeGuard[str | None]:
return content is None or isinstance(content, str)
181 changes: 181 additions & 0 deletions aidial_adapter_bedrock/dial_api/resource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import base64
import mimetypes
from abc import ABC, abstractmethod
from typing import List

from aidial_sdk.chat_completion import Attachment
from pydantic import BaseModel, Field, root_validator, validator

from aidial_adapter_bedrock.dial_api.storage import FileStorage, download_file
from aidial_adapter_bedrock.utils.resource import Resource
from aidial_adapter_bedrock.utils.text import truncate_string


class ValidationError(Exception):
message: str

def __init__(self, message: str):
self.message = message
super().__init__(message)


class MissingContentType(ValidationError):
pass


class UnsupportedContentType(ValidationError):
type: str
supported_types: List[str]

def __init__(self, *, message: str, type: str, supported_types: List[str]):
self.type = type
self.supported_types = supported_types
super().__init__(message)


class DialResource(ABC, BaseModel):
entity_name: str = Field(default=None)
supported_types: List[str] | None = Field(default=None)

@abstractmethod
async def download(self, storage: FileStorage | None) -> Resource: ...

@abstractmethod
async def guess_content_type(self) -> str | None: ...

@abstractmethod
async def get_resource_name(self, storage: FileStorage | None) -> str: ...

async def get_content_type(self) -> str:
type = await self.guess_content_type()

if not type:
raise MissingContentType(
f"Can't derive content type of the {self.entity_name}"
)

if (
self.supported_types is not None
and type not in self.supported_types
):
raise UnsupportedContentType(
message=f"The {self.entity_name} is not one of the supported types",
type=type,
supported_types=self.supported_types,
)

return type


class URLResource(DialResource):
url: str
content_type: str | None = None

@root_validator
def validator(cls, values):
values["entity_name"] = values.get("entity_name") or "URL"
return values

async def download(self, storage: FileStorage | None) -> Resource:
type = await self.get_content_type()
data = await _download_url(storage, self.url)
return Resource(type=type, data=data)

async def guess_content_type(self) -> str | None:
return (
self.content_type
or Resource.parse_data_url_content_type(self.url)
or mimetypes.guess_type(self.url)[0]
)

def is_data_url(self) -> bool:
return Resource.parse_data_url_content_type(self.url) is not None

async def get_resource_name(self, storage: FileStorage | None) -> str:
if self.is_data_url():
return f"data URL ({await self.guess_content_type()})"

name = self.url
if storage is not None:
name = await storage.get_human_readable_name(self.url)

return truncate_string(name, n=50)


class AttachmentResource(DialResource):
attachment: Attachment

@validator("attachment", pre=True)
def parse_attachment(cls, value):
if isinstance(value, dict):
attachment = Attachment.parse_obj(value)
# Working around the issue of defaulting missing type to a markdown:
roman-romanov-o marked this conversation as resolved.
Show resolved Hide resolved
# https://github.com/epam/ai-dial-sdk/blob/2835107e950c89645a2b619fecba2518fa2d7bb1/aidial_sdk/chat_completion/request.py#L22
if "type" not in value:
attachment.type = None
return attachment
return value

@root_validator(pre=True)
def validator(cls, values):
values["entity_name"] = values.get("entity_name") or "attachment"
return values

async def download(self, storage: FileStorage | None) -> Resource:
type = await self.get_content_type()

if self.attachment.data:
data = base64.b64decode(self.attachment.data)
elif self.attachment.url:
data = await _download_url(storage, self.attachment.url)
else:
raise ValidationError(f"Invalid {self.entity_name}")

return Resource(type=type, data=data)

def create_url_resource(self, url: str) -> URLResource:
return URLResource(
url=url,
content_type=self.informative_content_type,
entity_name=self.entity_name,
)

@property
def informative_content_type(self) -> str | None:
if (
self.attachment.type is None
or "octet-stream" in self.attachment.type
):
return None
return self.attachment.type

async def guess_content_type(self) -> str | None:
if url := self.attachment.url:
type = await self.create_url_resource(url).guess_content_type()
if type:
return type

return self.attachment.type

async def get_resource_name(self, storage: FileStorage | None) -> str:
if title := self.attachment.title:
return title

if self.attachment.data:
return f"data {self.entity_name}"
elif url := self.attachment.url:
return await self.create_url_resource(url).get_resource_name(
storage
)
else:
raise ValidationError(f"Invalid {self.entity_name}")


async def _download_url(file_storage: FileStorage | None, url: str) -> bytes:
if (resource := Resource.from_data_url(url)) is not None:
return resource.data

if file_storage:
return await file_storage.download_file(url)
else:
return await download_file(url)
65 changes: 41 additions & 24 deletions aidial_adapter_bedrock/dial_api/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import mimetypes
import os
from typing import Mapping, Optional, TypedDict
from urllib.parse import urljoin
from urllib.parse import unquote, urljoin

import aiohttp
from pydantic import BaseModel

from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log
from aidial_adapter_bedrock.utils.log_config import app_logger as log


class FileMetadata(TypedDict):
Expand All @@ -23,15 +24,10 @@ class Bucket(TypedDict):
appdata: str


class FileStorage:
class FileStorage(BaseModel):
dial_url: str
api_key: str
bucket: Optional[Bucket]

def __init__(self, dial_url: str, api_key: str):
self.dial_url = dial_url
self.api_key = api_key
self.bucket = None
bucket: Optional[Bucket] = None

@property
def auth_headers(self) -> Mapping[str, str]:
Expand All @@ -49,6 +45,15 @@ async def _get_bucket(self, session: aiohttp.ClientSession) -> Bucket:

return self.bucket

async def _get_user_bucket(self, session: aiohttp.ClientSession) -> str:
bucket = await self._get_bucket(session)
appdata = bucket.get("appdata")
if appdata is None:
raise ValueError(
"Can't retrieve user bucket because appdata isn't available"
)
return appdata.split("/", 1)[0]

@staticmethod
def _to_form_data(
filename: str, content_type: str, content: bytes
Expand Down Expand Up @@ -87,36 +92,48 @@ async def upload(
async def upload_file_as_base64(
self, upload_dir: str, data: str, content_type: str
) -> FileMetadata:
filename = f"{upload_dir}/{_compute_hash_digest(data)}"
filename = f"{upload_dir}/{compute_hash_digest(data)}"
content: bytes = base64.b64decode(data)
return await self.upload(filename, content_type, content)

async def download_file_as_base64(self, dial_path: str) -> str:
url = urljoin(f"{self.dial_url}/v1/", dial_path)
def attachment_link_to_url(self, link: str) -> str:
return urljoin(f"{self.dial_url}/v1/", link)

def _url_to_attachment_link(self, url: str) -> str:
return url.removeprefix(f"{self.dial_url}/v1/")

async def download_file(self, link: str) -> bytes:
url = self.attachment_link_to_url(link)
headers: Mapping[str, str] = {}
if url.lower().startswith(self.dial_url.lower()):
headers = self.auth_headers
return await download_file(url, headers)

async def get_human_readable_name(self, link: str) -> str:
url = self.attachment_link_to_url(link)
link = self._url_to_attachment_link(url)

return await download_file_as_base64(url, headers)
link = link.removeprefix("files/")

if link.startswith("public/"):
bucket = "public"
else:
async with aiohttp.ClientSession() as session:
bucket = await self._get_user_bucket(session)

async def _download_file(
url: str, headers: Optional[Mapping[str, str]]
) -> bytes:
link = link.removeprefix(f"{bucket}/")
decoded_link = unquote(link)
return link if link == decoded_link else repr(decoded_link)


async def download_file(url: str, headers: Mapping[str, str] = {}) -> bytes:
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
response.raise_for_status()
return await response.read()


async def download_file_as_base64(
url: str, headers: Optional[Mapping[str, str]] = None
) -> str:
data = await _download_file(url, headers)
return base64.b64encode(data).decode("ascii")


def _compute_hash_digest(file_content: str) -> str:
def compute_hash_digest(file_content: str) -> str:
return hashlib.sha256(file_content.encode()).hexdigest()


Expand Down
Loading