Skip to content
This repository has been archived by the owner on Jan 5, 2025. It is now read-only.

Commit

Permalink
Merge pull request #394 from openchatai/feat/summarizer
Browse files Browse the repository at this point in the history
Adding a summarization column to be used when summarizing responses
  • Loading branch information
codebanesr authored Dec 12, 2023
2 parents 53e227b + 5793eea commit f6d030c
Show file tree
Hide file tree
Showing 10 changed files with 161 additions and 102 deletions.
10 changes: 5 additions & 5 deletions llm-server/custom_types/bot_message.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from typing import List, Optional
from typing import List
from langchain.pydantic_v1 import BaseModel, Field
from langchain.output_parsers import PydanticOutputParser


class BotMessage(BaseModel):
bot_message: str = Field(description="Message from the bot")
ids: List[str] = Field(description="List of IDs")
missing_information: Optional[str] = Field(description="Incase of ambiguity ask user follow up question")




# Set up a parser + inject instructions into the prompt template.
bot_message_parser = PydanticOutputParser(pydantic_object=BotMessage)


# bot_message_parser.parse(input_string)
def parse_bot_message(input: str) -> BotMessage:
return bot_message_parser.parse(input)
return bot_message_parser.parse(input)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Add system_summary_prompt to Chatbot
Revision ID: d845330c4432
Revises: 86c78095b920
Create Date: 2023-12-12 14:35:53.182454
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa

# revision identifiers, used by Alembic.
revision: str = "d845330c4432"
down_revision: Union[str, None] = "86c78095b920"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade():
if (
not op.get_bind()
.execute(sa.text("SHOW COLUMNS FROM chatbots LIKE 'summary_prompt'"))
.fetchone()
):
op.add_column("chatbots", sa.Column("summary_prompt", sa.Text(), nullable=True))

op.execute(
"""
UPDATE chatbots
SET summary_prompt = "Given a JSON response, summarize the key information in a concise manner. Include relevant details, references, and links if present. Format the summary in Markdown for clarity and readability."
"""
)

op.alter_column(
"chatbots", "summary_prompt", existing_type=sa.TEXT(), nullable=False
)


def downgrade():
op.drop_column("chatbots", "system_summary_prompt")
12 changes: 10 additions & 2 deletions llm-server/routes/chat/chat_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,13 @@ async def send_chat():
headers=headers_from_json,
server_base_url=server_base_url,
app=app_name,
summary_prompt=str(bot.summary_prompt),
)

if response_data["response"]:
upsert_analytics_record(chatbot_id=str(bot.id), successful_operations=1, total_operations=1)
upsert_analytics_record(
chatbot_id=str(bot.id), successful_operations=1, total_operations=1
)
create_chat_history(str(bot.id), session_id, True, message)
create_chat_history(
str(bot.id),
Expand All @@ -170,7 +173,12 @@ async def send_chat():
response_data["response"] or response_data["error"] or "",
)
elif response_data["error"]:
upsert_analytics_record(chatbot_id=str(bot.id), successful_operations=0, total_operations=1, logs=response_data["error"])
upsert_analytics_record(
chatbot_id=str(bot.id),
successful_operations=0,
total_operations=1,
logs=response_data["error"],
)

