From 60311089ec887c0f87caff0ed4cae534f2d6b0d6 Mon Sep 17 00:00:00 2001 From: Xu Huang Date: Mon, 23 Oct 2023 00:40:05 +0800 Subject: [PATCH 1/3] enhance fuzzy match with BGE model --- llm4crs/agent_plan_first_openai.py | 2 +- llm4crs/corups/base.py | 31 ++++++++++++++++++++++++++---- llm4crs/prompt/system.py | 2 +- requirements.txt | 2 +- 4 files changed, 30 insertions(+), 7 deletions(-) diff --git a/llm4crs/agent_plan_first_openai.py b/llm4crs/agent_plan_first_openai.py index 5655c35..36093e7 100644 --- a/llm4crs/agent_plan_first_openai.py +++ b/llm4crs/agent_plan_first_openai.py @@ -301,7 +301,6 @@ def setup_prompts(self, tools: List[Tool]): tools_desc = "\n".join([f"{tool.name}: {tool.desc}" for tool in self._tools]) tool_names = "[" + ", ".join([f"{tool.name}" for tool in self._tools]) + "]" template = SYSTEM_PROMPT_PLAN_FIRST.format( - table_info=self.item_corups.info(), tools_desc=tools_desc, tool_exe_name=self.toolbox.name, tool_names=tool_names, @@ -347,6 +346,7 @@ def run( "history": "", # chat history "input": inputs["input"], "reflection": "" if not reflection else reflection, + "table_info": self.item_corups.info(query=inputs['input']) } self.toolbox.failed_times = 0 diff --git a/llm4crs/corups/base.py b/llm4crs/corups/base.py index 50592da..42bdb0c 100644 --- a/llm4crs/corups/base.py +++ b/llm4crs/corups/base.py @@ -40,6 +40,10 @@ def __init__(self, fpath: str, column_meaning_file: str, name: str='Item_Informa self.fpath = fpath self.name = name # name of the table self.corups = self._read_file(fpath, columns, sep, parquet_engine) + # tags to be displayed to LLM: topk query-related tags + random selected tags + self.disp_cate_topk: int = 6 + self.disp_cate_total: int = 10 + self._fuzzy_bert_base = "BAAI/bge-base-en-v1.5" self._required_columns_validate() self.column_meaning = self._load_col_desc_file(column_meaning_file) @@ -53,8 +57,8 @@ def __init__(self, fpath: str, column_meaning_file: str, name: str='Item_Informa self.corups[col] = self.corups[col].apply(lambda x: ', '.join(x)) self.fuzzy_engine: Dict[str:SentBERTEngine] = { - col : SentBERTEngine(self.corups[col].to_numpy(), self.corups['id'].to_numpy(), case_sensitive=False) if col not in categorical_cols - else SentBERTEngine(self.categorical_col_values[col], np.arange(len(self.categorical_col_values[col])), case_sensitive=False) + col : SentBERTEngine(self.corups[col].to_numpy(), self.corups['id'].to_numpy(), case_sensitive=False, model_name=self._fuzzy_bert_base) if col not in categorical_cols + else SentBERTEngine(self.categorical_col_values[col], np.arange(len(self.categorical_col_values[col])), case_sensitive=False, model_name=self._fuzzy_bert_base) for col in fuzzy_cols } # title as index @@ -92,7 +96,7 @@ def __len__(self) -> int: return len(self.corups) - def info(self, remove_game_titles: bool=False): + def info(self, remove_game_titles: bool=False, query: str=None): prefix = 'Table information:' table_name = f"Table Name: {self.name}" cols_info = "Column Names, Data Types and Column meaning:" @@ -103,7 +107,8 @@ def info(self, remove_game_titles: bool=False): dtype = _pd_type_to_sql_type(self.corups[col]) cols_info += f"\n - {col}({dtype}): {self.column_meaning[col]}" if col == 'tags': - _prefix = f" Such as [{', '.join(random.sample(self.categorical_col_values[col].tolist(), k=10))}]." + disp_values = self.sample_categoricol_values(col, total_n=self.disp_cate_total, query=query, topk=self.disp_cate_topk) + _prefix = f" Related values: [{', '.join(disp_values)}]." cols_info += _prefix if dtype in {'float', 'datetime', 'integer'}: @@ -122,6 +127,24 @@ def info(self, remove_game_titles: bool=False): res = prefix + res return res + def sample_categoricol_values(self, col_name: str, total_n: int, query: str=None, topk: int=None) -> List: + # Select topk related tags according to query and sample (total_n-topk) tags + if query is None: + result = random.sample(self.categorical_col_values[col_name], k=total_n) + else: + if topk is None: + topk = total_n + assert total_n >= topk, f"`topk` must be smaller than `total_n`, while got {topk} > {total_n}." + topk_values = self.fuzzy_engine[col_name](query, return_doc=True, topk=topk) + topk_values = list(topk_values) + result = topk_values + if total_n > topk: + while (len(result) < total_n) and (len(result) < len(self.categorical_col_values[col_name])): + random_values = random.choice(self.categorical_col_values[col_name]) + if random_values not in result: + result.append(random_values) + return result + def convert_id_2_info(self, id: Union[int, List[int], np.ndarray], col_names: Union[str, List[str]]=None) -> Union[Dict, List[Dict]]: """Given game id, get game informations. diff --git a/llm4crs/prompt/system.py b/llm4crs/prompt/system.py index 1d8aad8..80291ce 100644 --- a/llm4crs/prompt/system.py +++ b/llm4crs/prompt/system.py @@ -114,7 +114,7 @@ All SQL commands are used to search in the {item} information table (a sqlite3 table). The information of the table is listed below: \ -{table_info} +{{table_info}} If human is looking up information of {item}s, such as the description of {item}s, number of {item}s, price of {item}s and so on, use the {LookUpTool}. \ diff --git a/requirements.txt b/requirements.txt index a306cee..52ec196 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,5 +11,5 @@ tiktoken==0.4.0 guidance==0.0.64 SQLAlchemy==1.4.46 --extra-index-url https://download.pytorch.org/whl/cu113 -torch==1.12.1 +torch>=1.12.1,<=1.13.1 unirec==0.0.1a2 \ No newline at end of file From 5bd218463fd353af5829545f90f4a8f8433fb79d Mon Sep 17 00:00:00 2001 From: Xu Huang Date: Mon, 23 Oct 2023 19:56:01 +0800 Subject: [PATCH 2/3] fix bug for dynamic table info; update requirement --- app.py | 10 +++++----- llm4crs/agent.py | 4 +++- llm4crs/agent_plan_first.py | 3 ++- llm4crs/corups/base.py | 2 +- llm4crs/utils/prompt.py | 2 ++ requirements.txt | 2 +- run.sh | 2 ++ 7 files changed, 16 insertions(+), 9 deletions(-) diff --git a/app.py b/app.py index d8ff91d..49c294d 100644 --- a/app.py +++ b/app.py @@ -252,11 +252,11 @@ def user(user_message, history): """ with gr.Blocks(css=css, elem_id="chatbot") as demo: with gr.Row(visible=True) as btn_raws: - with gr.Column(scale=0.5): + with gr.Column(scale=5): mode = gr.Radio( ["diversity", "accuracy"], value="accuracy", label="Recommendation Mode" ) - with gr.Column(scale=0.5): + with gr.Column(scale=5): style = gr.Radio( ["concise", "detailed"], value=getattr(bot, 'reply_style', 'concise'), label="Reply Style" ) @@ -265,13 +265,13 @@ def user(user_message, history): ) state = gr.State([]) with gr.Row(visible=True) as input_raws: - with gr.Column(scale=0.6): + with gr.Column(scale=4): txt = gr.Textbox( show_label=False, placeholder="Enter text and press enter", container=False ) - with gr.Column(scale=0.15, min_width=0): + with gr.Column(scale=1, min_width=0): send = gr.Button(value="Send", elem_id="send", variant="primary") - with gr.Column(scale=0.15, min_width=0): + with gr.Column(scale=1, min_width=0): clear = gr.ClearButton(value="Clear") state.value = [default_chat_value] diff --git a/llm4crs/agent.py b/llm4crs/agent.py index efec57c..bb2ed17 100644 --- a/llm4crs/agent.py +++ b/llm4crs/agent.py @@ -7,7 +7,7 @@ from typing import * import openai -from langchain import OpenAI +from langchain.llms import OpenAI from langchain.agents import (AgentExecutor, LLMSingleActionAgent, Tool) from langchain.callbacks import get_openai_callback from langchain.chains import LLMChain @@ -128,6 +128,7 @@ def setup_tools(self, tools: List[Callable]) -> List[Tool]: def setup_prompts(self, tools: List[Tool]): prompt = CRSChatPrompt( + table_info=self.item_corups.info(), intermediate_steps="", template=SYSTEM_PROMPT.format(table_info=self.item_corups.info(), **self._tool_names, **self._domain_map), tools=tools, @@ -156,6 +157,7 @@ def run(self, input: Dict[str, str], chat_history: str=None): self.prompt.memory = self.memory.buffer if self.selector: self.prompt.examples = self.selector(input['input']) + self.prompt.table_info = self.item_corups.info(query=input['input']) try: response = self.agent_exe.run(input) diff --git a/llm4crs/agent_plan_first.py b/llm4crs/agent_plan_first.py index 0400a77..5930496 100644 --- a/llm4crs/agent_plan_first.py +++ b/llm4crs/agent_plan_first.py @@ -75,10 +75,11 @@ def setup_tools(self, tools: List[Callable[..., Any]]) -> List[Tool]: def setup_prompts(self, tools: List[Tool]): tools_desc = "\n".join([f"{tool.name}: {tool.desc}" for tool in self._tools]) tool_names = "[" + ", ".join([f"{tool.name}" for tool in self._tools]) + "]" - template = SYSTEM_PROMPT_PLAN_FIRST.format(table_info=self.item_corups.info(), tools_desc=tools_desc, + template = SYSTEM_PROMPT_PLAN_FIRST.format(tools_desc=tools_desc, tool_exe_name=self.tools[0].name, tool_names=tool_names, **self._tool_names, **self._domain_map) prompt = CRSChatPrompt( + table_info=self.item_corups.info(), intermediate_steps="", template=template, tools=tools, diff --git a/llm4crs/corups/base.py b/llm4crs/corups/base.py index 42bdb0c..127a950 100644 --- a/llm4crs/corups/base.py +++ b/llm4crs/corups/base.py @@ -130,7 +130,7 @@ def info(self, remove_game_titles: bool=False, query: str=None): def sample_categoricol_values(self, col_name: str, total_n: int, query: str=None, topk: int=None) -> List: # Select topk related tags according to query and sample (total_n-topk) tags if query is None: - result = random.sample(self.categorical_col_values[col_name], k=total_n) + result = random.sample(self.categorical_col_values[col_name].tolist(), k=total_n) else: if topk is None: topk = total_n diff --git a/llm4crs/utils/prompt.py b/llm4crs/utils/prompt.py index d0c2339..9ff57bd 100644 --- a/llm4crs/utils/prompt.py +++ b/llm4crs/utils/prompt.py @@ -21,6 +21,7 @@ class CRSChatPrompt(StringPromptTemplate): memory: str examples: str reflection: str + table_info: str def format(self, **kwargs: Any) -> str: @@ -45,6 +46,7 @@ def format(self, **kwargs: Any) -> str: kwargs["reflection"] = self.reflection else: kwargs["reflection"] = '' + kwargs["table_info"] = self.table_info return self.template.format(**kwargs) diff --git a/requirements.txt b/requirements.txt index 52ec196..ed9c3de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ pandasql==0.7.3 -langchain==0.0.275 +langchain==0.0.312 gradio==3.40.1 loguru==0.7.0 faiss-cpu==1.7.4 diff --git a/run.sh b/run.sh index 145ba07..27d20f5 100644 --- a/run.sh +++ b/run.sh @@ -53,6 +53,8 @@ OPENAI_API_KEY=$API_KEY \ OPENAI_API_BASE=$API_BASE \ OPENAI_API_VERSION=$API_VERSION \ OPENAI_API_TYPE=$API_TYPE \ +TOKENIZERS_PARALLELISM=false \ +PYTHONPATH=$(pwd) \ python ./app.py \ --engine=$engine \ --bot_type=$bot_type \ From 0b0a7bc43a330117733e0599e085140b931579cd Mon Sep 17 00:00:00 2001 From: Xu Huang Date: Mon, 23 Oct 2023 20:07:05 +0800 Subject: [PATCH 3/3] fix bugs in prompt template --- llm4crs/prompt/system.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llm4crs/prompt/system.py b/llm4crs/prompt/system.py index 80291ce..ef914cd 100644 --- a/llm4crs/prompt/system.py +++ b/llm4crs/prompt/system.py @@ -36,7 +36,7 @@ All SQL commands are used to search in the {item} information table (a sqlite3 table). The information of the table is listed below: \ -{table_info} +{{table_info}} First you need to think: