Skip to content

Commit

Permalink
Merge pull request #607 from fsatsuki/issue-592
Browse files Browse the repository at this point in the history
feat: control LLM model selection through configuration
  • Loading branch information
wadabee authored Dec 12, 2024
2 parents 062a1c9 + a514ca6 commit a3f0a75
Show file tree
Hide file tree
Showing 28 changed files with 1,515 additions and 701 deletions.
12 changes: 11 additions & 1 deletion backend/app/repositories/custom_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
decompose_bot_id,
)
from app.repositories.models.custom_bot import (
ActiveModelsModel,
AgentModel,
BotAliasModel,
BotMeta,
Expand Down Expand Up @@ -88,6 +89,7 @@ def store_bot(user_id: str, custom_bot: BotModel):
"ConversationQuickStarters": [
starter.model_dump() for starter in custom_bot.conversation_quick_starters
],
"ActiveModels": custom_bot.active_models.model_dump(), # type: ignore[attr-defined]
}
if custom_bot.bedrock_knowledge_base:
item["BedrockKnowledgeBase"] = custom_bot.bedrock_knowledge_base.model_dump()
Expand All @@ -110,6 +112,7 @@ def update_bot(
sync_status: type_sync_status,
sync_status_reason: str,
display_retrieved_chunks: bool,
active_models: ActiveModelsModel, # type: ignore
conversation_quick_starters: list[ConversationQuickStarterModel],
bedrock_knowledge_base: BedrockKnowledgeBaseModel | None = None,
bedrock_guardrails: BedrockGuardrailsModel | None = None,
Expand All @@ -130,7 +133,8 @@ def update_bot(
"SyncStatusReason = :sync_status_reason, "
"GenerationParams = :generation_params, "
"DisplayRetrievedChunks = :display_retrieved_chunks, "
"ConversationQuickStarters = :conversation_quick_starters"
"ConversationQuickStarters = :conversation_quick_starters, "
"ActiveModels = :active_models"
)

expression_attribute_values = {
Expand All @@ -146,6 +150,7 @@ def update_bot(
":conversation_quick_starters": [
starter.model_dump() for starter in conversation_quick_starters
],
":active_models": active_models.model_dump(), # type: ignore[attr-defined]
}
if bedrock_knowledge_base:
update_expression += ", BedrockKnowledgeBase = :bedrock_knowledge_base"
Expand Down Expand Up @@ -195,6 +200,7 @@ def store_alias(user_id: str, alias: BotAliasModel):
"ConversationQuickStarters": [
starter.model_dump() for starter in alias.conversation_quick_starters
],
"ActiveModels": alias.active_models.model_dump(), # type: ignore[attr-defined]
}

response = table.put_item(Item=item)
Expand Down Expand Up @@ -484,6 +490,7 @@ def find_private_bot_by_id(user_id: str, bot_id: str) -> BotModel:
if "GuardrailsParams" in item
else None
),
active_models=ActiveModelsModel.model_validate(item.get("ActiveModels", {})),
)

logger.info(f"Found bot: {bot}")
Expand All @@ -502,6 +509,7 @@ def find_public_bot_by_id(bot_id: str) -> BotModel:
raise RecordNotFoundError(f"Public bot with id {bot_id} not found")

item = response["Items"][0]

bot = BotModel(
id=decompose_bot_id(item["SK"]),
title=item["Title"],
Expand Down Expand Up @@ -560,6 +568,7 @@ def find_public_bot_by_id(bot_id: str) -> BotModel:
if "GuardrailsParams" in item
else None
),
active_models=ActiveModelsModel.model_validate(item.get("ActiveModels")),
)
logger.info(f"Found public bot: {bot}")
return bot
Expand Down Expand Up @@ -589,6 +598,7 @@ def find_alias_by_id(user_id: str, alias_id: str) -> BotAliasModel:
has_knowledge=item["HasKnowledge"],
has_agent=item.get("HasAgent", False),
conversation_quick_starters=item.get("ConversationQuickStarters", []),
active_models=ActiveModelsModel.model_validate(item.get("ActiveModels")),
)

logger.info(f"Found alias: {bot}")
Expand Down
7 changes: 6 additions & 1 deletion backend/app/repositories/models/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import base64
from decimal import Decimal
from typing import Annotated, Any, Dict, List, Type, get_args

from pydantic import BaseModel, ConfigDict
from pydantic.functional_serializers import PlainSerializer
from pydantic.functional_validators import PlainValidator
from typing import Annotated, Any

# Declare customized float type
Float = Annotated[
Expand Down Expand Up @@ -35,3 +36,7 @@ def decode_base64_string(value: Any) -> bytes:
return_type=str,
),
]


