Skip to content

Commit

Permalink
Merge pull request #75 from gomate-community/pipeline
Browse files Browse the repository at this point in the history
Pipeline
  • Loading branch information
yanqiangmiffy authored Oct 22, 2024
2 parents fbfae2c + 415b43e commit d221d90
Show file tree
Hide file tree
Showing 16 changed files with 2,232 additions and 101 deletions.
19 changes: 19 additions & 0 deletions examples/citations/match_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from gomate.modules.citation.match_citation import MatchCitation
import json

mc = MatchCitation()

with open(f'sample5.json', 'r', encoding='utf-8') as f:
input_data = json.load(f)
print(input_data)
result = mc.ground_response(
question=input_data["question"],
response=input_data["response"],
evidences=input_data["evidences"],
selected_idx=input_data["selected_idx"],
markdown=True,
show_code=True,
selected_docs=input_data["selected_docs"],
)

print(result)
166 changes: 166 additions & 0 deletions examples/citations/sample5.json

Large diffs are not rendered by default.

204 changes: 125 additions & 79 deletions gomate/modules/citation/match_citation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
from typing import List
from gomate.modules.document.utils import PROJECT_BASE

import jieba
import loguru

from gomate.modules.document.utils import PROJECT_BASE


class MatchCitation:
def __init__(self):
Expand Down Expand Up @@ -54,9 +56,17 @@ def remove_stopwords(self, query: str):
query = query.replace(word, " ")
return query

def highlight_common_substrings(self, sentence, evidence_sentence, evidence, min_length=6):
evidence_sentences = self.cut(evidence)
current_sentence_index = next(i for i, s in enumerate(evidence_sentences) if evidence_sentence == s)
highlighted_text = evidence_sentences[current_sentence_index]
start_evidence = evidence.index(highlighted_text)
end_evidence = start_evidence + len(highlighted_text)
return [[start_evidence, end_evidence - 1]]