return jsonify(
{"type": "text", "response": {"text": response_data["response"]}}
Expand Down
27 changes: 14 additions & 13 deletions llm-server/routes/root_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

chat = get_chat_model()


def validate_steps(steps: List[str], swagger_doc: ResolvingParser):
try:
paths = swagger_doc.specification.get("paths", {})
Expand All @@ -67,15 +68,18 @@ def validate_steps(steps: List[str], swagger_doc: ResolvingParser):
if all(x in operationIds for x in steps):
return True
else:
logger.warn("Model has hallucinated, made up operation id", steps=steps, operationIds=operationIds)
logger.warn(
"Model has hallucinated, made up operation id",
steps=steps,
operationIds=operationIds,
)
return False

except Exception as e:
logger.error(f"An error occurred: {str(e)}")
return False



async def handle_request(
text: str,
swagger_url: str,
Expand All @@ -85,6 +89,7 @@ async def handle_request(
headers: Dict[str, str],
server_base_url: str,
app: Optional[str],
summary_prompt: str,
) -> ResponseDict:
log_user_request(text)
check_required_fields(base_prompt, text, swagger_url)
Expand Down Expand Up @@ -112,32 +117,26 @@ async def handle_request(
prev_conversations=prev_conversations,
flows=flows,
bot_id=bot_id,
base_prompt=base_prompt
base_prompt=base_prompt,
)

if step.missing_information is not None and len(step.missing_information) >= 10:
return {
"error": None,
"response": step.bot_message + "\n" + step.missing_information
}

if len(step.ids) > 0:
swagger_doc = get_swagger_doc(swagger_url)
fl = validate_steps(step.ids, swagger_doc)

if fl is False:
return {"error": None, "response": step.bot_message}

response = await handle_api_calls(
ids=step.ids,
swagger_doc=swagger_doc,
app=app,
bot_id=bot_id,
headers=headers,
server_base_url=server_base_url,
session_id=session_id,
text=text,
swagger_url=swagger_url,
summary_prompt=summary_prompt,
)

logger.info(
Expand Down Expand Up @@ -196,6 +195,7 @@ def get_swagger_doc(swagger_url: str) -> ResolvingParser:
else:
return ResolvingParser(spec_string=swagger_doc)


async def handle_api_calls(
ids: List[str],
swagger_doc: ResolvingParser,
Expand All @@ -204,8 +204,8 @@ async def handle_api_calls(
server_base_url: str,
swagger_url: Optional[str],
app: Optional[str],
session_id: str,
bot_id: str,
summary_prompt: str,
) -> ResponseDict:
_workflow = create_workflow_from_operation_ids(ids, swagger_doc, text)
output = await run_workflow(
Expand All @@ -214,6 +214,7 @@ async def handle_api_calls(
WorkflowData(text, headers, server_base_url, swagger_url, app),
app,
bot_id=bot_id,
summary_prompt=summary_prompt,
)

_workflow["swagger_url"] = swagger_url
Expand Down
24 changes: 8 additions & 16 deletions llm-server/routes/workflow/extractors/convert_json_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,10 @@ def convert_json_to_text(
api_response: Dict[str, Any],
api_request_data: Dict[str, Any],
bot_id: str,
summary_prompt: str,
) -> str:
chat = get_chat_model()

api_summarizer_template = None
system_message = SystemMessage(
content="You are an ai assistant that can summarize api responses"
)
prompt_templates = load_prompts(bot_id)
api_summarizer_template = (
prompt_templates.api_summarizer if prompt_templates else None
)

if api_summarizer_template is not None:
system_message = SystemMessage(content=api_summarizer_template)
system_message = SystemMessage(content=summary_prompt)

messages = [
system_message,
Expand All @@ -41,12 +31,14 @@ def convert_json_to_text(
HumanMessage(
content="Here is the response from the apis: {}".format(api_response)
),
# HumanMessage(
# content="Here is the api_request_data: {}".format(api_request_data)
# ),
]

result = chat(messages)
logger.info("Convert json to text", content=result.content, incident="convert_json_to_text", api_request_data=api_request_data)
logger.info(
"Convert json to text",
content=result.content,
incident="convert_json_to_text",
api_request_data=api_request_data,
)

return cast(str, result.content)
42 changes: 25 additions & 17 deletions llm-server/routes/workflow/utils/process_conversation_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,25 @@ def process_conversation_step(
prev_conversations: List[BaseMessage],
flows: List[WorkflowFlowType],
bot_id: str,
base_prompt: str
base_prompt: str,
):
logger.info("planner data", context=context, api_summaries=api_summaries, prev_conversations=prev_conversations, flows=flows)
logger.info(
"planner data",
context=context,
api_summaries=api_summaries,
prev_conversations=prev_conversations,
flows=flows,
)
if not session_id:
raise ValueError("Session id must be defined for chat conversations")
messages: List[BaseMessage] = []
messages.append(SystemMessage(content=base_prompt))

messages.append(SystemMessage(content="You will have access to a list of api's and some useful information, called context."))

messages.append(
SystemMessage(
content="You will have access to a list of api's and some useful information, called context."
)
)

if len(prev_conversations) > 0:
messages.extend(prev_conversations)
Expand Down Expand Up @@ -71,24 +81,22 @@ def process_conversation_step(

messages.append(
HumanMessage(
content="""Based on the information provided to you and the conversation history of this conversation, I want you to answer the questions that follow. Your should respond with a json that looks like the following, you must always use the operationIds provided in api summaries. Do not make up an operation id -
{
"ids": ["list", "of", "operationIds", "for apis to be called"],
"bot_message": "your response based on the instructions provided at the beginning",
"missing_information": "Optional Field; Incase of ambiguity ask clarifying questions. You should not worry about the api filters or query, that should be decided by a different agent."
}
Don't add operation ids if you can reply by merely looking in the conversation history.
"""
content="""Based on the information provided to you and the conversation history of this conversation, I want you to answer the questions that follow. You should respond with a json that looks like the following, you must always use the operationIds provided in api summaries. Do not make up an operation id -
{
"ids": ["list", "of", "operationIds", "for", "apis", "to", "be", "called"],
"bot_message": "your response based on the instructions provided at the beginning, this could also be clarification if the information provided by the user is not complete / accurate",
}
Don't add operation ids if you can reply by merely looking in the conversation history.
"""
)
)
messages.append(
HumanMessage(content="If you are unsure / confused, ask claryfying questions")
)

messages.append(HumanMessage(content=user_requirement))



logger.info("messages array", messages=messages)

content = cast(str, chat(messages=messages).content)
Expand All @@ -105,7 +113,7 @@ def process_conversation_step(

except OutputParserException as e:
logger.warn("Failed to parse json", data=content, err=str(e))
return BotMessage(bot_message=content, ids=[], missing_information=None)
return BotMessage(bot_message=content, ids=[])
except Exception as e:
logger.warn("unexpected error occured", err=str(e))
return BotMessage(ids=[], bot_message=str(e), missing_information=None)
return BotMessage(ids=[], bot_message=str(e))
Loading

0 comments on commit f6d030c

Please sign in to comment.