Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(solver): fix graph p zh name #118

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
21 changes: 9 additions & 12 deletions kag/common/benchmarks/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

from .evaUtils import get_em_f1
from ...solver.logic.core_modules.common.text_sim_by_vector import TextSimilarity


class Evaluate():
Expand All @@ -11,23 +12,19 @@ class Evaluate():
"""
def __init__(self, embedding_factory = "text-embedding-ada-002"):
self.embedding_factory = embedding_factory
self.text_similarity = TextSimilarity()

def evaForSimilarity(self, predictionlist: List[str], goldlist: List[str]):
"""
evaluate the similarity between prediction and gold #TODO
"""
# data_samples = {
# 'question': [],
# 'answer': predictionlist,
# 'ground_truth': goldlist
# }
# dataset = Dataset.from_dict(data_samples)
# run_config = RunConfig(timeout=240, thread_timeout=240, max_workers=16)
# embeddings = embedding_factory(self.embedding_factory, run_config)
#
# score = evaluate(dataset, metrics=[answer_similarity], embeddings = embeddings, run_config=run_config)
# return np.average(score.to_pandas()[['answer_similarity']])
return 0.0
total_score = 0.0
for i in range(len(predictionlist)):
scores = self.text_similarity.text_sim_result(predictionlist[i], [goldlist[i]], topk=1, low_score=0.2, is_cached=True)
if len(scores):
for score in scores:
total_score += score[1]
return total_score


def getBenchMark(self, predictionlist: List[str], goldlist: List[str]):
Expand Down
6 changes: 4 additions & 2 deletions kag/solver/implementation/default_kg_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, disable_exact_match=False, **kwargs):
self.fuzzy_match = FuzzyMatchRetrievalSpo(text_similarity=self.text_similarity, llm=self.llm_module)
self.exact_match = ExactMatchRetrievalSpo(self.schema)
self.parser = ParseLogicForm(self.schema, None)
self.exact_match_threshold = kwargs.get("exact_match_threshold", 0.9)

def retrieval_relation(self, n: GetSPONode, one_hop_graph_list: List[OneHopGraphData], **kwargs) -> KgGraph:
req_id = kwargs.get('req_id', '')
Expand Down Expand Up @@ -174,9 +175,10 @@ def _exact_match_spo(self, n: GetSPONode, one_hop_graph_list: List[OneHopGraphDa
for alias_name in total_one_kg_graph.entity_map.keys():
for e in total_one_kg_graph.entity_map[alias_name]:
score = e.score
if score < 0.9:
if score < self.exact_match_threshold:
total_one_kg_graph.rmv_node_ins(alias_name, [e.biz_id])
return total_one_kg_graph, False
if len(total_one_kg_graph.entity_map.get(alias_name, [])) == 0:
return total_one_kg_graph, False
return total_one_kg_graph, matched_flag

def _fuzzy_match_spo(self, n: GetSPONode, one_hop_graph_list: List[OneHopGraphData], req_id: str):
Expand Down
6 changes: 5 additions & 1 deletion kag/solver/implementation/default_reasoner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import time
from typing import List

from kag.interface.solver.kag_reasoner_abc import KagReasonerABC
Expand Down Expand Up @@ -60,10 +61,13 @@ def reason(self, question: str):
- history_log: A dictionary containing the history of QA pairs and re-ranked documents.
"""
# logic form planing
start_time = time.time()
lf_nodes: List[LFPlanResult] = self.lf_planner.lf_planing(question)
logger.info(f"plan cost={time.time() - start_time} lf_nodes = {lf_nodes}")


# logic form execution
solved_answer, sub_qa_pair, recall_docs, history_qa_log = self.lf_solver.solve(question, lf_nodes)
solved_answer, sub_qa_pair, recall_docs, history_qa_log, kg_graph = self.lf_solver.solve(question, lf_nodes)
# Generate supporting facts for sub question-answer pair
supporting_fact = '\n'.join(sub_qa_pair)

Expand Down
2 changes: 1 addition & 1 deletion kag/solver/implementation/default_reflector.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,4 @@ def _refine_query(self, memory: KagMemoryABC, instruction: str):
with_json_parse=False, with_except=True)
if len(update_reason_path) == 0:
return None
return update_reason_path[0]
return "\n".join(update_reason_path)
5 changes: 4 additions & 1 deletion kag/solver/logic/core_modules/common/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,4 +315,7 @@ def to_std(self, args):
class LFPlanResult:
def __init__(self, query: str, lf_nodes: List[LogicNode]):
self.query: str = query
self.lf_nodes: List[LogicNode] = lf_nodes
self.lf_nodes: List[LogicNode] = lf_nodes

def __repr__(self):
return "\n".join([str(n) for n in self.lf_nodes])
31 changes: 21 additions & 10 deletions kag/solver/logic/core_modules/common/one_hop_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def __init__(self):
self.end_entity: EntityData = None
self.end_alias = "o"
self.type: str = None
self.type_zh: str = None