def ground_response(
self,
question:str,
question: str,
response: str,
evidences: List[str],
selected_idx: List[int],
Expand All @@ -74,103 +84,139 @@ def ground_response(
"selected_idx": selected_idx,
"selected_docs": selected_docs
}

# Log using loguru
# loguru.logger.info(f"Response: {response}")
# loguru.logger.info(f"Evidences: {evidences}")
# loguru.logger.info(f"Selected indices: {selected_idx}")
# loguru.logger.info(f"Selected documents: {selected_docs}")

# Save to JSON file
output_file = "citation.json"
with open("citation.json", 'w', encoding='utf-8') as f:
output_file = "citation_match.json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(json_data, f, ensure_ascii=False, indent=4)

loguru.logger.info(f"Parameters saved to {output_file}")

print(response)
sentences = self.cut(response)
print(sentences)
selected_idx = [i - 1 for i in selected_idx]

quote_list = []
final_response = []
quote_index_map = {} # To keep track of existing quotes
best_idx = 0

for sentence in sentences:
print("===================sentence", sentence)
if not sentence.strip():
# continue
final_response.append(sentence)
else:
sentence_seg_cut = set(jieba.lcut(self.remove_stopwords(sentence)))
sentence_seg_cut_length = len(sentence_seg_cut)
threshold = 0.6
final_response.append(f"{sentence}")
group_list = []
for i, idx in enumerate(selected_idx):
evidence = evidences[i]
evidence_sentences = self.cut(evidence)
for j, evidence_sentence in enumerate(evidence_sentences):
contents = [{"content": sentence} for sentence in sentences]
for cit_idx, citation in enumerate(contents):
citation['citation_content'] = []
citation['best_idx'] = []
citation['best_ratio'] = []
citation['highlighted_start_end'] = []
sentence = citation['content']
# print("===================sentence", sentence)
# 答案内容进行分词
sentence_seg_cut = set(jieba.lcut(self.remove_stopwords(sentence)))
sentence_seg_cut_length = len(sentence_seg_cut)
threshold = 0.5
# 检索内容
for doc_idx, doc in enumerate(selected_docs):
evidence_sentences = self.cut(doc['content'])
for es_idx, evidence_sentence in enumerate(evidence_sentences):
## 可能存在空的片段
if evidence_sentence.strip() and sentence.strip() :
evidence_seg_cut = set(jieba.lcut(self.remove_stopwords(evidence_sentence)))
overlap = sentence_seg_cut.intersection(evidence_seg_cut)
ratio = len(overlap) / sentence_seg_cut_length
# print(sentence_seg_cut,evidence_seg_cut,ratio)

if ratio > threshold:
best_ratio = ratio
highlighted_start_end = self.highlight_common_substrings(sentence, evidence_sentence, evidence)
best_idx = doc_idx
best_sentence = evidence_sentence
highlighted_start_end = self.highlight_common_substrings(sentence, evidence_sentence,
doc['content'])
if best_idx not in citation['best_idx']:
citation['citation_content'].append(doc['content'])
citation['best_idx'].append(best_idx)
citation['best_ratio'].append(best_ratio)
citation['highlighted_start_end'].append(highlighted_start_end)
print(contents)

citation_cnt = 0
is_citation_exists = []
for citation in contents:
best_idx = citation['best_idx']
if best_idx not in is_citation_exists:
is_citation_exists.append(best_idx)
citation_cnt += 1

is_content_exists = []
final_response = []
quote_list = []
best_indices = 0

is_group_exists=[]
for citation_idx, citation in enumerate(contents):
final_response.append(f"{citation['content']}")

best_idxes = citation['best_idx']
if len(best_idxes) > 0:
is_doc_id_exists = []
group_list = []
# 判断当前一组引用是否被当前段落引用过
if best_idxes not in is_content_exists:
for idx, best_idx in enumerate(best_idxes):
# 判断当前组是否存在重复文档
if selected_docs[best_idx]["doc_id"] not in is_doc_id_exists:
group_item = {
"doc_id": selected_docs[i]["doc_id"],
"chk_id": selected_docs[i]["chk_id"],
"doc_source": selected_docs[i]["newsinfo"]["source"],
"doc_date": selected_docs[i]["newsinfo"]["date"],
"doc_title": selected_docs[i]["newsinfo"]["title"],
"chk_content": evidence,
"best_ratio": best_ratio,
"highlight": highlighted_start_end,
"doc_id": selected_docs[best_idx]["doc_id"],
"chk_id": selected_docs[best_idx]["chk_id"],
"doc_source": selected_docs[best_idx]["newsinfo"]["source"],
"doc_date": selected_docs[best_idx]["newsinfo"]["date"],
"doc_title": selected_docs[best_idx]["newsinfo"]["title"],
# "chk_content": selected_docs[best_idx]['content'],
"chk_content": citation['citation_content'][idx],
"best_ratio": citation['best_ratio'][idx],
"highlight": citation['highlighted_start_end'][idx],
}
group_list.append(group_item)
is_doc_id_exists.append(selected_docs[best_idx]["doc_id"])
# 合并引用
group_list.sort(key=lambda x: x['best_ratio'], reverse=True)

merged_group_list = []
reference = group_list[0]
reference_tokens = set(jieba.lcut(self.remove_stopwords(reference['chk_content'])))
merged_group = [reference]
for item in group_list[1:]:
item_tokens = set(jieba.lcut(self.remove_stopwords(item['chk_content'])))
if len(reference_tokens.intersection(item_tokens)) > 5:
merged_group.append(item)
else:
merged_group_list.append([item])
# merged_group = [item]
if merged_group:
merged_group_list.append(merged_group)
for group in merged_group_list:
group_data={
"doc_list": group,
"chk_content": group[0]["chk_content"],
"highlight": group[0]["highlight"],
}
doc_id_list=[doc['doc_id'] for doc in group_data['doc_list']]
# print(doc_id_list)
if doc_id_list not in is_group_exists:
quote_list.append(group_data)
best_indices += 1
final_response.append(f"{[best_indices]}")
is_group_exists.append(doc_id_list)
else:
# print("已存在")
final_response.append(f"{[is_group_exists.index(doc_id_list)+1]}")

data = {'result': ''.join(final_response), 'quote_list': quote_list, 'summary': ''}
# Save to JSON file
json_data['result'] = ''.join(final_response)
json_data['quote_list'] = quote_list
output_file = "citation_match_res.json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(json_data, f, ensure_ascii=False, indent=4)

if group_list:
# Create a unique key for the group_list based on its content
group_key = tuple(sorted((item["doc_id"], item["chk_id"]) for item in group_list))

if group_key in quote_index_map:
# If this group already exists, use its index
existing_idx = quote_index_map[group_key]
final_response.append(f"[{existing_idx}]")
else:
# If this is a new group, add it to quote_list and update the index
best_idx += 1
quote_index_map[group_key] = best_idx
quote_list.append({
"doc_list": group_list,
"chk_content": group_list[0]["chk_content"],
"highlight": group_list[0]["highlight"],
})
final_response.append(f"[{best_idx}]")

# final_response.append("。")
# final_response.append("\n")
# print(''.join(final_response))
data = {'result': ''.join(final_response), 'quote_list': quote_list,'summary':''}
loguru.logger.info(f"Parameters saved to {output_file}")
print(json_data)
return data

def highlight_common_substrings(self, sentence, evidence_sentence, evidence, min_length=6):
evidence_sentences = self.cut(evidence)
current_sentence_index = next(i for i, s in enumerate(evidence_sentences) if evidence_sentence == s)
highlighted_text = evidence_sentences[current_sentence_index]
start_evidence = evidence.index(highlighted_text)
end_evidence = start_evidence + len(highlighted_text)
return [[start_evidence, end_evidence-1]]


if __name__ == '__main__':
mc = MatchCitation()

with open(f'{PROJECT_BASE}/data/docs/citations_samples/sample1.json','r',encoding='utf-8') as f:
input_data =json.load(f)
print(input_data)
with open(f'{PROJECT_BASE}/data/docs/citations_samples/sample19.json', 'r', encoding='utf-8') as f:
input_data = json.load(f)
result = mc.ground_response(
question=input_data["question"],
response=input_data["response"],
Expand Down
66 changes: 44 additions & 22 deletions gomate/modules/citation/source_citation.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def ground_response(
sentence_seg_cut = set(jieba.lcut(self.remove_stopwords(sentence)))
sentence_seg_cut_length = len(sentence_seg_cut)

threshold = 0.2
threshold = 0.3
# 检索内容
for doc_idx, doc in enumerate(selected_docs):
evidence_sentences = self.cut(doc['content'])
Expand Down Expand Up @@ -193,19 +193,16 @@ def ground_response(
quote_list = []
best_indices = 0

for citation in contents:
is_doc_id_exists = []
group_list = []

for citation_idx,citation in enumerate(contents):
if citation_cnt > 1:
citation['title'] = self.convert_to_chinese(str(best_indices + 1)) + '、' + citation['title']
citation['title'] = "**" + citation['title'] + "**"
else:
citation['title'] = "**" + citation['title'] + "**"
citation['title'] = self.convert_to_chinese(str(citation_idx + 1)) + '、' + citation['title']
citation['title'] = "**" + citation['title'] + "**"
final_response.append(f"{citation['title']}")

best_idxes = citation['best_idx']
print(best_idxes)

is_doc_id_exists = []
group_list = []
# 判断当前一组引用是否被当前段落引用过
if best_idxes not in is_content_exists:
for idx, best_idx in enumerate(best_idxes):
Expand All @@ -225,17 +222,42 @@ def ground_response(
group_list.append(group_item)
is_doc_id_exists.append(selected_docs[best_idx]["doc_id"])

quote_list.append({
"doc_list": group_list,
"chk_content": group_list[0]["chk_content"],
"highlight": group_list[0]["highlight"],
})
best_indices += 1
final_response.append(f"{citation['title']}{[best_indices]}\n\n")
# final_response.append(f"{citation['title']}\n")
# final_response.append(f"\n{citation['content']}{[best_indices]}\n\n")

is_content_exists.append(best_idxes)
# 合并引用
group_list.sort(key=lambda x: x['best_ratio'], reverse=True)

merged_group_list = []
reference = group_list[0]
reference_tokens = set(jieba.lcut(self.remove_stopwords(reference['chk_content'])))
merged_group = [reference]
# print(len(group_list))
for item in group_list[1:]:
item_tokens = set(jieba.lcut(self.remove_stopwords(item['chk_content'])))
if len(reference_tokens.intersection(item_tokens)) > 15:
merged_group.append(item)
else:
merged_group_list.append([item])
# merged_group = [item]
if merged_group:
print(len(merged_group))
merged_group_list.append(merged_group)
for group in merged_group_list:
quote_list.append({
"doc_list": group,
"chk_content": group[0]["chk_content"],
"highlight": group[0]["highlight"],
})

best_indices += 1
final_response.append(f"{[best_indices]}")
# quote_list.append({
# "doc_list": group_list,
# "chk_content": group_list[0]["chk_content"],
# "highlight": group_list[0]["highlight"],
# })
# best_indices += 1
# final_response.append(f"{citation['title']}{[best_indices]}\n\n")
#
# is_content_exists.append(best_idxes)

data = {'result': ''.join(final_response), 'quote_list': quote_list, 'summary': response['summary']}
# Save to JSON file
Expand Down Expand Up @@ -266,4 +288,4 @@ def ground_response(
selected_docs=input_data["selected_docs"],
)

# print(result)
# print(result)
Loading

0 comments on commit d221d90

Please sign in to comment.