Skip to content

Commit

Permalink
add long-term memory and support toolllama as brains; (#11)
Browse files Browse the repository at this point in the history
* add long-term memory and support toolllama as brains;
support partial match of tool names

* add some docstrings
  • Loading branch information
XuHwang authored Feb 6, 2024
1 parent 0b8d6e1 commit 63f85bb
Show file tree
Hide file tree
Showing 18 changed files with 768 additions and 333 deletions.
445 changes: 296 additions & 149 deletions eval/user_simulator.py

Large diffs are not rendered by default.

220 changes: 140 additions & 80 deletions llm4crs/agent_plan_first_openai.py

Large diffs are not rendered by default.

92 changes: 53 additions & 39 deletions llm4crs/corups/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from pandasql import sqldf
from pandas.api.types import is_integer_dtype, is_bool_dtype, is_float_dtype, is_datetime64_dtype, is_object_dtype, is_categorical_dtype
from sentence_transformers import SentenceTransformer
import torch

from llm4crs.utils import raise_error, SentBERTEngine

Expand All @@ -34,16 +36,14 @@ def _pd_type_to_sql_type(col: pd.Series) -> str:


class BaseGallery:

def __init__(self, fpath: str, column_meaning_file: str, name: str='Item_Information', columns: List[str]=None, sep: str=',', parquet_engine: str='pyarrow',
fuzzy_cols: List[str]=['title'], categorical_cols: List[str]=['tags']) -> None:
self.fpath = fpath
self.name = name # name of the table
self.corups = self._read_file(fpath, columns, sep, parquet_engine)
# tags to be displayed to LLM: topk query-related tags + random selected tags
self.disp_cate_topk: int = 6
self.disp_cate_total: int = 10
self._fuzzy_bert_base = "BAAI/bge-base-en-v1.5"
self._fuzzy_bert_base = "thenlper/gte-base"
self._required_columns_validate()
self.column_meaning = self._load_col_desc_file(column_meaning_file)

Expand All @@ -56,11 +56,33 @@ def __init__(self, fpath: str, column_meaning_file: str, name: str='Item_Informa
else:
self.corups[col] = self.corups[col].apply(lambda x: ', '.join(x))

self.fuzzy_engine: Dict[str:SentBERTEngine] = {
col : SentBERTEngine(self.corups[col].to_numpy(), self.corups['id'].to_numpy(), case_sensitive=False, model_name=self._fuzzy_bert_base) if col not in categorical_cols
else SentBERTEngine(self.categorical_col_values[col], np.arange(len(self.categorical_col_values[col])), case_sensitive=False, model_name=self._fuzzy_bert_base)
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
_fuzzy_bert_engine = SentenceTransformer(self._fuzzy_bert_base, device=device)
self.fuzzy_engine: Dict[str, SentBERTEngine] = {
col: SentBERTEngine(
self.corups[col].to_numpy(),
self.corups["id"].to_numpy(),
case_sensitive=False,
model=_fuzzy_bert_engine
)
if col not in categorical_cols
else SentBERTEngine(
self.categorical_col_values[col],
np.arange(len(self.categorical_col_values[col])),
case_sensitive=False,
model=_fuzzy_bert_engine
)
for col in fuzzy_cols
}
self.fuzzy_engine['sql_cols'] = SentBERTEngine(
np.array(columns),
np.arange(len(columns)),
case_sensitive=False,
model=_fuzzy_bert_engine
) # fuzzy engine for column names
# title as index
self.corups_title = self.corups.set_index('title', drop=True)
# id as index
Expand All @@ -76,25 +98,18 @@ def __call__(self, sql: str, corups: pd.DataFrame=None, return_id_only: bool=Tru
Returns:
list: the result represents by id
"""
try:
if corups is None:
res = sqldf(sql, {self.name: self.corups}) # all games
else:
res = sqldf(sql, {self.name: corups}) # games in buffer
if corups is None:
result = sqldf(sql, {self.name: self.corups}) # all games
else:
result = sqldf(sql, {self.name: corups}) # games in buffer

if return_id_only:
res = res[self.corups.index.name].to_list()
else:
pass
return res
except Exception as e:
print(e)
return []
if return_id_only:
result = result[self.corups.index.name].to_list()
return result


def __len__(self) -> int:
return len(self.corups)


def info(self, remove_game_titles: bool=False, query: str=None):
prefix = 'Table information:'
Expand All @@ -110,7 +125,7 @@ def info(self, remove_game_titles: bool=False, query: str=None):
disp_values = self.sample_categoricol_values(col, total_n=self.disp_cate_total, query=query, topk=self.disp_cate_topk)
_prefix = f" Related values: [{', '.join(disp_values)}]."
cols_info += _prefix

if dtype in {'float', 'datetime', 'integer'}:
_min = self.corups[col].min()
_max = self.corups[col].max()
Expand All @@ -120,17 +135,18 @@ def info(self, remove_game_titles: bool=False, query: str=None):
cols_info += _prefix

primary_key = f"Primary Key: {self.corups.index.name}"
foreign_key = f"Foreign Key: None"
categorical_cols = list(self.categorical_col_values.keys())
note = f"Note that [{','.join(categorical_cols)}] columns are categorical, must use related values to search otherwise no result would be returned."
res = ''
for i, s in enumerate([table_name, cols_info, primary_key, foreign_key]):
for i, s in enumerate([table_name, cols_info, primary_key, note]):
res += f"\n{i}. {s}"
res = prefix + res
return res

def sample_categoricol_values(self, col_name: str, total_n: int, query: str=None, topk: int=None) -> List:
# Select topk related tags according to query and sample (total_n-topk) tags
if query is None:
result = random.sample(self.categorical_col_values[col_name].tolist(), k=total_n)
result = random.sample(self.categorical_col_values[col_name], k=total_n)
else:
if topk is None:
topk = total_n
Expand All @@ -146,11 +162,11 @@ def sample_categoricol_values(self, col_name: str, total_n: int, query: str=None
return result


def convert_id_2_info(self, id: Union[int, List[int], np.ndarray], col_names: Union[str, List[str]]=None) -> Union[Dict, List[Dict]]:
"""Given game id, get game informations.
def convert_id_2_info(self, item_id: Union[int, List[int], np.ndarray], col_names: Union[str, List[str]]=None) -> Union[Dict, List[Dict]]:
"""Given game item_id, get game informations.
Args:
- id: game ids.
- item_id: game ids.
- col_names: column names to be returned
Returns:
Expand All @@ -167,13 +183,13 @@ def convert_id_2_info(self, id: Union[int, List[int], np.ndarray], col_names: Un
else:
raise_error(TypeError, "Not supported type for `col_names`.")

if isinstance(id, int):
items = self.corups.loc[id][col_names].to_dict()
elif isinstance(id, list) or isinstance(id, np.ndarray):
items = self.corups.loc[id][col_names].to_dict(orient='list')
if isinstance(item_id, int):
items = self.corups.loc[item_id][col_names].to_dict()
elif isinstance(item_id, list) or isinstance(item_id, np.ndarray):
items = self.corups.loc[item_id][col_names].to_dict(orient='list')
else:
raise_error(TypeError, "Not supported type for `id`.")
raise_error(TypeError, "Not supported type for `item_id`.")

return items


Expand Down Expand Up @@ -204,7 +220,7 @@ def convert_title_2_info(self, titles: Union[int, List[int], np.ndarray], col_na
items = self.corups_title.loc[titles][col_names].to_dict(orient='list')
else:
raise_error(TypeError, "Not supported type for `titles`.")

return items


Expand All @@ -225,15 +241,14 @@ def _read_file(self, fpath: str, columns: List[str]=None, sep: str=',', parquet_

def _load_col_desc_file(self, fpath: str) -> Dict:
assert fpath.endswith('.json'), "Only support json file now."
with open(fpath, 'r') as f:
with open(fpath, 'r', encoding='utf-8') as f:
return json.load(f)


def _required_columns_validate(self) -> None:
for col in _REQUIRED_COLUMNS:
if col not in self.corups.columns:
raise_error(ValueError, f"`id` and `name` are required in item corups table but {col} not found, please check the table file `{self.fpath}`.")


def fuzzy_match(self, value: Union[str, List[str]], col: str) -> Union[str, List[str]]:
if col not in self.fuzzy_engine:
Expand All @@ -244,9 +259,8 @@ def fuzzy_match(self, value: Union[str, List[str]], col: str) -> Union[str, List
return res



if __name__ == '__main__':
from llm4crs.environ_variables import *
from llm4crs.environ_variables import GAME_INFO_FILE, TABLE_COL_DESC_FILE
gallery = BaseGallery(GAME_INFO_FILE, column_meaning_file=TABLE_COL_DESC_FILE)
print(gallery.info())

Expand Down
13 changes: 5 additions & 8 deletions llm4crs/critic/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,10 @@ def _call(self, request: str, answer: str, history: str, tracks: str):
answer=answer,
**TOOL_NAMES,
)
if self.bot_type == "chat":
prompt = [
{"role": "system", "content": sys_msg},
{"role": "user", "content": usr_msg},
]
else:
prompt = f"{sys_msg}\n{usr_msg}"

reply = self.bot.call(prompt, max_tokens=128)
reply = self.bot.call(
sys_prompt=sys_msg,
user_prompt=usr_msg,
max_tokens=128
)
return reply
Empty file added llm4crs/memory/__init__.py
Empty file.
93 changes: 93 additions & 0 deletions llm4crs/memory/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@

import json

from llm4crs.utils.open_ai import OpenAICall

_FEW_SHOT_EXAMPLES = \
"""
> Conversations
User: My history is ITEM-1, ITEM-2, ITEM-3. Now I want something new.
Assistent: Based on your preference, I recommend you ITEM-17, ITEM-19, ITEM-30.
User: I don't like those items, give me more options.
Assistent: Based on your feedbacks, I recommend you ITEM-5, ITEM-100.
User: I think ITEM-100 may be very interesting. I may like it.
> Profiles
{"history": ["ITEM-1", "ITEM-2", "ITEM-3"], "like": ["ITEM-100"], "unwanted": ["ITEM-17", "ITEM-19", "ITEM-30"]}
> Conversations
User: I used to enjoy ITEM-89, ITEM-11, ITEM-78, ITEM-67. Now I want something new.
Assistent: Based on your preference, I recommend you ITEM-53, ITEM-10.
User: I think ITEM-10 may be very interesting, but I don't like it.
Assistent: Based on your feedbacks, I recommend you ITEM-88, ITEM-70.
User: I don't like those items, give me more options.
> Profiles
{"history": ["ITEM-89", "ITEM-11", "ITEM-78", "ITEM-67"], "like": [], "unwanted": ["ITEM-10", "ITEM-88", "ITEM-70"]}
"""

class UserProfileMemory:
"""
The memory is used to store long-term user profile. It can be updated by the conversation and used as the input for recommendation tool.
The memory consists of three parts: history, like and unwanted. Each part is a set. The history is a set of items that the user has interacted with. The like is a set of items that the user likes. The unwanted is a set of items that the user dislikes.
"""
def __init__(self, llm_engine=None, **kwargs) -> None:
if llm_engine:
self.llm_engine = llm_engine
else:
self.llm_engine = OpenAICall(**kwargs)
self.profile = {
"history": set([]),
"like": set([]),
"unwanted": set([]),
}

def conclude_user_profile(self, conversation: str) -> str:
prompt = "Your task is to extract user profile from the conversation."
prompt += f"The profile consists of three parts: history, like and unwanted.Each part is a list. You should return a json-format string.\nHere are some examples.\n{_FEW_SHOT_EXAMPLES}\nNow extract user profiles from below conversation: \n> Conversation\n{conversation}\n> Profiles\n"
return self.llm_engine.call(
user_prompt=prompt,
temperature=0.0
)


def correct_format(self, err_resp: str) -> str:
prompt = "Your task is to correct the string to json format. Here are two examples of the format:\n{\"history\": [\"ITEM-1\", \"ITEM-2\", \"ITEM-3\"], \"like\": [\"ITEM-100\"], \"unwanted\": [\"ITEM-17\", \"ITEM-19\", \"ITEM-30\"]}\nThe string to be corrected is {err_resp}. It can not be parsed by Python json.loads(). Now give the corrected json format string.".replace("{err_resp}", err_resp)
return self.llm_engine.call(
user_prompt=prompt,
sys_prompt="You are an assistent and good at writing json string.",
temperature=0.0
)


def update(self, conversation: str):
cur_profile: str = self.conclude_user_profile(conversation)
parse_success = False
limit = 3
tries = 0
while not parse_success and tries < limit:
try:
cur_profile_dict = json.loads(cur_profile)
parse_success = True
except json.decoder.JSONDecodeError as e:
cur_profile = self.correct_format(cur_profile)
tries += 1
if parse_success:
# update profile
self.profile['like'] -= set(cur_profile_dict.get('unwanted', []))
self.profile['like'].update(cur_profile_dict.get('like', []))
self.profile['unwanted'] -= set(cur_profile_dict.get('like', []))
self.profile['unwanted'].update(cur_profile_dict.get('unwanted', []))
self.profile['history'].update(cur_profile_dict.get('history', []))

def get(self) -> dict:
return {k: list(v) for k, v in self.profile.items()}


def clear(self):
self.profile = {
"history": set([]),
"like": set([]),
"unwanted": set([]),
}

2 changes: 2 additions & 0 deletions llm4crs/prompt/critic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.



CRITIC_PROMPT = \
"""
{{#system~}}
Expand Down
Loading

0 comments on commit 63f85bb

Please sign in to comment.