Skip to content
This repository has been archived by the owner on Jan 5, 2025. It is now read-only.

Commit

Permalink
Merge pull request #365 from openchatai/fix/base_prompts
Browse files Browse the repository at this point in the history
Fix/base prompts
  • Loading branch information
codebanesr authored Dec 6, 2023
2 parents 5d01243 + a64974b commit 59d18ea
Show file tree
Hide file tree
Showing 16 changed files with 154 additions and 85 deletions.
Empty file added cc
Empty file.
2 changes: 1 addition & 1 deletion llm-server/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ CMD ["python", "-m", "debugpy", "--listen", "0.0.0.0:5678", "--wait-for-client",
# Production stage
FROM common AS production
EXPOSE 8002
CMD ["python", "-m", "flask", "run", "--host=0.0.0.0", "--port=8002"]
CMD ["python", "-m", "flask", "run", "--host=0.0.0.0", "--port=8002", "--reload"]
3 changes: 2 additions & 1 deletion llm-server/custom_types/bot_message.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import List
from typing import List, Optional
from langchain.pydantic_v1 import BaseModel, Field
from langchain.output_parsers import PydanticOutputParser

class BotMessage(BaseModel):
bot_message: str = Field(description="Message from the bot")
ids: List[str] = Field(description="List of IDs")
missing_information: Optional[str] = Field(description="Incase of ambiguity ask user follow up question")


# Set up a parser + inject instructions into the prompt template.
Expand Down
1 change: 1 addition & 0 deletions llm-server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ pydantic==2.4.2
pydantic_core==2.10.1
pymongo==4.5.0
PyMySQL==1.1.0
pypdfium2==4.24.0
PyPika==0.48.9
PySocks==1.7.1
python-dateutil==2.8.2
Expand Down
15 changes: 8 additions & 7 deletions llm-server/routes/chat/chat_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,14 @@ async def send_chat():
app=app_name,
)

create_chat_history(str(bot.id), session_id, True, message)
create_chat_history(
str(bot.id),
session_id,
False,
response_data["response"] or response_data["error"],
)
if response_data["response"]:
create_chat_history(str(bot.id), session_id, True, message)
create_chat_history(
str(bot.id),
session_id,
False,
response_data["response"] or response_data["error"],
)

