Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add token usage to Bedrock Claude + Migrated chain for this model #564

Merged
merged 3 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cli/magic-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ const embeddingModels = [

// Advanced settings

options.advancedMonitoring = config.advancedMonitoring;
options.createVpcEndpoints = config.vpc?.createVpcEndpoints;
options.logRetention = config.logRetention;
options.privateWebsite = config.privateWebsite;
Expand Down Expand Up @@ -827,6 +828,12 @@ async function processCreateOptions(options: any): Promise<void> {
}
},
},
{
type: "confirm",
name: "advancedMonitoring",
message: "Do you want to enable custom metrics and advanced monitoring?",
initial: options.advancedMonitoring || false,
},
{
type: "confirm",
name: "createVpcEndpoints",
Expand Down Expand Up @@ -1106,6 +1113,7 @@ async function processCreateOptions(options: any): Promise<void> {
}
: undefined,
privateWebsite: advancedSettings.privateWebsite,
advancedMonitoring: advancedSettings.advancedMonitoring,
logRetention: advancedSettings.logRetention
? Number(advancedSettings.logRetention)
: undefined,
Expand Down
5 changes: 5 additions & 0 deletions integtests/chatbot-api/session_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_create_session(client, default_model, default_provider, session_id):
break

assert found == True

assert sessionFound.get("title") == request.get("data").get("text")


Expand All @@ -48,6 +49,10 @@ def test_get_session(client, session_id, default_model):
assert len(session.get("history")) == 2
assert session.get("history")[0].get("type") == "human"
assert session.get("history")[1].get("type") == "ai"
assert session.get("history")[1].get("metadata") is not None
metadata = json.loads(session.get("history")[1].get("metadata"))
assert metadata.get("usage") is not None
assert metadata.get("usage").get("total_tokens") > 0


def test_delete_session(client, session_id):
Expand Down
1 change: 1 addition & 0 deletions lib/aws-genai-llm-chatbot-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
const monitoringStack = new cdk.NestedStack(this, "MonitoringStack");
new Monitoring(monitoringStack, "Monitoring", {
prefix: props.config.prefix,
advancedMonitoring: props.config.advancedMonitoring === true,
appsycnApi: chatBotApi.graphqlApi,
appsyncResolversLogGroups: chatBotApi.resolvers.map((r) => {
return LogGroup.fromLogGroupName(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from enum import Enum
from aws_lambda_powertools import Logger
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import ConversationalRetrievalChain, ConversationChain
from langchain.chains.conversation.base import ConversationChain
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.memory import ConversationBufferMemory
from langchain.prompts.prompt import PromptTemplate
from langchain.chains.conversational_retrieval.prompts import (
Expand All @@ -15,6 +19,12 @@
from genai_core.langchain import WorkspaceRetriever, DynamoDBChatMessageHistory
from genai_core.types import ChatbotMode

from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.outputs import LLMResult, ChatGeneration
from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.human import HumanMessage
from langchain_aws import ChatBedrockConverse

logger = Logger()


Expand All @@ -24,13 +34,40 @@ class Mode(Enum):

class LLMStartHandler(BaseCallbackHandler):
prompts = []
usage = None

# Langchain callbacks
# https://python.langchain.com/v0.2/docs/concepts/#callbacks
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
logger.info(prompts)
self.prompts.append(prompts)

def on_llm_end(
self, response: LLMResult, *, run_id, parent_run_id, **kwargs: Any
) -> Any:
generation = response.generations[0][0] # only one llm request
if (
generation is not None
and isinstance(generation, ChatGeneration)
and isinstance(generation.message, AIMessage)
):
# In case of rag there could be 2 llm calls.
if self.usage is None:
self.usage = {
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
}
self.usage = {
"input_tokens": self.usage.get("input_tokens")
+ generation.message.usage_metadata.get("input_tokens"),
"output_tokens": self.usage.get("output_tokens")
+ generation.message.usage_metadata.get("output_tokens"),
"total_tokens": self.usage.get("total_tokens")
+ generation.message.usage_metadata.get("total_tokens"),
}


class ModelAdapter:
def __init__(
Expand Down Expand Up @@ -101,6 +138,115 @@ def get_condense_question_prompt(self):
def get_qa_prompt(self):
return QA_PROMPT

def run_with_chain_v2(self, user_prompt, workspace_id=None):
if not self.llm:
raise ValueError("llm must be set")

self.callback_handler.prompts = []
documents = []
retriever = None

if workspace_id:
retriever = WorkspaceRetriever(workspace_id=workspace_id)
# Only stream the last llm call (otherwise the internal
# llm response will be visible)
llm_without_streaming = self.get_llm({"streaming": False})
history_aware_retriever = create_history_aware_retriever(
llm_without_streaming,
retriever,
self.get_condense_question_prompt(),
)
question_answer_chain = create_stuff_documents_chain(
self.llm,
self.get_qa_prompt(),
)
chain = create_retrieval_chain(
history_aware_retriever, question_answer_chain
)
else:
chain = self.get_prompt() | self.llm

conversation = RunnableWithMessageHistory(
chain,
lambda session_id: self.chat_history,
history_messages_key="chat_history",
input_messages_key="input",
output_messages_key="output",
)

config = {"configurable": {"session_id": self.session_id}}
try:
if self.model_kwargs.get("streaming", False):
answer = ""
for chunk in conversation.stream(
input={"input": user_prompt}, config=config
):
logger.info("chunk", chunk=chunk)
if "answer" in chunk:
answer = answer + chunk["answer"]
elif isinstance(chunk, AIMessageChunk):
for c in chunk.content:
if "text" in c:
answer = answer + c.get("text")
else:
response = conversation.invoke(
input={"input": user_prompt}, config=config
)
if "answer" in response:
answer = response.get("answer") # Rag flow
else:
answer = response.content
except Exception as e:
logger.exception(e)
raise e

if workspace_id:
# In the RAG flow, the history is not updated automatically
self.chat_history.add_message(HumanMessage(user_prompt))
self.chat_history.add_message(AIMessage(answer))
if retriever is not None:
documents = [
{
"page_content": doc.page_content,
"metadata": doc.metadata,
}
for doc in retriever.get_last_search_documents()
]

metadata = {
"modelId": self.model_id,
"modelKwargs": self.model_kwargs,
"mode": self._mode,
"sessionId": self.session_id,
"userId": self.user_id,
"documents": documents,
"prompts": self.callback_handler.prompts,
"usage": self.callback_handler.usage,
}

self.chat_history.add_metadata(metadata)

if (
self.callback_handler.usage is not None
and "total_tokens" in self.callback_handler.usage
):
# Used by Cloudwatch filters to generate a metric of token usage.
logger.info(
"Usage Metric",
# Each unique value of model id will create a
# new cloudwatch metric (each one has a cost)
model=self.model_id,
metric_type="token_usage",
value=self.callback_handler.usage.get("total_tokens"),
)

return {
"sessionId": self.session_id,
"type": "text",
"content": answer,
"metadata": metadata,
}

def run_with_chain(self, user_prompt, workspace_id=None):
if not self.llm:
raise ValueError("llm must be set")
Expand All @@ -120,7 +266,7 @@ def run_with_chain(self, user_prompt, workspace_id=None):
callbacks=[self.callback_handler],
)
result = conversation({"question": user_prompt})
logger.info(result["source_documents"])
logger.debug(result["source_documents"])
documents = [
{
"page_content": doc.page_content,
Expand Down Expand Up @@ -184,6 +330,9 @@ def run(self, prompt, workspace_id=None, *args, **kwargs):
logger.debug(f"mode: {self._mode}")

if self._mode == ChatbotMode.CHAIN.value:
return self.run_with_chain(prompt, workspace_id)
if isinstance(self.llm, ChatBedrockConverse):
return self.run_with_chain_v2(prompt, workspace_id)
else:
return self.run_with_chain(prompt, workspace_id)

raise ValueError(f"unknown mode {self._mode}")
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
)


from ..base import ModelAdapter
import genai_core.clients
from langchain_aws import ChatBedrockConverse
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder


def get_guardrails() -> dict:
if "BEDROCK_GUARDRAILS_ID" in os.environ:
return {
Expand All @@ -23,6 +29,82 @@ def get_guardrails() -> dict:
return {}


class BedrockChatAdapter(ModelAdapter):
def __init__(self, model_id, *args, **kwargs):
self.model_id = model_id

super().__init__(*args, **kwargs)

def get_qa_prompt(self):
system_prompt = (
"Use the following pieces of context to answer the question at the end."
" If you don't know the answer, just say that you don't know, "
"don't try to make up an answer. \n\n{context}"
)
return ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)

def get_prompt(self):
prompt_template = ChatPromptTemplate(
[
(
"system",
(
"The following is a friendly conversation between "
"a human and an AI."
"If the AI does not know the answer to a question, it "
"truthfully says it does not know."
),
),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
]
)

return prompt_template

def get_condense_question_prompt(self):
contextualize_q_system_prompt = (
"Given the following conversation and a follow up"
" question, rephrase the follow up question to be a standalone question."
)
return ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)

def get_llm(self, model_kwargs={}, extra={}):
bedrock = genai_core.clients.get_bedrock_client()
params = {}
if "temperature" in model_kwargs:
params["temperature"] = model_kwargs["temperature"]
if "topP" in model_kwargs:
params["top_p"] = model_kwargs["topP"]
if "maxTokens" in model_kwargs:
params["max_tokens"] = model_kwargs["maxTokens"]

guardrails = get_guardrails()
if len(guardrails.keys()) > 0:
params["guardrails"] = guardrails

return ChatBedrockConverse(
client=bedrock,
model=self.model_id,
disable_streaming=model_kwargs.get("streaming", False) == False,
callbacks=[self.callback_handler],
**params,
**extra,
)


class LLMInputOutputAdapter:
"""Adapter class to prepare the inputs from Langchain to a format
that LLM model expects.
Expand Down
Loading
Loading