def get_spo_type(self):
return f"{self.from_type}_{self.type}_{self.end_type}"
Expand All @@ -240,7 +241,8 @@ def to_json(self):
"from_type": self.from_type,
"end_entity_name": self.end_entity.name,
"end_type": self.end_type,
"type": self.type
"type": self.type,
"type_zh": self.type_zh
}

def _get_entity_description(self, entity: EntityData):
Expand Down Expand Up @@ -287,22 +289,25 @@ def __repr__(self):
from_entity_desc_str = "" if from_entity_desc is None else f"({from_entity_desc})"
to_entity_desc = self._get_entity_description(self.end_entity)
to_entity_desc_str = "" if to_entity_desc is None else f"({to_entity_desc})"
return f"({self.from_entity.name}{from_entity_desc_str} {self.type} {self.end_entity.name}{to_entity_desc_str})"
return f"({self.from_entity.name}{from_entity_desc_str} {self.type_zh} {self.end_entity.name}{to_entity_desc_str})"

@staticmethod
def from_dict(json_dict: dict, schema: SchemaUtils):
rel = RelationData()

rel.from_id = json_dict["__from_id__"]
rel.from_type = get_label_without_prefix(schema, json_dict["__from_id_type__"])
rel.from_type = json_dict["__from_id_type__"]
rel.end_id = json_dict["__to_id__"]
rel.end_type = get_label_without_prefix(schema, json_dict["__to_id_type__"])
rel.end_type = json_dict["__to_id_type__"]
rel.type = json_dict["__label__"]
spo_label_name = f"{rel.from_type}_{rel.type}_{rel.end_type}"
rel.type_zh = rel.type
from_type = get_label_without_prefix(schema, json_dict["__from_id_type__"])
end_type = get_label_without_prefix(schema, json_dict["__to_id_type__"])
spo_label_name = f"{from_type}_{rel.type}_{end_type}"
rel.prop = Prop.from_dict(json_dict, spo_label_name, schema)
if schema is not None:
if spo_label_name in schema.spo_en_zh.keys():
rel.type = schema.get_spo_with_p(schema.spo_en_zh[spo_label_name])
rel.type_zh = schema.get_spo_with_p(schema.spo_en_zh[spo_label_name])
return rel

def revert_spo(self):
Expand All @@ -323,6 +328,7 @@ def revert_spo(self):
def from_prop_value(s: EntityData, p: str, o: EntityData):
rel = RelationData()
rel.type = p
rel.type_zh = p

rel.from_id = s.biz_id
rel.from_type = s.type
Expand Down Expand Up @@ -424,7 +430,7 @@ def get_s_all_attribute_name(self):
return attribute_name_set
if len(self.s.prop.origin_prop_map) > 0:
for k in self.s.prop.origin_prop_map.keys():
attribute_name_set.append(self._schema_attr_en_to_zh(k))
attribute_name_set.append(k)
if len(self.s.prop.extend_prop_map) > 0:
for k in self.s.prop.extend_prop_map.keys():
attribute_name_set.append(k)
Expand Down Expand Up @@ -526,7 +532,9 @@ def get_std_p_value_by_spo_text(self, p, spo_text):
if p in prop.keys():
v_set = prop[p]
for rel in v_set:
relation_value_set.append(self._prase_attribute_relation(p, str(rel)))
attr_spo = f"{self.s.name} {p} {rel}"
if spo_text == attr_spo:
relation_value_set.append(self._prase_attribute_relation(p, str(rel)))
return relation_value_set


Expand Down Expand Up @@ -560,10 +568,10 @@ def get_s_all_relation_name(self):
relation_name_set = []
if len(self.in_relations) > 0:
for k in self.in_relations.keys():
relation_name_set.append(self.get_edge_en_to_zh(k))
relation_name_set.append(k)
if len(self.out_relations) > 0:
for k in self.out_relations.keys():
relation_name_set.append(self.get_edge_en_to_zh(k))
relation_name_set.append(k)
return relation_name_set


Expand Down Expand Up @@ -761,6 +769,9 @@ def to_json(self):
"edge_map": edge_dict
}

def __str__(self):
return self.to_json()

def to_edge_str(self):
return "\n".join(self.to_edge_evidence())

Expand Down
6 changes: 3 additions & 3 deletions kag/solver/logic/core_modules/common/text_sim_by_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,17 @@ def sentence_encode(self, sentences, is_cached=False):
ret.append(tmp_map[text])
return ret

