Skip to content

Commit

Permalink
feat: migrated latest fixes (#23)
Browse files Browse the repository at this point in the history
* feat: added integration tests for max_tokens and stop sequence
* fix: use number of bytes as token count estimator for AI21 and AWS Titan
* feat: allow empty messages in each language model
* feat: supported history truncation via max_prompt_tokens/discarded_messages parameters
* chore: bumped version of aidial-sdk to 0.1.2
* fix: removed 'Assistant' prefix occasionally generated by Titan
* feat: supported streaming for Titan and Claude
* fix: fixed AI21 temperature setting
  • Loading branch information
adubovik authored Nov 10, 2023
1 parent f9a3c77 commit eea277a
Show file tree
Hide file tree
Showing 46 changed files with 1,954 additions and 1,006 deletions.
4 changes: 3 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
},
"editor.tabSize": 4
},
"python.testing.pytestArgs": ["."],
"python.testing.pytestArgs": [
"."
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.analysis.typeCheckingMode": "basic"
Expand Down
129 changes: 0 additions & 129 deletions MODEL_CARD.md

This file was deleted.

9 changes: 5 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ PORT ?= 5001
IMAGE_NAME ?= ai-dial-adapter-bedrock
PLATFORM ?= linux/amd64
DEV_PYTHON ?= 3.11
DOCKER ?= docker
ARGS=

.PHONY: all install build serve clean lint format test integration_tests docker_build docker_run
Expand Down Expand Up @@ -35,12 +36,12 @@ integration_tests: install
poetry run nox -s integration_tests

docker_test:
docker build --platform $(PLATFORM) -f Dockerfile.test -t $(IMAGE_NAME):test .
docker run --platform $(PLATFORM) --rm $(IMAGE_NAME):test
$(DOCKER) build --platform $(PLATFORM) -f Dockerfile.test -t $(IMAGE_NAME):test .
$(DOCKER) run --platform $(PLATFORM) --rm $(IMAGE_NAME):test

docker_serve:
docker build --platform $(PLATFORM) -t $(IMAGE_NAME):dev .
docker run --platform $(PLATFORM) --env-file ./.env --rm -p $(PORT):5000 $(IMAGE_NAME):dev
$(DOCKER) build --platform $(PLATFORM) -t $(IMAGE_NAME):dev .
$(DOCKER) run --platform $(PLATFORM) --env-file ./.env --rm -p $(PORT):5000 $(IMAGE_NAME):dev

help:
@echo '===================='
Expand Down
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,18 @@ The project implements [AI DIAL API](https://epam-rail.com/dial_api) for languag

Supported models:
* Amazon Titan
- amazon.titan-tg1-large
* AI21 J2
* Anthropic Claude V1, V2
- ai21.j2-grande-instruct
- ai21.j2-jumbo-instruct
- ai21.j2-mid
- ai21.j2-ultra
* Anthropic Claude
- anthropic.claude-instant-v1
- anthropic.claude-v1
- anthropic.claude-v2
* Stable Diffusion
- stability.stable-diffusion-xl

## Developer environment

Expand Down Expand Up @@ -54,6 +63,7 @@ Copy `.env.example` to `.env` and customize it for your environment:
|AWS_SECRET_ACCESS_KEY|NA|AWS credentials with access to Bedrock service|
|DEFAULT_REGION||AWS region e.g. "us-east-1"|
|LOG_LEVEL|INFO|Log level. Use DEBUG for dev purposes and INFO in prod|
|AIDIAL_LOG_LEVEL|WARNING|AI DIAL SDK log level|
|WEB_CONCURRENCY|1|Number of workers for the server|
|TEST_SERVER_URL|http://0.0.0.0:5001|Server URL used in the integration tests|

Expand Down
32 changes: 23 additions & 9 deletions aidial_adapter_bedrock/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,25 @@

import fastapi
from aidial_sdk import DIALApp
from aidial_sdk import HTTPException as DialException
from fastapi import Request
from fastapi.responses import JSONResponse

from aidial_adapter_bedrock.chat_completion import BedrockChatCompletion
from aidial_adapter_bedrock.llm.bedrock_adapter import BedrockModels
from aidial_adapter_bedrock.llm.bedrock_models import BedrockDeployment
from aidial_adapter_bedrock.llm.chat_emulation.types import ChatEmulationType
from aidial_adapter_bedrock.llm.model_listing import get_bedrock_models
from aidial_adapter_bedrock.server.exceptions import dial_exception_decorator
from aidial_adapter_bedrock.universal_api.response import (
ModelObject,
ModelsResponse,
)
from aidial_adapter_bedrock.utils.env import get_env
from aidial_adapter_bedrock.utils.log_config import LogConfig
from aidial_adapter_bedrock.utils.log_config import app_logger as log

logging.config.dictConfig(LogConfig().dict())

default_region = get_env("DEFAULT_REGION")
default_chat_emulation_type = ChatEmulationType.META_CHAT


app = DIALApp(description="AWS Bedrock adapter for RAIL API")

Expand All @@ -32,16 +33,29 @@ def healthcheck():
@app.get("/openai/models")
@dial_exception_decorator
async def models():
bedrock_models = BedrockModels(region=default_region).models()
bedrock_models = get_bedrock_models(region=default_region)
models = [ModelObject(id=model["modelId"]) for model in bedrock_models]
return ModelsResponse(data=models)


for deployment in BedrockDeployment:
app.add_chat_completion(
deployment.get_model_id(),
BedrockChatCompletion(
region=default_region,
chat_emulation_type=default_chat_emulation_type,
),
BedrockChatCompletion(region=default_region),
)


@app.exception_handler(DialException)
async def exception_handler(request: Request, exc: DialException):
log.exception(f"Exception: {str(exc)}")
return JSONResponse(
status_code=exc.status_code,
content={
"error": {
"message": exc.message,
"type": exc.type,
"code": exc.code,
"param": exc.param,
}
},
)
61 changes: 34 additions & 27 deletions aidial_adapter_bedrock/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,58 @@
import asyncio
from typing import List
from typing import Optional, Set

from aidial_sdk.chat_completion import ChatCompletion, Request, Response

from aidial_adapter_bedrock.llm.bedrock_adapter import BedrockAdapter
from aidial_adapter_bedrock.llm.chat_emulation.types import ChatEmulationType
from aidial_adapter_bedrock.llm.consumer import ChoiceConsumer
from aidial_adapter_bedrock.llm.model.adapter import get_bedrock_adapter
from aidial_adapter_bedrock.server.exceptions import dial_exception_decorator
from aidial_adapter_bedrock.universal_api.request import ModelParameters
from aidial_adapter_bedrock.universal_api.token_usage import TokenUsage
from aidial_adapter_bedrock.utils.log_config import app_logger as log


class BedrockChatCompletion(ChatCompletion):
region: str
chat_emulation_type: ChatEmulationType

def __init__(self, region: str, chat_emulation_type: ChatEmulationType):
def __init__(self, region: str):
self.region = region
self.chat_emulation_type = chat_emulation_type

@dial_exception_decorator
async def chat_completion(self, request: Request, response: Response):
model = await BedrockAdapter.create(
model_params = ModelParameters.create(request)
model = await get_bedrock_adapter(
region=self.region,
model_id=request.deployment_id,
model_params=ModelParameters.create(request),
)

async def generate_response(idx: int) -> TokenUsage:
model_response = await model.achat(
self.chat_emulation_type, request.messages
)

async def generate_response(
usage: TokenUsage,
discarded_messages_set: Set[Optional[int]],
choice_idx: int,
) -> None:
with response.create_choice() as choice:
choice.append_content(model_response.content)

for data in model_response.data:
choice.add_attachment(
title=data.name,
data=data.content,
type=data.mime_type,
)

return model_response.usage

usages: List[TokenUsage] = await asyncio.gather(
*(generate_response(idx) for idx in range(request.n or 1))
consumer = ChoiceConsumer(choice)
await model.achat(consumer, model_params, request.messages)
usage.accumulate(consumer.usage)
discarded_messages_set.add(consumer.discarded_messages)

usage = TokenUsage()
discarded_messages_set: Set[Optional[int]] = set()

await asyncio.gather(
*(
generate_response(usage, discarded_messages_set, idx)
for idx in range(request.n or 1)
)
)

usage = sum(usages, TokenUsage())
log.debug(f"usage: {usage}")
response.set_usage(usage.prompt_tokens, usage.completion_tokens)

assert (
len(discarded_messages_set) == 1
), "Discarded messages count must be the same for each choice."

discarded_messages = next(iter(discarded_messages_set))
if discarded_messages is not None:
response.set_discarded_messages(discarded_messages)
Loading

0 comments on commit eea277a

Please sign in to comment.