Skip to content

Commit

Permalink
feat: Add token usage to Bedrock Claude + Migrated chain for this mod…
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-marion authored Sep 16, 2024
1 parent 5d4ded0 commit 274d6db
Show file tree
Hide file tree
Showing 15 changed files with 423 additions and 131 deletions.
8 changes: 8 additions & 0 deletions cli/magic-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ const embeddingModels = [

// Advanced settings

options.advancedMonitoring = config.advancedMonitoring;
options.createVpcEndpoints = config.vpc?.createVpcEndpoints;
options.logRetention = config.logRetention;
options.privateWebsite = config.privateWebsite;
Expand Down Expand Up @@ -827,6 +828,12 @@ async function processCreateOptions(options: any): Promise<void> {
}
},
},
{
type: "confirm",
name: "advancedMonitoring",
message: "Do you want to enable custom metrics and advanced monitoring?",
initial: options.advancedMonitoring || false,
},
{
type: "confirm",
name: "createVpcEndpoints",
Expand Down Expand Up @@ -1106,6 +1113,7 @@ async function processCreateOptions(options: any): Promise<void> {
}
: undefined,
privateWebsite: advancedSettings.privateWebsite,
advancedMonitoring: advancedSettings.advancedMonitoring,
logRetention: advancedSettings.logRetention
? Number(advancedSettings.logRetention)
: undefined,
Expand Down
5 changes: 5 additions & 0 deletions integtests/chatbot-api/session_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_create_session(client, default_model, default_provider, session_id):
break

assert found == True

assert sessionFound.get("title") == request.get("data").get("text")


Expand All @@ -48,6 +49,10 @@ def test_get_session(client, session_id, default_model):
assert len(session.get("history")) == 2
assert session.get("history")[0].get("type") == "human"
assert session.get("history")[1].get("type") == "ai"
assert session.get("history")[1].get("metadata") is not None
metadata = json.loads(session.get("history")[1].get("metadata"))
assert metadata.get("usage") is not None
assert metadata.get("usage").get("total_tokens") > 0


def test_delete_session(client, session_id):
Expand Down
1 change: 1 addition & 0 deletions lib/aws-genai-llm-chatbot-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
const monitoringStack = new cdk.NestedStack(this, "MonitoringStack");
new Monitoring(monitoringStack, "Monitoring", {
prefix: props.config.prefix,
advancedMonitoring: props.config.advancedMonitoring === true,
appsycnApi: chatBotApi.graphqlApi,
appsyncResolversLogGroups: chatBotApi.resolvers.map((r) => {
return LogGroup.fromLogGroupName(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from enum import Enum
from aws_lambda_powertools import Logger
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import ConversationalRetrievalChain, ConversationChain
from langchain.chains.conversation.base import ConversationChain
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.memory import ConversationBufferMemory
from langchain.prompts.prompt import PromptTemplate
from langchain.chains.conversational_retrieval.prompts import (
Expand All @@ -15,6 +19,12 @@
from genai_core.langchain import WorkspaceRetriever, DynamoDBChatMessageHistory
from genai_core.types import ChatbotMode

from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.outputs import LLMResult, ChatGeneration
from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.human import HumanMessage
from langchain_aws import ChatBedrockConverse

logger = Logger()


Expand All @@ -24,13 +34,40 @@ class Mode(Enum):

class LLMStartHandler(BaseCallbackHandler):
prompts = []
usage = None

# Langchain callbacks
# https://python.langchain.com/v0.2/docs/concepts/#callbacks
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
logger.info(prompts)
self.prompts.append(prompts)

def on_llm_end(
self, response: LLMResult, *, run_id, parent_run_id, **kwargs: Any
) -> Any:
generation = response.generations[0][0] # only one llm request
if (
generation is not None
and isinstance(generation, ChatGeneration)
and isinstance(generation.message, AIMessage)
):
# In case of rag there could be 2 llm calls.
if self.usage is None:
self.usage = {
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
}
self.usage = {
"input_tokens": self.usage.get("input_tokens")
+ generation.message.usage_metadata.get("input_tokens"),
"output_tokens": self.usage.get("output_tokens")
+ generation.message.usage_metadata.get("output_tokens"),
"total_tokens": self.usage.get("total_tokens")
+ generation.message.usage_metadata.get("total_tokens"),
}


class ModelAdapter:
def __init__(
Expand Down Expand Up @@ -101,6 +138,115 @@ def get_condense_question_prompt(self):
def get_qa_prompt(self):
return QA_PROMPT

def run_with_chain_v2(self, user_prompt, workspace_id=None):
if not self.llm:
raise ValueError("llm must be set")

self.callback_handler.prompts = []
documents = []
retriever = None

if workspace_id:
retriever = WorkspaceRetriever(workspace_id=workspace_id)
# Only stream the last llm call (otherwise the internal
# llm response will be visible)
llm_without_streaming = self.get_llm({"streaming": False})
history_aware_retriever = create_history_aware_retriever(
llm_without_streaming,
retriever,
self.get_condense_question_prompt(),
)
question_answer_chain = create_stuff_documents_chain(
self.llm,
self.get_qa_prompt(),
)
chain = create_retrieval_chain(
history_aware_retriever, question_answer_chain
)
else:
chain = self.get_prompt() | self.llm

conversation = RunnableWithMessageHistory(
chain,
lambda session_id: self.chat_history,
history_messages_key="chat_history",
input_messages_key="input",
output_messages_key="output",
)

config = {"configurable": {"session_id": self.session_id}}
try:
if self.model_kwargs.get("streaming", False):
answer = ""
for chunk in conversation.stream(
input={"input": user_prompt}, config=config
):
logger.info("chunk", chunk=chunk)
if "answer" in chunk:
answer = answer + chunk["answer"]
elif isinstance(chunk, AIMessageChunk):
for c in chunk.content:
if "text" in c:
answer = answer + c.get("text")
else:
response = conversation.invoke(
input={"input": user_prompt}, config=config
)
if "answer" in response:
answer = response.get("answer") # Rag flow
else:
answer = response.content
except Exception as e:
logger.exception(e)
raise e

if workspace_id:
# In the RAG flow, the history is not updated automatically
self.chat_history.add_message(HumanMessage(user_prompt))
self.chat_history.add_message(AIMessage(answer))
if retriever is not None:
documents = [
{
"page_content": doc.page_content,
"metadata": doc.metadata,
}
for doc in retriever.get_last_search_documents()
]

metadata = {
"modelId": self.model_id,
"modelKwargs": self.model_kwargs,
"mode": self._mode,
"sessionId": self.session_id,
"userId": self.user_id,
"documents": documents,
"prompts": self.callback_handler.prompts,
"usage": self.callback_handler.usage,
}

self.chat_history.add_metadata(metadata)

if (
self.callback_handler.usage is not None
and "total_tokens" in self.callback_handler.usage
):
# Used by Cloudwatch filters to generate a metric of token usage.
logger.info(
"Usage Metric",
# Each unique value of model id will create a
# new cloudwatch metric (each one has a cost)
model=self.model_id,
metric_type="token_usage",
value=self.callback_handler.usage.get("total_tokens"),
)

return {
"sessionId": self.session_id,
"type": "text",
"content": answer,
"metadata": metadata,
}

def run_with_chain(self, user_prompt, workspace_id=None):
if not self.llm:
raise ValueError("llm must be set")
Expand All @@ -120,7 +266,7 @@ def run_with_chain(self, user_prompt, workspace_id=None):
callbacks=[self.callback_handler],
)
result = conversation({"question": user_prompt})
logger.info(result["source_documents"])
logger.debug(result["source_documents"])
documents = [
{
"page_content": doc.page_content,
Expand Down Expand Up @@ -184,6 +330,9 @@ def run(self, prompt, workspace_id=None, *args, **kwargs):
logger.debug(f"mode: {self._mode}")

if self._mode == ChatbotMode.CHAIN.value:
return self.run_with_chain(prompt, workspace_id)
if isinstance(self.llm, ChatBedrockConverse):
return self.run_with_chain_v2(prompt, workspace_id)
else:
return self.run_with_chain(prompt, workspace_id)

raise ValueError(f"unknown mode {self._mode}")
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
)


from ..base import ModelAdapter
import genai_core.clients
from langchain_aws import ChatBedrockConverse
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder


def get_guardrails() -> dict:
if "BEDROCK_GUARDRAILS_ID" in os.environ:
return {
Expand All @@ -23,6 +29,82 @@ def get_guardrails() -> dict:
return {}


class BedrockChatAdapter(ModelAdapter):
def __init__(self, model_id, *args, **kwargs):
self.model_id = model_id

super().__init__(*args, **kwargs)

def get_qa_prompt(self):
system_prompt = (
"Use the following pieces of context to answer the question at the end."
" If you don't know the answer, just say that you don't know, "
"don't try to make up an answer. \n\n{context}"
)
return ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)

def get_prompt(self):
prompt_template = ChatPromptTemplate(
[
(
"system",
(
"The following is a friendly conversation between "
"a human and an AI."
"If the AI does not know the answer to a question, it "
"truthfully says it does not know."
),
),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
]
)

return prompt_template

def get_condense_question_prompt(self):
contextualize_q_system_prompt = (
"Given the following conversation and a follow up"
" question, rephrase the follow up question to be a standalone question."
)
return ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)

def get_llm(self, model_kwargs={}, extra={}):
bedrock = genai_core.clients.get_bedrock_client()
params = {}
if "temperature" in model_kwargs:
params["temperature"] = model_kwargs["temperature"]
if "topP" in model_kwargs:
params["top_p"] = model_kwargs["topP"]
if "maxTokens" in model_kwargs:
params["max_tokens"] = model_kwargs["maxTokens"]

guardrails = get_guardrails()
if len(guardrails.keys()) > 0:
params["guardrails"] = guardrails

return ChatBedrockConverse(
client=bedrock,
model=self.model_id,
disable_streaming=model_kwargs.get("streaming", False) == False,
callbacks=[self.callback_handler],
**params,
**extra,
)


class LLMInputOutputAdapter:
"""Adapter class to prepare the inputs from Langchain to a format
that LLM model expects.
Expand Down
Loading

0 comments on commit 274d6db

Please sign in to comment.