def text_sim_result(self, mention, candidates: List[str], topk=1, low_score=0.63):
def text_sim_result(self, mention, candidates: List[str], topk=1, low_score=0.63, is_cached=False):
'''
output: [(candi_name, candi_score),...]
'''
if mention is None:
return []
mention_emb = self.sentence_encode(mention)
mention_emb = self.sentence_encode(mention, is_cached)
candidates = [cand for cand in candidates if cand is not None and cand.strip() != '']
if len(candidates) == 0:
return []
candidates_emb = self.sentence_encode(candidates)
candidates_emb = self.sentence_encode(candidates, is_cached)
candidates_dis = {}
for candidate, candidate_emb in zip(candidates, candidates_emb):
cosine = cosine_similarity(np.array(mention_emb), np.array(candidate_emb))
Expand Down
2 changes: 1 addition & 1 deletion kag/solver/logic/core_modules/lf_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,4 +286,4 @@ def _execute_lf(self, sub_logic_nodes):
kg_qa_result += self.output_executor.executor(n, self.req_id, self.params)
else:
logger.warning(f"unknown operator: {n.operator}")
return kg_qa_result, spo_set
return list(set(kg_qa_result)), list(set(spo_set))
4 changes: 3 additions & 1 deletion kag/solver/logic/core_modules/lf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from kag.interface.retriever.chunk_retriever_abc import ChunkRetrieverABC
from kag.interface.retriever.kg_retriever_abc import KGRetrieverABC
from kag.solver.logic.core_modules.common.base_model import LFPlanResult
from kag.solver.logic.core_modules.common.one_hop_graph import KgGraph
from kag.solver.logic.core_modules.common.schema_utils import SchemaUtils
from kag.solver.logic.core_modules.common.text_sim_by_vector import TextSimilarity
from kag.solver.logic.core_modules.common.utils import generate_random_string
Expand Down Expand Up @@ -126,6 +127,7 @@ def solve(self, query, lf_nodes: List[LFPlanResult]):
Returns:
tuple: A tuple containing the final answer, sub-query-answer pairs, relevant documents, and history.
"""
kg_graph = KgGraph()
try:
start_time = time.time()
executor = LogicExecutor(
Expand Down Expand Up @@ -163,4 +165,4 @@ def solve(self, query, lf_nodes: List[LFPlanResult]):
docs = self._flat_passages_set([cur_step_recall_docs])
if len(docs) != 0:
self.last_iter_docs = docs
return ",".join(kg_qa_result), sub_qa_pair, docs, history
return ",".join(kg_qa_result), sub_qa_pair, docs, history, kg_graph
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,15 @@ def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> list:
s_biz_id_set = []
for s_data in s_data_set:
if isinstance(s_data, EntityData):
if s_data.name == '':
s_biz_id_set.append(s_data.biz_id)
else:
if s_data.name != '':
kg_qa_result.append(s_data.name)
if isinstance(s_data, RelationData):
kg_qa_result.append(str(s_data))
if len(kg_qa_result) == 0:
for s_data in s_data_set:
if isinstance(s_data, EntityData):
if s_data.name == '':
s_biz_id_set.append(s_data.biz_id)
if len(s_biz_id_set) > 0:
one_hop_cached_map = self.dsl_runner.query_vertex_property_by_s_ids(s_biz_id_set,
n.s.get_entity_first_type(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def _convert_node_to_json(self, node_str):
try:
import json
node = json.loads(node_str)
except:
except Exception as e:
logger.warning(f"_convert_node_to_json failed {e}")
return {}
return {
'id': node['id'],
Expand Down Expand Up @@ -442,7 +443,7 @@ def parse_one_hot_graph_graph_detail_with_id_map(self, task_resp: ReasonTask, ad
s_entity = EntityData()
s_entity.type = s_type_name
s_entity.type_zh = self._get_node_type_zh(s_type_name)
s_entity.prop = Prop.from_dict(prop_values, s_entity.type, None)
s_entity.prop = Prop.from_dict(prop_values, s_entity.type, self.schema)
s_entity.biz_id = s_biz_id
s_entity.name = prop_values.get("name", "")
if "description" in prop_values.keys():
Expand Down Expand Up @@ -474,7 +475,7 @@ def parse_one_hot_graph_graph_detail_with_id_map(self, task_resp: ReasonTask, ad
o_entity = EntityData()
o_entity.type = o_type_name
o_entity.type_zh = self._get_node_type_zh(o_type_name)
o_entity.prop = Prop.from_dict(prop_values, o_entity.type, None)
o_entity.prop = Prop.from_dict(prop_values, o_entity.type, self.schema)
o_entity.biz_id = o_biz_id

o_entity.name = prop_values.get("name", "")
Expand All @@ -499,7 +500,7 @@ def parse_one_hot_graph_graph_detail_with_id_map(self, task_resp: ReasonTask, ad
continue
p_json = self._convert_edge_to_json(data[p_index])
p_json = self._trans_normal_p_json(p_json, s_json, o_json)
rel = RelationData.from_dict(p_json, None)
rel = RelationData.from_dict(p_json, self.schema)
# if rel.type in ['similarity', 'source']:
# continue
if s_entity is None:
Expand Down
Loading