return jsonify(
{"type": "text", "response": {"text": response_data["response"]}}
Expand Down
43 changes: 42 additions & 1 deletion llm-server/routes/root_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,34 @@

chat = get_chat_model(CHAT_MODELS.gpt_3_5_turbo_16k)

def validate_steps(steps: List[str], swagger_doc: ResolvingParser):
try:
paths = swagger_doc.specification.get("paths", {})
operationIds: List[str] = []

for path in paths:
operations = paths[path]
for method in operations:
operation = operations[method]
operationId = operation.get("operationId")
if operationId:
operationIds.append(operationId)

if not operationIds:
logger.warn("No operationIds found in the Swagger document.")
return False

if all(x in operationIds for x in steps):
return True
else:
logger.warn("Model has hallucinated, made up operation id", steps=steps, operationIds=operationIds)
return False

except Exception as e:
logger.error(f"An error occurred: {str(e)}")
return False



async def handle_request(
text: str,
Expand Down Expand Up @@ -84,12 +112,25 @@ async def handle_request(
prev_conversations=prev_conversations,
flows=flows,
bot_id=bot_id,
base_prompt=base_prompt
)

if step.missing_information is not None and len(step.missing_information) >= 10:
return {
"error": None,
"response": step.missing_information
}

if len(step.ids) > 0:
swagger_doc = get_swagger_doc(swagger_url)
fl = validate_steps(step.ids, swagger_doc)

if fl is False:
return {"error": None, "response": step.bot_message}

response = await handle_api_calls(
ids=step.ids,
swagger_doc=get_swagger_doc(swagger_url),
swagger_doc=swagger_doc,
app=app,
bot_id=bot_id,
headers=headers,
Expand Down
14 changes: 7 additions & 7 deletions llm-server/routes/workflow/extractors/convert_json_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@

def convert_json_to_text(
user_input: str,
api_response: str,
api_response: Dict[str, Any],
api_request_data: Dict[str, Any],
bot_id: str,
) -> str:
chat = get_chat_model(CHAT_MODELS.gpt_3_5_turbo_16k)

api_summarizer_template = None
system_message = SystemMessage(
content="You are a chatbot that can understand API responses"
content="You are an ai assistant that can summarize api responses"
)
prompt_templates = load_prompts(bot_id)
api_summarizer_template = (
Expand All @@ -35,15 +35,15 @@ def convert_json_to_text(
messages = [
system_message,
HumanMessage(
content="You'll receive user input and server responses obtained by making calls to various APIs. You will also recieve a dictionary that specifies, the body, param and query param used to make those api calls. Your task is to transform the JSON response into a response that in an answer to the user input. You should inform the user about the filters that were used to make these api calls"
content="You'll receive user input and server responses obtained by making calls to various APIs. Your task is to summarize the api response that is an answer to the user input. Try to be concise and accurate, and also include references if present."
),
HumanMessage(content="Here is the user input: {}.".format(user_input)),
HumanMessage(content=user_input),
HumanMessage(
content="Here is the response from the apis: {}".format(api_response)
),
HumanMessage(
content="Here is the api_request_data: {}".format(api_request_data)
),
# HumanMessage(
# content="Here is the api_request_data: {}".format(api_request_data)
# ),
]

result = chat(messages)
Expand Down
32 changes: 13 additions & 19 deletions llm-server/routes/workflow/utils/process_conversation_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,15 @@ def process_conversation_step(
prev_conversations: List[BaseMessage],
flows: List[WorkflowFlowType],
bot_id: str,
base_prompt: str
):
logger.info("planner data", context=context, api_summaries=api_summaries, prev_conversations=prev_conversations, flows=flows)
if not session_id:
raise ValueError("Session id must be defined for chat conversations")
prompt_templates = load_prompts(bot_id)
system_message_classifier = SystemMessage(
content="You are a helpful ai assistant. User will give you two things, a list of api's and some useful information, called context."
)
if app and prompt_templates.system_message is not None:
system_message_classifier = SystemMessage(
content=prompt_templates.system_message
)
logger.debug(
"System message classification",
incident="system_message_classifier",
app=app,
context=context,
)
messages: List[BaseMessage] = []
messages.append(system_message_classifier)
messages.append(SystemMessage(content=base_prompt))

messages.append(SystemMessage(content="You will have access to a list of api's and some useful information, called context."))

if len(prev_conversations) > 0:
messages.extend(prev_conversations)
Expand Down Expand Up @@ -81,10 +71,11 @@ def process_conversation_step(

messages.append(
HumanMessage(
content="""Based on the information provided to you I want you to answer the questions that follow. Your should respond with a json that looks like the following -
content="""Based on the information provided to you I want you to answer the questions that follow. Your should respond with a json that looks like the following, you must always use the operationIds provided in api summaries. Do not make up an operation id -
{{
"ids": ["list", "of", "operationIds", "for apis to be called"],
"bot_message": "your response based on the instructions provided at the beginning"
"bot_message": "your response based on the instructions provided at the beginning",
"missing_information": "Optional Field; Incase of ambiguity where user input is not sufficient to make the api call, ask follow up questions. Followup question should only be asked once per user input"
}}
"""
)
Expand All @@ -94,6 +85,9 @@ def process_conversation_step(
)

messages.append(HumanMessage(content=user_requirement))


logger.info("messages array", messages=messages)

content = cast(str, chat(messages=messages).content)

Expand All @@ -110,7 +104,7 @@ def process_conversation_step(
except OutputParserException as e:
logger.error("Failed to parse json", data=content)
logger.error("Failed to parse json", err=str(e))
return BotMessage(bot_message=content, ids=[])
return BotMessage(bot_message=content, ids=[], missing_information=None)
except Exception as e:
logger.error("unexpected error occured", err=str(e))
return BotMessage(ids=[], bot_message=str(e))
return BotMessage(ids=[], bot_message=str(e), missing_information=None)
48 changes: 24 additions & 24 deletions llm-server/routes/workflow/utils/run_openapi_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def run_openapi_operations(
api_request_data[operation_id] = api_payload.__dict__
api_response = None
try:
logger.info("Making API call", incident="make_api_call", payload=json.dumps(api_payload.body_schema))
logger.info("Making API call", incident="make_api_call", body=json.dumps(api_payload.body_schema), params=api_payload.query_params)

api_response = make_api_request(
headers=headers, **api_payload.__dict__
Expand All @@ -60,28 +60,34 @@ async def run_openapi_operations(

except Exception as e:
logger.error("Error occurred while making API call", incident="make_api_call_failed", error=str(e))
return {}
raise e

logger.info("Got the following api response", text = api_response.text)
# if a custom transformer function is defined for this operationId use that, otherwise forward it to the llm
# so we don't necessarily have to defined mappers for all api endpoints
partial_json = load_json_config(app, operation_id)
logger.info("Loading JSON configuration", incident="load_json_config", json_config=json.dumps(partial_json))
if not partial_json:
logger.error(
"Failed to find a config map. Consider adding a config map for this operation id",
incident="load_json_config",
error="Failed to find a config map, consider adding a config map for this operation id",
logger.warn(
"Config map is not defined for this operationId",
incident="config_map_undefined",
operation_id=operation_id,
app=app
)
record_info[operation_id] = transform_api_response_from_schema(
api_payload.endpoint or "", api_response.text
)
record_info[operation_id] = api_response.text

# Removed this because this slows down the bot response instead of speeding it
# record_info[operation_id] = transform_api_response_from_schema(
# api_payload.endpoint or "", api_response.text
# )

pass
else:
logger.info(
"API Response",
incident="api_response",
text=api_response.text,
action="Truncate unnecessary info using json_config provided",
incident="log_api_response",
api_response=api_response.text,
json_config_used=partial_json,
next_action="summarize_with_partial_json",
)
api_json = json.loads(api_response.text)
record_info[operation_id] = json.dumps(
Expand All @@ -91,19 +97,13 @@ async def run_openapi_operations(
)

except Exception as e:
payload = json.dumps(
{
"text": text,
"headers": headers,
"server_base_url": server_base_url,
"app": app,
}
)

logger.error(
"Error occurred during workflow check in store",
incident="check_workflow_in_store",
payload=payload,
incident="check_workflow_in_store",
text= text,
headers= headers,
server_base_url= server_base_url,
app= app,
error=str(e),
)
return convert_json_to_text(text, record_info, api_request_data, bot_id=bot_id)
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ class EmbeddingProvider(Enum):
BARD = "bard"
azure = "azure"
llama2 = "llama2"
openchat = "openchat"

15 changes: 13 additions & 2 deletions llm-server/shared/utils/opencopilot_utils/get_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from functools import lru_cache
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings.ollama import OllamaEmbeddings
from .embedding_type import EmbeddingProvider
import os
from langchain.embeddings.base import Embeddings

from utils.get_logger import CustomLogger
import os, warnings


logger = CustomLogger(module_name=__name__)

LOCAL_IP = os.getenv("LOCAL_IP", "host.docker.internal")

def get_embedding_provider():
"""Gets the chosen embedding provider from environment variables."""
return os.environ.get("EMBEDDING_PROVIDER")
Expand All @@ -31,7 +37,7 @@ def get_openai_embedding():
"""Gets embeddings using the OpenAI embedding provider."""
openai_api_key = os.environ.get("OPENAI_API_KEY")

return OpenAIEmbeddings(openai_api_key=openai_api_key, chunk_size=1)
return OpenAIEmbeddings(openai_api_key=openai_api_key)

def choose_embedding_provider():
"""Chooses and returns the appropriate embedding provider instance."""
Expand All @@ -40,6 +46,10 @@ def choose_embedding_provider():
if embedding_provider == EmbeddingProvider.azure.value:
return get_azure_embedding()

elif embedding_provider == EmbeddingProvider.openchat.value:
logger.info("Got ollama embedding provider", provider=embedding_provider)
return OllamaEmbeddings(base_url=f"{LOCAL_IP}:11434", model="openchat")

elif embedding_provider == EmbeddingProvider.OPENAI.value or embedding_provider is None:
if embedding_provider is None:
warnings.warn("No embedding provider specified. Defaulting to OpenAI.")
Expand All @@ -53,6 +63,7 @@ def choose_embedding_provider():
)

# Main function to get embeddings
@lru_cache(maxsize=1)
def get_embeddings() -> Embeddings:
"""Gets embeddings using the chosen embedding provider."""
return choose_embedding_provider()
4 changes: 2 additions & 2 deletions llm-server/shared/utils/opencopilot_utils/get_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def get_vector_store(options: StoreOptions) -> VectorStore:
vector_store = Qdrant(
client, collection_name=options.namespace, embeddings=embedding
)

# vector_store = Qdrant.from_documents([], embedding, url='http://localhost:6333', collection=options.namespace)

else:
raise ValueError("Invalid STORE environment variable value")

return vector_store
return vector_store
1 change: 1 addition & 0 deletions llm-server/utils/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class ChatModels(NamedTuple):
nous_hermes = "nous-hermes"
llama2: str = "llama2"
xwinlm = "xwinlm"
openchat = "openchat"


CHAT_MODELS: ChatModels = ChatModels()
Loading

0 comments on commit 59d18ea

Please sign in to comment.