class DynamicBaseModel(BaseModel):
model_config = ConfigDict(extra="allow")
42 changes: 19 additions & 23 deletions backend/app/repositories/models/conversation.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,40 @@
from __future__ import annotations

from typing import Literal, Any, Annotated, Self, TypedDict, TypeGuard
from pathlib import Path
import re
from pathlib import Path
from typing import Annotated, Any, Literal, Self, TypedDict, TypeGuard
from urllib.parse import urlparse

from app.repositories.models.common import Base64EncodedBytes
from app.routes.schemas.conversation import (
SimpleMessage,
MessageInput,
type_model_name,
AttachmentContent,
Content,
TextContent,
DocumentToolResult,
ImageContent,
AttachmentContent,
ToolUseContent,
ToolUseContentBody,
ToolResult,
TextToolResult,
JsonToolResult,
ImageToolResult,
DocumentToolResult,
ToolResultContentBody,
ToolResultContent,
JsonToolResult,
MessageInput,
RelatedDocument,
SimpleMessage,
TextContent,
TextToolResult,
ToolResult,
ToolResultContent,
ToolResultContentBody,
ToolUseContent,
ToolUseContentBody,
type_model_name,
)
from app.utils import generate_presigned_url

from pydantic import BaseModel, Field, field_validator, Discriminator, JsonValue
from mypy_boto3_bedrock_runtime.literals import DocumentFormatType, ImageFormatType
from mypy_boto3_bedrock_runtime.type_defs import (
ContentBlockTypeDef,
ToolUseBlockTypeDef,
ToolUseBlockOutputTypeDef,
ToolResultBlockTypeDef,
ToolResultContentBlockOutputTypeDef,
ToolUseBlockOutputTypeDef,
ToolUseBlockTypeDef,
)
from mypy_boto3_bedrock_runtime.literals import (
DocumentFormatType,
ImageFormatType,
)
from pydantic import BaseModel, Discriminator, Field, JsonValue, field_validator


class TextContentModel(BaseModel):
Expand Down
26 changes: 23 additions & 3 deletions backend/app/repositories/models/custom_bot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
from app.repositories.models.common import Float
from typing import Any, Dict, List, Literal, Type, get_args

from app.repositories.models.common import DynamicBaseModel, Float
from app.repositories.models.custom_bot_guardrails import BedrockGuardrailsModel
from app.repositories.models.custom_bot_kb import BedrockKnowledgeBaseModel
from app.routes.schemas.bot import type_sync_status
from pydantic import BaseModel
from app.routes.schemas.conversation import type_model_name
from pydantic import BaseModel, ConfigDict, create_model


def _create_model_activate_model(model_names: List[str]) -> Type[DynamicBaseModel]:
fields: Dict[str, Any] = {
name.replace("-", "_").replace(".", "_"): (bool, True) for name in model_names
}
return create_model("ActiveModelsModel", __base__=DynamicBaseModel, **fields)


ActiveModelsModel: Type[BaseModel] = _create_model_activate_model(
list(get_args(type_model_name))
)


class KnowledgeModel(BaseModel):
Expand Down Expand Up @@ -78,6 +93,7 @@ class BotModel(BaseModel):
conversation_quick_starters: list[ConversationQuickStarterModel]
bedrock_knowledge_base: BedrockKnowledgeBaseModel | None
bedrock_guardrails: BedrockGuardrailsModel | None
active_models: ActiveModelsModel # type: ignore

def has_knowledge(self) -> bool:
return (
Expand All @@ -91,7 +107,10 @@ def is_agent_enabled(self) -> bool:
return len(self.agent.tools) > 0

def has_bedrock_knowledge_base(self) -> bool:
return self.bedrock_knowledge_base is not None
return (
self.bedrock_knowledge_base is not None
and self.bedrock_knowledge_base.knowledge_base_id is not None
)


class BotAliasModel(BaseModel):
Expand All @@ -106,6 +125,7 @@ class BotAliasModel(BaseModel):
has_knowledge: bool
has_agent: bool
conversation_quick_starters: list[ConversationQuickStarterModel]
active_models: ActiveModelsModel # type: ignore


class BotMeta(BaseModel):
Expand Down
5 changes: 4 additions & 1 deletion backend/app/routes/bot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal
from typing import Any, Dict, Literal

from app.dependencies import check_creating_bot_allowed
from app.repositories.custom_bot import (
Expand All @@ -7,6 +7,7 @@
update_bot_visibility,
)
from app.routes.schemas.bot import (
ActiveModelsOutput,
Agent,
AgentTool,
BedrockGuardrailsOutput,
Expand All @@ -23,6 +24,7 @@
GenerationParams,
Knowledge,
)
from app.routes.schemas.conversation import type_model_name
from app.usecases.bot import (
create_new_bot,
fetch_all_bots,
Expand Down Expand Up @@ -152,6 +154,7 @@ def get_private_bot(request: Request, bot_id: str):
if bot.bedrock_guardrails
else None
),
active_models=ActiveModelsOutput.model_validate(dict(bot.active_models)),
)
return output

