Skip to content

Commit

Permalink
Merge pull request #77 from gomate-community/pipeline
Browse files Browse the repository at this point in the history
Pipeline
  • Loading branch information
yanqiangmiffy authored Nov 5, 2024
2 parents a5fb4ef + 95b20fa commit 4917ca2
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 33 deletions.
14 changes: 7 additions & 7 deletions gomate/applications/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ def chat(self, question: str = '', top_k: int = 5):
contents = self.reranker.rerank(query=question, documents=[content['text'] for content in contents])
content = '\n'.join([content['text'] for content in contents])
print(contents)
response, history = self.llm.chat(question, [], content)
result = self.mc.ground_response(
response=response,
evidences=[content['text'] for content in contents],
selected_idx=[idx for idx in range(len(contents))],
markdown=True
)
result, history = self.llm.chat(question, [], content)
# result = self.mc.ground_response(
# response=response,
# evidences=[content['text'] for content in contents],
# selected_idx=[idx for idx in range(len(contents))],
# markdown=True
# )
return result, history, contents
11 changes: 9 additions & 2 deletions gomate/modules/citation/source_citation.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,14 @@ def ground_response(
# print(len(group_list))
for item in group_list[1:]:
item_tokens = set(jieba.lcut(self.remove_stopwords(item['chk_content'])))
if len(reference_tokens.intersection(item_tokens)) > 15:
if len(reference_tokens.intersection(item_tokens)) >= 12:
merged_group.append(item)
elif len(set(reference['chk_content']).intersection(set(item['chk_content'])))> 30:
# print("***"*20)
# print(reference['chk_content'])
# print(item['chk_content'])
# print(len(reference_tokens.intersection(item_tokens)))
# print(len(set(reference['chk_content']).intersection(set(item['chk_content']))))
merged_group.append(item)
else:
merged_group_list.append([item])
Expand Down Expand Up @@ -372,7 +379,7 @@ def ground_response(
if __name__ == '__main__':
mc = SourceCitation()

with open(f'{PROJECT_BASE}/data/docs/citations_samples/重复引用1.json', 'r', encoding='utf-8') as f:
with open(f'{PROJECT_BASE}/data/docs/citations_samples/完全内容重复3.json', 'r', encoding='utf-8') as f:
input_data = json.load(f)
# print(input_data)
result = mc.ground_response(
Expand Down
61 changes: 39 additions & 22 deletions gomate/modules/clusters/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys

sys.path.append("/data/users/searchgpt/yq/GoMate_dev")
import json
import os
Expand All @@ -15,6 +16,10 @@
from tqdm import tqdm
import loguru
from singlepass import SGCluster
from datetime import datetime
import uuid



keywords = [
"美国",
Expand Down Expand Up @@ -49,8 +54,8 @@ def __init__(self):
self.db = self.get_conn()

def get_conn(self):
client = pymongo.MongoClient("mongodb://root:golaxyintelligence@10.208.61.115:20000/")
# client = pymongo.MongoClient("mongodb://root:golaxyintelligence@10.60.1.145:20000/")
# client = pymongo.MongoClient("mongodb://root:golaxyintelligence@10.208.61.115:20000/")
client = pymongo.MongoClient("mongodb://root:golaxyintelligence@10.60.1.145:27017/")
db = client['goinv3_2409']
return db

Expand Down Expand Up @@ -172,8 +177,9 @@ def get_es_data():
os.makedirs("data/", exist_ok=True)

for word in keywords:
loguru.logger.info("正在获取es数据:"+word)
url = f"http://10.208.61.117:9200/document_share_data_30_news/_search?q={word}&size=6000&sort=publish_time:desc"
loguru.logger.info("正在获取es数据:" + word)
# url = f"http://10.208.61.117:9200/document_share_data_30_news/_search?q={word}&size=6000&sort=publish_time:desc"
url = f"http://10.208.61.117:9200/goinv3_document_news/_search?q={word}&sort=publish_time:desc&size=2000"
response = requests.get(url)
with open(f"data/{word}_data.json", "w", encoding="utf-8") as f:
json.dump(response.json(), f, ensure_ascii=False, indent=4)
Expand All @@ -193,7 +199,7 @@ def get_es_data():
def run_cluster_data():
print("=========一级聚类==========")
for keyword in keywords:
loguru.logger.info("一级聚类:"+keyword)
loguru.logger.info("一级聚类:" + keyword)
data = pd.read_excel(f"data/{keyword}_data.xlsx", dtype={"id": str})
data = data.drop_duplicates(subset=["title"]).reset_index(drop=True)
data["id"] = data["id"].astype(str)
Expand All @@ -211,7 +217,7 @@ def run_cluster_data():
sc.classify(data)
print("=========二级聚类==========")
for keyword in keywords:
loguru.logger.info("二级聚类:"+keyword)
loguru.logger.info("二级聚类:" + keyword)
data = pd.read_excel(f"result/level1_{keyword}_result.xlsx", dtype={"id": str})
data = data.drop_duplicates(subset=["title"]).reset_index(drop=True)
data["id"] = data["id"].astype(str)
Expand All @@ -233,9 +239,10 @@ def run_cluster_data():
except:
pass


def generate_report():
for keyword in keywords:
loguru.logger.info("正在生成报告:"+keyword)
loguru.logger.info("正在生成报告:" + keyword)
dfs = []
for file in os.listdir("result"):
if file.endswith(".xlsx") and keyword in file and 'level2_' in file:
Expand All @@ -249,33 +256,35 @@ def generate_report():
if not os.path.exists(f"result/{keyword}_cluster_level1_index.jsonl"):
with open(f"result/{keyword}_cluster_level1_index.jsonl", "w", encoding="utf-8") as f:
for index, group in tqdm(df.groupby(by=["cluster_level1_index"])):
if len(group)>=3:
if len(group) >= 3:
titles = group["title"][:30].tolist()
contents = group["title"][:5].tolist()
response1 = llm_api.compress(titles, contents)
titles = group["title"][:5].tolist()
response2 = llm_report.compress(titles, contents)

f.write(json.dumps({"cluster_level1_index": index, "level1_title": response1["response"].strip(),
"level1_content": response2["response"].strip()}, ensure_ascii=False) + "\n")
f.write(
json.dumps({"cluster_level1_index": index, "level1_title": response1["response"].strip(),
"level1_content": response2["response"].strip()}, ensure_ascii=False) + "\n")

with open(f"result/{keyword}_cluster_level2_index.jsonl", "w", encoding="utf-8") as f:
for index, group in tqdm(df.groupby(by=["cluster_level2_index"])):
if len(group)>=3:
if len(group) >= 3:
titles = group["title"][:30].tolist()
contents = group["title"][:5].tolist()
response1 = llm_api.compress(titles, contents)
titles = group["title"][:5].tolist()
response2 = llm_report.compress(titles, contents)
f.write(json.dumps({"cluster_level2_index": index, "level2_title": response1["response"].strip(),
"level2_content": response2["response"].strip()}, ensure_ascii=False) + "\n")
f.write(
json.dumps({"cluster_level2_index": index, "level2_title": response1["response"].strip(),
"level2_content": response2["response"].strip()}, ensure_ascii=False) + "\n")


def insert_mongo_report():
mc = MongoCursor()
for idx,keyword in enumerate(keywords):
for idx, keyword in enumerate(keywords):
try:
loguru.logger.info("正在插入MongoDB成功:"+keyword)
loguru.logger.info("正在插入MongoDB成功:" + keyword)
df = pd.read_excel(f"result/{keyword}_cluster_double.xlsx")
level1_mapping = {}
with open(f"result/{keyword}_cluster_level1_index.jsonl", 'r', encoding='utf-8') as f:
Expand Down Expand Up @@ -307,9 +316,11 @@ def insert_mongo_report():
# 查看结果
# 获取当前日期并格式化为 YYYYMMDD 格式
current_date = datetime.now().strftime("%Y%m%d")
# 生成一个唯一的ID
unique_id = f"{current_date}_{uuid.uuid4().hex[:6]}" # 只取uuid的前6位
template = {
'_id': f'{current_date}_00{idx+1}',
'name': f'开源情报每日简报-{current_date}',
'_id': f'{current_date}_00{idx + 1}_{unique_id}',
'name': f'开源情报每日简报-{current_date}-{keyword}',
'description': '',
'tags': ['开源', '新闻', keyword],
'content': [],
Expand Down Expand Up @@ -337,8 +348,9 @@ def insert_mongo_report():
})
template['content'] = contents
mc.insert_one(template, 'report')
except:
loguru.logger.error("插入MongoDB失败:"+keyword)
except Exception as e:
print(e)
loguru.logger.error("插入MongoDB失败:" + keyword)


def run():
Expand Down Expand Up @@ -376,8 +388,13 @@ def main():
except (KeyboardInterrupt, SystemExit):
loguru.logger.info("调度器已关闭")

def sing_run():
generate_report()
insert_mongo_report()

# def sing_run():

# get_es_data()
# run_cluster_data()
# generate_report()
# insert_mongo_report()

if __name__ == '__main__':
main()
3 changes: 1 addition & 2 deletions gomate/modules/document/json_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def parse(self, fnm, from_page=0, to_page=100000, **kwargs):
'sec_num': section['sec_num'],
'content': section['sec_theme'] + '\n' + section['content'],
'chunks': [
data['file_name'] + data['date'] + '\n' +
data['title'] + '\n' + section['sec_theme'] + '\n' +
section['sec_theme'] + '\n' +
chunk['content'] for chunk in section['chunks']
]
}
Expand Down

0 comments on commit 4917ca2

Please sign in to comment.