Skip to content

Commit

Permalink
Merge pull request #5 from microsoft/bge-emb
Browse files Browse the repository at this point in the history
enhance fuzzy match with BGE model
  • Loading branch information
XuHwang authored Oct 23, 2023
2 parents 770c7b6 + 0b0a7bc commit b84aeec
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 16 deletions.
10 changes: 5 additions & 5 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,11 @@ def user(user_message, history):
"""
with gr.Blocks(css=css, elem_id="chatbot") as demo:
with gr.Row(visible=True) as btn_raws:
with gr.Column(scale=0.5):
with gr.Column(scale=5):
mode = gr.Radio(
["diversity", "accuracy"], value="accuracy", label="Recommendation Mode"
)
with gr.Column(scale=0.5):
with gr.Column(scale=5):
style = gr.Radio(
["concise", "detailed"], value=getattr(bot, 'reply_style', 'concise'), label="Reply Style"
)
Expand All @@ -265,13 +265,13 @@ def user(user_message, history):
)
state = gr.State([])
with gr.Row(visible=True) as input_raws:
with gr.Column(scale=0.6):
with gr.Column(scale=4):
txt = gr.Textbox(
show_label=False, placeholder="Enter text and press enter", container=False
)
with gr.Column(scale=0.15, min_width=0):
with gr.Column(scale=1, min_width=0):
send = gr.Button(value="Send", elem_id="send", variant="primary")
with gr.Column(scale=0.15, min_width=0):
with gr.Column(scale=1, min_width=0):
clear = gr.ClearButton(value="Clear")

state.value = [default_chat_value]
Expand Down
4 changes: 3 additions & 1 deletion llm4crs/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import *

import openai
from langchain import OpenAI
from langchain.llms import OpenAI
from langchain.agents import (AgentExecutor, LLMSingleActionAgent, Tool)
from langchain.callbacks import get_openai_callback
from langchain.chains import LLMChain
Expand Down Expand Up @@ -128,6 +128,7 @@ def setup_tools(self, tools: List[Callable]) -> List[Tool]:

def setup_prompts(self, tools: List[Tool]):
prompt = CRSChatPrompt(
table_info=self.item_corups.info(),
intermediate_steps="",
template=SYSTEM_PROMPT.format(table_info=self.item_corups.info(), **self._tool_names, **self._domain_map),
tools=tools,
Expand Down Expand Up @@ -156,6 +157,7 @@ def run(self, input: Dict[str, str], chat_history: str=None):
self.prompt.memory = self.memory.buffer
if self.selector:
self.prompt.examples = self.selector(input['input'])
self.prompt.table_info = self.item_corups.info(query=input['input'])

try:
response = self.agent_exe.run(input)
Expand Down
3 changes: 2 additions & 1 deletion llm4crs/agent_plan_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,11 @@ def setup_tools(self, tools: List[Callable[..., Any]]) -> List[Tool]:
def setup_prompts(self, tools: List[Tool]):
tools_desc = "\n".join([f"{tool.name}: {tool.desc}" for tool in self._tools])
tool_names = "[" + ", ".join([f"{tool.name}" for tool in self._tools]) + "]"
template = SYSTEM_PROMPT_PLAN_FIRST.format(table_info=self.item_corups.info(), tools_desc=tools_desc,
template = SYSTEM_PROMPT_PLAN_FIRST.format(tools_desc=tools_desc,
tool_exe_name=self.tools[0].name, tool_names=tool_names,
**self._tool_names, **self._domain_map)
prompt = CRSChatPrompt(
table_info=self.item_corups.info(),
intermediate_steps="",
template=template,
tools=tools,
Expand Down
2 changes: 1 addition & 1 deletion llm4crs/agent_plan_first_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,6 @@ def setup_prompts(self, tools: List[Tool]):
tools_desc = "\n".join([f"{tool.name}: {tool.desc}" for tool in self._tools])
tool_names = "[" + ", ".join([f"{tool.name}" for tool in self._tools]) + "]"
template = SYSTEM_PROMPT_PLAN_FIRST.format(
table_info=self.item_corups.info(),
tools_desc=tools_desc,
tool_exe_name=self.toolbox.name,
tool_names=tool_names,
Expand Down Expand Up @@ -347,6 +346,7 @@ def run(
"history": "", # chat history
"input": inputs["input"],
"reflection": "" if not reflection else reflection,
"table_info": self.item_corups.info(query=inputs['input'])
}

self.toolbox.failed_times = 0
Expand Down
31 changes: 27 additions & 4 deletions llm4crs/corups/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def __init__(self, fpath: str, column_meaning_file: str, name: str='Item_Informa
self.fpath = fpath
self.name = name # name of the table
self.corups = self._read_file(fpath, columns, sep, parquet_engine)
# tags to be displayed to LLM: topk query-related tags + random selected tags
self.disp_cate_topk: int = 6
self.disp_cate_total: int = 10
self._fuzzy_bert_base = "BAAI/bge-base-en-v1.5"
self._required_columns_validate()
self.column_meaning = self._load_col_desc_file(column_meaning_file)

Expand All @@ -53,8 +57,8 @@ def __init__(self, fpath: str, column_meaning_file: str, name: str='Item_Informa
self.corups[col] = self.corups[col].apply(lambda x: ', '.join(x))

self.fuzzy_engine: Dict[str:SentBERTEngine] = {
col : SentBERTEngine(self.corups[col].to_numpy(), self.corups['id'].to_numpy(), case_sensitive=False) if col not in categorical_cols
else SentBERTEngine(self.categorical_col_values[col], np.arange(len(self.categorical_col_values[col])), case_sensitive=False)
col : SentBERTEngine(self.corups[col].to_numpy(), self.corups['id'].to_numpy(), case_sensitive=False, model_name=self._fuzzy_bert_base) if col not in categorical_cols
else SentBERTEngine(self.categorical_col_values[col], np.arange(len(self.categorical_col_values[col])), case_sensitive=False, model_name=self._fuzzy_bert_base)
for col in fuzzy_cols
}
# title as index
Expand Down Expand Up @@ -92,7 +96,7 @@ def __len__(self) -> int:
return len(self.corups)


def info(self, remove_game_titles: bool=False):
def info(self, remove_game_titles: bool=False, query: str=None):
prefix = 'Table information:'
table_name = f"Table Name: {self.name}"
cols_info = "Column Names, Data Types and Column meaning:"
Expand All @@ -103,7 +107,8 @@ def info(self, remove_game_titles: bool=False):
dtype = _pd_type_to_sql_type(self.corups[col])
cols_info += f"\n - {col}({dtype}): {self.column_meaning[col]}"
if col == 'tags':
_prefix = f" Such as [{', '.join(random.sample(self.categorical_col_values[col].tolist(), k=10))}]."
disp_values = self.sample_categoricol_values(col, total_n=self.disp_cate_total, query=query, topk=self.disp_cate_topk)
_prefix = f" Related values: [{', '.join(disp_values)}]."
cols_info += _prefix

if dtype in {'float', 'datetime', 'integer'}:
Expand All @@ -122,6 +127,24 @@ def info(self, remove_game_titles: bool=False):
res = prefix + res
return res

def sample_categoricol_values(self, col_name: str, total_n: int, query: str=None, topk: int=None) -> List:
# Select topk related tags according to query and sample (total_n-topk) tags
if query is None:
result = random.sample(self.categorical_col_values[col_name].tolist(), k=total_n)
else:
if topk is None:
topk = total_n
assert total_n >= topk, f"`topk` must be smaller than `total_n`, while got {topk} > {total_n}."
topk_values = self.fuzzy_engine[col_name](query, return_doc=True, topk=topk)
topk_values = list(topk_values)
result = topk_values
if total_n > topk:
while (len(result) < total_n) and (len(result) < len(self.categorical_col_values[col_name])):
random_values = random.choice(self.categorical_col_values[col_name])
if random_values not in result:
result.append(random_values)
return result


def convert_id_2_info(self, id: Union[int, List[int], np.ndarray], col_names: Union[str, List[str]]=None) -> Union[Dict, List[Dict]]:
"""Given game id, get game informations.
Expand Down
4 changes: 2 additions & 2 deletions llm4crs/prompt/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
All SQL commands are used to search in the {item} information table (a sqlite3 table). The information of the table is listed below: \
{table_info}
{{table_info}}
First you need to think:
Expand Down Expand Up @@ -114,7 +114,7 @@
All SQL commands are used to search in the {item} information table (a sqlite3 table). The information of the table is listed below: \
{table_info}
{{table_info}}
If human is looking up information of {item}s, such as the description of {item}s, number of {item}s, price of {item}s and so on, use the {LookUpTool}. \
Expand Down
2 changes: 2 additions & 0 deletions llm4crs/utils/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class CRSChatPrompt(StringPromptTemplate):
memory: str
examples: str
reflection: str
table_info: str


def format(self, **kwargs: Any) -> str:
Expand All @@ -45,6 +46,7 @@ def format(self, **kwargs: Any) -> str:
kwargs["reflection"] = self.reflection
else:
kwargs["reflection"] = ''
kwargs["table_info"] = self.table_info
return self.template.format(**kwargs)


Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pandasql==0.7.3
langchain==0.0.275
langchain==0.0.312
gradio==3.40.1
loguru==0.7.0
faiss-cpu==1.7.4
Expand All @@ -11,5 +11,5 @@ tiktoken==0.4.0
guidance==0.0.64
SQLAlchemy==1.4.46
--extra-index-url https://download.pytorch.org/whl/cu113
torch==1.12.1
torch>=1.12.1,<=1.13.1
unirec==0.0.1a2
2 changes: 2 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ OPENAI_API_KEY=$API_KEY \
OPENAI_API_BASE=$API_BASE \
OPENAI_API_VERSION=$API_VERSION \
OPENAI_API_TYPE=$API_TYPE \
TOKENIZERS_PARALLELISM=false \
PYTHONPATH=$(pwd) \
python ./app.py \
--engine=$engine \
--bot_type=$bot_type \
Expand Down

0 comments on commit b84aeec

Please sign in to comment.