Expand Down
30 changes: 28 additions & 2 deletions backend/app/routes/schemas/bot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Type, get_args

from app.routes.schemas.base import BaseSchema
from app.routes.schemas.bot_guardrails import (
Expand All @@ -11,7 +11,8 @@
BedrockKnowledgeBaseInput,
BedrockKnowledgeBaseOutput,
)
from pydantic import Field, root_validator, validator
from app.routes.schemas.conversation import type_model_name
from pydantic import Field, create_model, validator

if TYPE_CHECKING:
from app.repositories.models.custom_bot import BotModel
Expand All @@ -28,6 +29,26 @@
]


def _create_model_activate_input(model_names: List[str]) -> Type[BaseSchema]:
fields: Dict[str, Any] = {
name.replace("-", "_").replace(".", "_"): (bool, True) for name in model_names
}
return create_model("ActiveModelsInput", **fields, __base__=BaseSchema)


ActiveModelsInput = _create_model_activate_input(list(get_args(type_model_name)))


def create_model_activate_output(model_names: List[str]) -> Type[BaseSchema]:
fields: Dict[str, Any] = {
name.replace("-", "_").replace(".", "_"): (bool, True) for name in model_names
}
return create_model("ActiveModelsOutput", **fields, __base__=BaseSchema)


ActiveModelsOutput = create_model_activate_output(list(get_args(type_model_name)))


class GenerationParams(BaseSchema):
max_tokens: int
top_k: int
Expand Down Expand Up @@ -102,6 +123,7 @@ class BotInput(BaseSchema):
conversation_quick_starters: list[ConversationQuickStarter] | None
bedrock_knowledge_base: BedrockKnowledgeBaseInput | None = None
bedrock_guardrails: BedrockGuardrailsInput | None = None
active_models: ActiveModelsInput # type: ignore


class BotModifyInput(BaseSchema):
Expand All @@ -115,6 +137,7 @@ class BotModifyInput(BaseSchema):
conversation_quick_starters: list[ConversationQuickStarter] | None
bedrock_knowledge_base: BedrockKnowledgeBaseInput | None = None
bedrock_guardrails: BedrockGuardrailsInput | None = None
active_models: ActiveModelsInput # type: ignore

def _has_update_files(self) -> bool:
return self.knowledge is not None and (
Expand Down Expand Up @@ -228,6 +251,7 @@ class BotModifyOutput(BaseSchema):
conversation_quick_starters: list[ConversationQuickStarter]
bedrock_knowledge_base: BedrockKnowledgeBaseOutput | None
bedrock_guardrails: BedrockGuardrailsOutput | None
active_models: ActiveModelsOutput # type: ignore


class BotOutput(BaseSchema):
Expand All @@ -251,6 +275,7 @@ class BotOutput(BaseSchema):
conversation_quick_starters: list[ConversationQuickStarter]
bedrock_knowledge_base: BedrockKnowledgeBaseOutput | None
bedrock_guardrails: BedrockGuardrailsOutput | None
active_models: ActiveModelsOutput # type: ignore


class BotMetaOutput(BaseSchema):
Expand Down Expand Up @@ -281,6 +306,7 @@ class BotSummaryOutput(BaseSchema):
sync_status: type_sync_status
has_knowledge: bool
conversation_quick_starters: list[ConversationQuickStarter]
active_models: ActiveModelsOutput # type: ignore


class BotSwitchVisibilityInput(BaseSchema):
Expand Down
12 changes: 4 additions & 8 deletions backend/app/routes/schemas/conversation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from typing import Literal, Annotated
from typing import Annotated, Literal

from app.routes.schemas.base import BaseSchema
from app.repositories.models.common import Base64EncodedBytes
from pydantic import Field, Discriminator, JsonValue, root_validator

from mypy_boto3_bedrock_runtime.literals import (
DocumentFormatType,
ImageFormatType,
)
from app.routes.schemas.base import BaseSchema
from mypy_boto3_bedrock_runtime.literals import DocumentFormatType, ImageFormatType
from pydantic import Discriminator, Field, JsonValue, root_validator

type_model_name = Literal[
"claude-instant-v1",
Expand Down
Loading

0 comments on commit a3f0a75

Please sign in to comment.