-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: supported content parts (#155)
- Loading branch information
Showing
24 changed files
with
768 additions
and
268 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
import base64 | ||
import mimetypes | ||
from abc import ABC, abstractmethod | ||
from typing import List | ||
|
||
from aidial_sdk.chat_completion import Attachment | ||
from pydantic import BaseModel, Field, root_validator, validator | ||
|
||
from aidial_adapter_bedrock.dial_api.storage import FileStorage, download_file | ||
from aidial_adapter_bedrock.utils.resource import Resource | ||
from aidial_adapter_bedrock.utils.text import truncate_string | ||
|
||
|
||
class ValidationError(Exception): | ||
message: str | ||
|
||
def __init__(self, message: str): | ||
self.message = message | ||
super().__init__(message) | ||
|
||
|
||
class MissingContentType(ValidationError): | ||
pass | ||
|
||
|
||
class UnsupportedContentType(ValidationError): | ||
type: str | ||
supported_types: List[str] | ||
|
||
def __init__(self, *, message: str, type: str, supported_types: List[str]): | ||
self.type = type | ||
self.supported_types = supported_types | ||
super().__init__(message) | ||
|
||
|
||
class DialResource(ABC, BaseModel): | ||
entity_name: str = Field(default=None) | ||
supported_types: List[str] | None = Field(default=None) | ||
|
||
@abstractmethod | ||
async def download(self, storage: FileStorage | None) -> Resource: ... | ||
|
||
@abstractmethod | ||
async def guess_content_type(self) -> str | None: ... | ||
|
||
@abstractmethod | ||
async def get_resource_name(self, storage: FileStorage | None) -> str: ... | ||
|
||
async def get_content_type(self) -> str: | ||
type = await self.guess_content_type() | ||
|
||
if not type: | ||
raise MissingContentType( | ||
f"Can't derive content type of the {self.entity_name}" | ||
) | ||
|
||
if ( | ||
self.supported_types is not None | ||
and type not in self.supported_types | ||
): | ||
raise UnsupportedContentType( | ||
message=f"The {self.entity_name} is not one of the supported types", | ||
type=type, | ||
supported_types=self.supported_types, | ||
) | ||
|
||
return type | ||
|
||
|
||
class URLResource(DialResource): | ||
url: str | ||
content_type: str | None = None | ||
|
||
@root_validator | ||
def validator(cls, values): | ||
values["entity_name"] = values.get("entity_name") or "URL" | ||
return values | ||
|
||
async def download(self, storage: FileStorage | None) -> Resource: | ||
type = await self.get_content_type() | ||
data = await _download_url(storage, self.url) | ||
return Resource(type=type, data=data) | ||
|
||
async def guess_content_type(self) -> str | None: | ||
return ( | ||
self.content_type | ||
or Resource.parse_data_url_content_type(self.url) | ||
or mimetypes.guess_type(self.url)[0] | ||
) | ||
|
||
def is_data_url(self) -> bool: | ||
return Resource.parse_data_url_content_type(self.url) is not None | ||
|
||
async def get_resource_name(self, storage: FileStorage | None) -> str: | ||
if self.is_data_url(): | ||
return f"data URL ({await self.guess_content_type()})" | ||
|
||
name = self.url | ||
if storage is not None: | ||
name = await storage.get_human_readable_name(self.url) | ||
|
||
return truncate_string(name, n=50) | ||
|
||
|
||
class AttachmentResource(DialResource): | ||
attachment: Attachment | ||
|
||
@validator("attachment", pre=True) | ||
def parse_attachment(cls, value): | ||
if isinstance(value, dict): | ||
attachment = Attachment.parse_obj(value) | ||
# Working around the issue of defaulting missing type to a markdown: | ||
# https://github.com/epam/ai-dial-sdk/blob/2835107e950c89645a2b619fecba2518fa2d7bb1/aidial_sdk/chat_completion/request.py#L22 | ||
if "type" not in value: | ||
attachment.type = None | ||
return attachment | ||
return value | ||
|
||
@root_validator(pre=True) | ||
def validator(cls, values): | ||
values["entity_name"] = values.get("entity_name") or "attachment" | ||
return values | ||
|
||
async def download(self, storage: FileStorage | None) -> Resource: | ||
type = await self.get_content_type() | ||
|
||
if self.attachment.data: | ||
data = base64.b64decode(self.attachment.data) | ||
elif self.attachment.url: | ||
data = await _download_url(storage, self.attachment.url) | ||
else: | ||
raise ValidationError(f"Invalid {self.entity_name}") | ||
|
||
return Resource(type=type, data=data) | ||
|
||
def create_url_resource(self, url: str) -> URLResource: | ||
return URLResource( | ||
url=url, | ||
content_type=self.informative_content_type, | ||
entity_name=self.entity_name, | ||
) | ||
|
||
@property | ||
def informative_content_type(self) -> str | None: | ||
if ( | ||
self.attachment.type is None | ||
or "octet-stream" in self.attachment.type | ||
): | ||
return None | ||
return self.attachment.type | ||
|
||
async def guess_content_type(self) -> str | None: | ||
if url := self.attachment.url: | ||
type = await self.create_url_resource(url).guess_content_type() | ||
if type: | ||
return type | ||
|
||
return self.attachment.type | ||
|
||
async def get_resource_name(self, storage: FileStorage | None) -> str: | ||
if title := self.attachment.title: | ||
return title | ||
|
||
if self.attachment.data: | ||
return f"data {self.entity_name}" | ||
elif url := self.attachment.url: | ||
return await self.create_url_resource(url).get_resource_name( | ||
storage | ||
) | ||
else: | ||
raise ValidationError(f"Invalid {self.entity_name}") | ||
|
||
|
||
async def _download_url(file_storage: FileStorage | None, url: str) -> bytes: | ||
if (resource := Resource.from_data_url(url)) is not None: | ||
return resource.data | ||
|
||
if file_storage: | ||
return await file_storage.download_file(url) | ||
else: | ||
return await download_file(url) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.