diff --git a/api/apps/core/citation/bodys.py b/api/apps/core/citation/bodys.py index feef575..f90e152 100644 --- a/api/apps/core/citation/bodys.py +++ b/api/apps/core/citation/bodys.py @@ -33,3 +33,11 @@ class CitationBody(BaseModel): [1,2], description="文档对应的索引" ) + show_code:bool=Field( + default=True, + description="是否显示引用代码块" + ) + selected_docs:List[dict]=Field( + default=[] + + ) \ No newline at end of file diff --git a/api/apps/core/citation/views.py b/api/apps/core/citation/views.py index 1d022ae..f66a37a 100644 --- a/api/apps/core/citation/views.py +++ b/api/apps/core/citation/views.py @@ -25,14 +25,18 @@ async def citation(citation_body: CitationBody): response = citation_body.response evidences = citation_body.evidences - selected_idx=citation_body.selected_idx + selected_idx = citation_body.selected_idx + show_code = citation_body.show_code + selected_docs=citation_body.selected_docs # loguru.logger.info(response) # loguru.logger.info(evidences) citation_response = mc.ground_response( response=response, evidences=evidences, selected_idx=selected_idx, - markdown=True + markdown=True, + show_code=show_code, + selected_docs=selected_docs ) loguru.logger.info(citation_response) return ApiResponse(citation_response, message="答案引用成功") diff --git a/api/apps/core/parser/views.py b/api/apps/core/parser/views.py index 771fffd..67166fc 100644 --- a/api/apps/core/parser/views.py +++ b/api/apps/core/parser/views.py @@ -60,8 +60,9 @@ async def parser(file: UploadFile = File(...),chunk_size=512): raise NotImplementedError( "file type not supported yet(pdf, xlsx, doc, docx, txt supported)") contents = parser.parse(content) - loguru.logger.info(contents[0]) + # loguru.logger.info(contents[0]) contents = tc.chunk_sentences(contents, chunk_size=512) + loguru.logger.info(len(contents)) # 返回成功响应 return JSONResponse(content=contents, status_code=200) except Exception as e: diff --git a/gomate/modules/citation/match_citation.py b/gomate/modules/citation/match_citation.py index 5cad252..aa8e7ab 100644 --- a/gomate/modules/citation/match_citation.py +++ b/gomate/modules/citation/match_citation.py @@ -1,167 +1,427 @@ -import re -from abc import ABC +import pickle from typing import List import jieba - - -class MatchCitation(ABC): +class MatchCitation: def __init__(self): - self.stopwords = [ - "的" - ] + self.stopwords = ["的"] def cut(self, para: str): - """""" - pattern = [ - '([。!?\?])([^”’])', # 单字符断句符 - '(\.{6})([^”’])', # 英文省略号 - '(\…{2})([^”’])', # 中文省略号 - '([。!?\?][”’])([^,。!?\?])' - ] - for i in pattern: - para = re.sub(i, r"\1\n\2", para) - para = para.rstrip() - return para.split("\n") + return para.split("。") def remove_stopwords(self, query: str): for word in self.stopwords: query = query.replace(word, " ") return query - def ground_response(self, - response: str, evidences: List[str], - selected_idx: List[int] = None, markdown: bool = False - ) -> List[dict]: - # {'type': 'default', 'texts': ['xxx', 'xxx']} - # {'type': 'quote', 'texts': ['1', '2']} - # if field == 'video': - # return [{'type': 'default', 'texts': [response]}] - - # Step 1: cut response into sentences, line break is removed - # print(response) + # def ground_response( + # self, + # response: str, + # evidences: List[str], + # selected_idx: List[int], + # markdown: bool = True, + # show_code=True, + # selected_docs=List[dict] + # ): + # """ + # # selected_docs:[ {"file_name": 'source', "content":'xxxx' , "chk_id": 1,"doc_id": '11', "newsinfo": {'title':'xxx','source':'xxx','date':'2024-08-25'}}] + # # if best_ratio > threshold: + # # final_response.append(f"{sentence}[{best_idx+1}]。") + # # if show_code: + # # final_response.append(f"\n```\n{best_match}。\n```\n") + # # else: + # # final_response.append(f"{sentence}。") + # """ + # # selected_idx[-1]=selected_idx[-1]-1 + # print(selected_idx) + # sentences = self.cut(response) + # final_response = [] + # print("\n==================response===================\n",response) + # print("\n==================evidences===================\n",evidences) + # print("\n==================selected_idx===================\n",selected_idx) + # selected_idx=[i-1 for i in selected_idx] + # print("\n==================selected_idx===================\n", selected_idx) + # print("\n==================len(evidences)===================\n", len(evidences)) + # print("\n==================len(selected_docs)===================\n", len(selected_docs)) + # + # for sentence in sentences: + # if not sentence.strip(): + # continue + # + # sentence_seg_cut = set(jieba.lcut(self.remove_stopwords(sentence))) + # sentence_seg_cut_length = len(sentence_seg_cut) + # + # best_match = None + # best_ratio = 0 + # best_idx = None + # best_i=None + # for i,idx in enumerate(selected_idx): + # evidence = evidences[i] + # evidence_sentences = self.cut(evidence) + # + # for evidence_sentence in evidence_sentences: + # evidence_seg_cut = set(jieba.lcut(self.remove_stopwords(evidence_sentence))) + # overlap = sentence_seg_cut.intersection(evidence_seg_cut) + # ratio = len(overlap) / sentence_seg_cut_length + # + # if ratio > best_ratio: + # best_ratio = ratio + # best_match = evidence_sentence + # best_idx = idx + 1 + # best_i=i + # threshold = 0.7 if len(sentence) > 20 else 0.6 + # + # + # if best_ratio > threshold: + # final_response.append(f"{sentence}[{best_idx+1}]。") + # if show_code: + # doc_info = selected_docs[best_i] + # newsinfo = doc_info.get('newsinfo', {}) + # source = newsinfo.get('source', '') + # date = newsinfo.get('date', '') + # title = newsinfo.get('title', '') + # + # info_string = f"来源: {source}, 日期: {date}, 标题: {title}" + # final_response.append(f"\n```python\n{info_string}\n\n{best_match}。\n```\n") + # else: + # final_response.append(f"{sentence}。") + # + # return ''.join(final_response) + + # def highlight_matching_segments(self, sentence, text): + # # 将句子和文本分词 + # sentence_words = jieba.lcut(sentence) + # text_words = jieba.lcut(text) + # + # # 找出匹配的词 + # matching_words = set(sentence_words) & set(text_words) + # + # # 高亮匹配的词 + # highlighted_words = [] + # for word in text_words: + # if word in matching_words: + # highlighted_words.append(f"\033[1;33m{word}\033[0m") # 黄色高亮 + # else: + # highlighted_words.append(word) + # + # return ''.join(highlighted_words) + # def ground_response( + # self, + # response: str, + # evidences: List[str], + # selected_idx: List[int], + # markdown: bool = True, + # show_code=True, + # selected_docs=List[dict] + # ): + # sentences = self.cut(response) + # final_response = [] + # selected_idx = [i - 1 for i in selected_idx] + # + # for sentence in sentences: + # if not sentence.strip(): + # continue + # + # sentence_seg_cut = set(jieba.lcut(self.remove_stopwords(sentence))) + # sentence_seg_cut_length = len(sentence_seg_cut) + # + # best_match = None + # best_ratio = 0 + # best_idx = None + # best_i = None + # + # for i, idx in enumerate(selected_idx): + # evidence = evidences[i] + # evidence_sentences = self.cut(evidence) + # + # for j, evidence_sentence in enumerate(evidence_sentences): + # evidence_seg_cut = set(jieba.lcut(self.remove_stopwords(evidence_sentence))) + # overlap = sentence_seg_cut.intersection(evidence_seg_cut) + # ratio = len(overlap) / sentence_seg_cut_length + # + # if ratio > best_ratio: + # best_ratio = ratio + # best_match = evidence_sentence + # best_idx = idx + 1 + # best_i = i + # best_j = j + # + # threshold = 0.7 if len(sentence) > 20 else 0.6 + # + # if best_ratio > threshold: + # final_response.append(f"{sentence}[{best_idx + 1}]。") + # if show_code: + # doc_info = selected_docs[best_i] + # newsinfo = doc_info.get('newsinfo', {}) + # source = newsinfo.get('source', '') + # date = newsinfo.get('date', '') + # title = newsinfo.get('title', '') + # + # info_string = f"来源: {source}, 日期: {date}, 标题: {title}" + # + # # 优化1: 如果best_match长度小于80,拼接上下文 + # evidence_sentences = self.cut(evidences[best_i]) + # if len(best_match) < 80: + # start = max(0, best_j - 1) + # end = min(len(evidence_sentences), best_j + 2) + # best_match = ' '.join(evidence_sentences[start:end]) + # + # # 优化2: 高亮匹配片段 + # highlighted_match = self.highlight_matching_segments(sentence, best_match) + # + # # 优化3: 灰色显示info_string + # final_response.append(f"\n```python\n\033[90m{info_string}\033[0m\n\n{highlighted_match}。\n```\n") + # else: + # final_response.append(f"{sentence}。") + # + # return ''.join(final_response) + + # def ground_response( + # self, + # response: str, + # evidences: List[str], + # selected_idx: List[int], + # markdown: bool = True, + # show_code=True, + # selected_docs=List[dict] + # ): + # sentences = self.cut(response) + # final_response = [] + # selected_idx = [i - 1 for i in selected_idx] + # + # for sentence in sentences: + # if not sentence.strip(): + # continue + # + # sentence_seg_cut = set(jieba.lcut(self.remove_stopwords(sentence))) + # sentence_seg_cut_length = len(sentence_seg_cut) + # + # best_match = None + # best_ratio = 0 + # best_idx = None + # best_i = None + # + # for i, idx in enumerate(selected_idx): + # evidence = evidences[i] + # evidence_sentences = self.cut(evidence) + # + # for j, evidence_sentence in enumerate(evidence_sentences): + # evidence_seg_cut = set(jieba.lcut(self.remove_stopwords(evidence_sentence))) + # overlap = sentence_seg_cut.intersection(evidence_seg_cut) + # ratio = len(overlap) / sentence_seg_cut_length + # + # if ratio > best_ratio: + # best_ratio = ratio + # best_match = evidence_sentence + # best_idx = idx + 1 + # best_i = i + # best_j = j + # + # threshold = 0.7 if len(sentence) > 20 else 0.6 + # + # if best_ratio > threshold: + # final_response.append(f"{sentence}[{best_idx + 1}]。") + # if show_code: + # doc_info = selected_docs[best_i] + # newsinfo = doc_info.get('newsinfo', {}) + # source = newsinfo.get('source', '') + # date = newsinfo.get('date', '') + # title = newsinfo.get('title', '') + # + # info_string = f"来源: {source}, 日期: {date}, 标题: {title}" + # + # # 优化1: 如果best_match长度小于80,拼接上下文 + # evidence_sentences = self.cut(evidences[best_i]) + # if len(best_match) < 80: + # start = max(0, best_j - 1) + # end = min(len(evidence_sentences), best_j + 2) + # best_match = ' '.join(evidence_sentences[start:end]) + # + # # 优化2: 使用下划线标记匹配片段 + # highlighted_match = self.underline_matching_segments(sentence, best_match) + # + # # 优化3: 使用 markdown 语法为 info_string 添加灰色 + # final_response.append(f"\n```python\n*{info_string}*\n\n{highlighted_match}。\n```\n") + # else: + # final_response.append(f"{sentence}。") + # + # return ''.join(final_response) + # + # def underline_matching_segments(self, sentence, text): + # # 将句子和文本分词 + # sentence_words = jieba.lcut(sentence) + # text_words = jieba.lcut(text) + # + # # 找出匹配的词 + # matching_words = set(sentence_words) & set(text_words) + # + # # 为匹配的词添加下划线 + # underlined_words = [] + # for word in text_words: + # if word in matching_words: + # underlined_words.append(f"__{word}__") # 使用双下划线标记 + # else: + # underlined_words.append(word) + # + # return ''.join(underlined_words) + + def ground_response( + self, + response: str, + evidences: List[str], + selected_idx: List[int], + markdown: bool = True, + show_code=True, + selected_docs=List[dict] + ): sentences = self.cut(response) - # print(sentences) - # get removed line break position - line_breaks = [] - sentences = [s for s in sentences if s] - for i in range(len(sentences) - 1): - current_index = response.index(sentences[i]) - next_sentence_index = response.index(sentences[i + 1]) - dummy_next_sentence_index = current_index + len(sentences[i]) - line_breaks.append(response[dummy_next_sentence_index:next_sentence_index]) - line_breaks.append('') final_response = [] + selected_idx = [i - 1 for i in selected_idx] - citations = [i + 1 for i in selected_idx] - paragraph_have_citation = False - paragraph = "" - for sentence, line_break in zip(sentences, line_breaks): - origin_sentence = sentence - paragraph += origin_sentence - sentence = self.remove_stopwords(sentence) - sentence_seg_cut = set(jieba.lcut(sentence)) - sentence_seg_cut_length = len(sentence_seg_cut) - if sentence_seg_cut_length <= 0: + for sentence in sentences: + if not sentence.strip(): continue - topk_evidences = [] - for evidence, idx in zip(evidences, selected_idx): - evidence_cuts = self.cut(evidence) - for j in range(len(evidence_cuts)): - evidence_cuts[j] = self.remove_stopwords(evidence_cuts[j]) - evidence_seg_cut = set(jieba.lcut(evidence_cuts[j])) + sentence_seg_cut = set(jieba.lcut(self.remove_stopwords(sentence))) + sentence_seg_cut_length = len(sentence_seg_cut) + + best_match = None + best_ratio = 0 + best_idx = None + best_i = None + + for i, idx in enumerate(selected_idx): + evidence = evidences[i] + evidence_sentences = self.cut(evidence) + + for j, evidence_sentence in enumerate(evidence_sentences): + evidence_seg_cut = set(jieba.lcut(self.remove_stopwords(evidence_sentence))) overlap = sentence_seg_cut.intersection(evidence_seg_cut) - topk_evidences.append((len(overlap) / sentence_seg_cut_length, idx)) + ratio = len(overlap) / sentence_seg_cut_length + + if ratio > best_ratio: + best_ratio = ratio + best_match = evidence_sentence + best_idx = idx + 1 + best_i = i + best_j = j + + threshold = 0.7 if len(sentence) > 20 else 0.6 + + if best_ratio > threshold: + final_response.append(f"{sentence}[{best_idx + 1}]。") + if show_code: + doc_info = selected_docs[best_i] + newsinfo = doc_info.get('newsinfo', {}) + source = newsinfo.get('source', '') + date = newsinfo.get('date', '') + title = newsinfo.get('title', '') - topk_evidences.sort(key=lambda x: x[0], reverse=True) + info_string = f"来源: {source}, 日期: {date}, 标题: {title}" - idx = 0 - sentence_citations = [] - if len(sentence) > 20: - threshold = 0.4 + # 如果best_match长度小于80,拼接上下文 + evidence_sentences = self.cut(evidences[best_i]) + print(best_match) + print(len(best_match)) + if best_match and len(best_match) < 80: + start = max(0, best_j - 1) + end = min(len(evidence_sentences), best_j + 2) + best_match = ' '.join(evidence_sentences[start:end]) + print(f"Extended Best Match: {best_match}") + + # 优化2: 使用HTML标签标记匹配片段 + highlighted_match = self.highlight_common_substrings(sentence, best_match) + + # 优化3: 使用HTML标签为info_string添加灰色 + final_response.append( + f"\n> {info_string}\n>\n> {highlighted_match}。\n\n") else: - threshold = 0.5 - - while (idx < len(topk_evidences)) and (topk_evidences[idx][0] > threshold): - paragraph_have_citation = True - sentence_citations.append(topk_evidences[idx][1] + 1) - if topk_evidences[idx][1] + 1 in citations: - citations.remove(topk_evidences[idx][1] + 1) - idx += 1 - - if sentence != sentences[-1] and line_break and line_break[0] == '\n' or sentence == sentences[-1] and len( - citations) == 0: - if not paragraph_have_citation and len(selected_idx) > 0: - topk_evidences = [] - for evidence, idx in zip(evidences, selected_idx): - evidence = self.remove_stopwords(evidence) - paragraph_seg = set(jieba.lcut(paragraph)) - evidence_seg = set(jieba.lcut(evidence)) - overlap = paragraph_seg.intersection(evidence_seg) - paragraph_seg_length = len(paragraph_seg) - topk_evidences.append((len(overlap) / paragraph_seg_length, idx)) - topk_evidences.sort(key=lambda x: x[0], reverse=True) - if len(paragraph) > 60: - threshold = 0.2 - else: - threshold = 0.3 - if topk_evidences[0][0] > threshold: - sentence_citations.append(topk_evidences[0][1] + 1) - if topk_evidences[0][1] + 1 in citations: - citations.remove(topk_evidences[0][1] + 1) - paragraph_have_citation = False - paragraph = "" - - # Add citation to response, need to consider the punctuation and line break - if origin_sentence[-1] not in [':', ':'] and len(origin_sentence) > 10 and len(sentence_citations) > 0: - sentence_citations = list(set(sentence_citations)) - if origin_sentence[-1] in ['。', ',', '!', '?', '.', ',', '!', '?', ':', ':']: - if markdown: - final_response.append( - origin_sentence[:-1] + ''.join([f'[{c}]' for c in sentence_citations]) + origin_sentence[ - -1]) - else: - final_response.append({'type': 'default', 'texts': [origin_sentence[:-1]]}) - final_response.append({'type': 'quote', 'texts': [str(c) for c in sentence_citations]}) - final_response.append({'type': 'default', 'texts': [origin_sentence[-1]]}) - else: - if markdown: - final_response.append(origin_sentence + ''.join([f'[{c}]' for c in sentence_citations])) + final_response.append(f"{sentence}。") + + return ''.join(final_response) + + # def highlight_common_substrings(self, str1, str2, min_length=6): + # def find_common_substrings(s1, s2, min_len): + # m, n = len(s1), len(s2) + # dp = [[0] * (n + 1) for _ in range(m + 1)] + # substrings = [] + # + # for i in range(1, m + 1): + # for j in range(1, n + 1): + # if s1[i - 1] == s2[j - 1]: + # dp[i][j] = dp[i - 1][j - 1] + 1 + # if dp[i][j] >= min_len: + # substrings.append((i - dp[i][j], i, j - dp[i][j], j)) + # else: + # dp[i][j] = 0 + # + # return sorted(substrings, key=lambda x: x[2], reverse=True) # 按str2中的起始位置排序 + # + # common_substrings = find_common_substrings(str1, str2, min_length) + # + # # 标记需要高亮的部分 + # marked_positions = [0] * len(str2) + # for _, _, start2, end2 in common_substrings: + # for i in range(start2, end2): + # marked_positions[i] = 1 + # + # # 构建带有高亮标记的字符串 + # result = [] + # in_mark = False + # for i, char in enumerate(str2): + # if marked_positions[i] and not in_mark: + # result.append("") + # in_mark = True + # elif not marked_positions[i] and in_mark: + # result.append("") + # in_mark = False + # result.append(char) + # + # if in_mark: + # result.append("") + # + # return ''.join(result) + def highlight_common_substrings(self, str1, str2, min_length=6): + def find_common_substrings(s1, s2, min_len): + m, n = len(s1), len(s2) + dp = [[0] * (n + 1) for _ in range(m + 1)] + substrings = [] + + for i in range(1, m + 1): + for j in range(1, n + 1): + if s1[i - 1] == s2[j - 1]: + dp[i][j] = dp[i - 1][j - 1] + 1 + if dp[i][j] >= min_len: + substrings.append((i - dp[i][j], i, j - dp[i][j], j)) else: - final_response.append({'type': 'default', 'texts': [origin_sentence]}) - final_response.append({'type': 'quote', 'texts': [str(c) for c in sentence_citations]}) - else: - if markdown: - final_response.append(origin_sentence) - else: - final_response.append({'type': 'default', 'texts': [origin_sentence]}) - - if line_break: - if markdown: - final_response.append(line_break) - else: - final_response.append({'type': 'default', 'texts': [line_break]}) - if markdown: - final_response = ''.join(final_response) - return final_response - - def concatenate_citations(self, result: list[dict] = None): - """ - - :param result: - :return: - """ - final_text = "" - for item in result: - if item['type'] == 'default': - final_text += ''.join(item['texts']) - elif item['type'] == 'quote': - quotes = ''.join([f'[{q}]' for q in item['texts']]) - final_text += quotes - return final_text + dp[i][j] = 0 + + return sorted(substrings, key=lambda x: x[2], reverse=True) # 按str2中的起始位置排序 + + common_substrings = find_common_substrings(str1, str2, min_length) + # 标记需要高亮的部分 + marked_positions = [0] * len(str2) + for _, _, start2, end2 in common_substrings: + for i in range(start2, end2): + marked_positions[i] = 1 + # 构建带有蓝色高亮标记的字符串 + result = [] + in_mark = False + for i, char in enumerate(str2): + if marked_positions[i] and not in_mark: + result.append("") + in_mark = True + elif not marked_positions[i] and in_mark: + result.append("") + in_mark = False + result.append(char) + + if in_mark: + result.append("") + + return ''.join(result) if __name__ == '__main__': mc = MatchCitation() @@ -172,23 +432,8 @@ def concatenate_citations(self, result: list[dict] = None): "本·维特利 编剧:乔·霍贝尔埃里希·霍贝尔迪恩·乔格瑞斯 国家地区:中国 | 美国 发行公司:上海华人影业有限公司五洲电影发行有限公司中国电影股份有限公司北京电影发行分公司 出品公司:上海华人影业有限公司华纳兄弟影片公司北京登峰国际文化传播有限公司 更多片名:巨齿鲨2 剧情:海洋霸主巨齿鲨,今夏再掀狂澜!乔纳斯·泰勒(杰森·斯坦森 饰)与科学家张九溟(吴京 饰)双雄联手,进入海底7000米深渊执行探索任务。他们意外遭遇史前巨兽海洋霸主巨齿鲨群的攻击,还将对战凶猛危险的远古怪兽群。惊心动魄的深渊冒险,巨燃巨爽的深海大战一触即发……" ], selected_idx=[0, 1], - markdown=True + markdown=True, + show_code=True ) print(result) - - # result = [ - # {'type': 'default', - # 'texts': ['巨齿鲨2是一部科幻冒险电影,由本·维特利执导,杰森·斯坦森、吴京、蔡书雅和克利夫·柯蒂斯主演。']}, - # {'type': 'default', 'texts': ['电影讲述了海洋霸主巨齿鲨,今夏再掀狂澜']}, - # {'type': 'quote', 'texts': ['2', '3']}, {'type': 'default', 'texts': ['!']}, - # {'type': 'default', - # 'texts': ['乔纳斯·泰勒(杰森·斯坦森饰)与科学家张九溟(吴京饰)双雄联手,进入海底7000米深渊执行探索任务']}, - # {'type': 'quote', 'texts': ['2', '3']}, {'type': 'default', 'texts': ['。']}, - # {'type': 'default', 'texts': ['他们意外遭遇史前巨兽海洋霸主巨齿鲨群的攻击,还将对战凶猛危险的远古怪兽群']}, - # {'type': 'quote', 'texts': ['2', '3']}, {'type': 'default', 'texts': ['。']}, - # {'type': 'default', 'texts': ['惊心动魄的深渊冒险,巨燃巨爽的深海大战一触即发']}, - # {'type': 'quote', 'texts': ['2', '3']}, {'type': 'default', 'texts': ['。']} - # ] - # response=mc.concatenate_citations(result) - # print(response) diff --git a/gomate/modules/document/chunk.py b/gomate/modules/document/chunk.py index 155e982..9c2e642 100644 --- a/gomate/modules/document/chunk.py +++ b/gomate/modules/document/chunk.py @@ -7,20 +7,60 @@ class TextChunker: def __init__(self, ): self.tokenizer = rag_tokenizer + # def split_sentences(self, text): + # # 使用正则表达式按中英文标点符号进行分句 + # sentence_endings = re.compile(r'([。!?])') + # sentences = sentence_endings.split(text) + # + # # 将标点符号和前面的句子合并 + # sentences = ["".join(i) for i in zip(sentences[0::2], sentences[1::2])] + sentences[2::2] + # sentences=[sentence.strip() for sentence in sentences if sentence.strip()] + # return sentences + def split_sentences(self, text): - # 使用正则表达式按中英文标点符号进行分句 + # 使用正则表达式按中文标点符号进行分句 sentence_endings = re.compile(r'([。!?])') sentences = sentence_endings.split(text) # 将标点符号和前面的句子合并 - sentences = ["".join(i) for i in zip(sentences[0::2], sentences[1::2])] + sentences[2::2] - sentences=[sentence.strip() for sentence in sentences if sentence.strip()] - return sentences + result = [] + for i in range(0, len(sentences) - 1, 2): + if sentences[i]: + result.append(sentences[i] + sentences[i + 1]) + + # 处理最后一个可能没有标点的句子 + if sentences[-1]: + result.append(sentences[-1]) + + # 去除空白并过滤空句子 + result = [sentence.strip() for sentence in result if sentence.strip()] + + return result + + def process_text_chunks(self, chunks): + processed_chunks = [] + for chunk in chunks: + # 处理连续的四个及以上换行符 + while '\n\n\n\n' in chunk: + chunk = chunk.replace('\n\n\n\n', '\n\n') + + # 处理连续的四个及以上空格 + while ' ' in chunk: + chunk = chunk.replace(' ', ' ') + + processed_chunks.append(chunk) + + return processed_chunks def chunk_sentences(self, paragraphs, chunk_size): """ 将段落列表按照指定的块大小进行分块。 + 首先对拼接的paragraphs进行分句,然后按照句子token个数拼接成chunk + + 如果小于chunk_size,添加下一个句子,直到超过chunk_size,那么形成一个chunk + 依次生成新的chuank + 参数: paragraphs (list): 要分块的段落列表。 @@ -32,8 +72,10 @@ def chunk_sentences(self, paragraphs, chunk_size): # 分句 sentences = self.split_sentences(text) - if len(sentences)==0: - sentences=paragraphs + # print(sentences) + + if len(sentences) == 0: + sentences = paragraphs chunks = [] current_chunk = [] current_chunk_tokens = 0 @@ -50,7 +92,7 @@ def chunk_sentences(self, paragraphs, chunk_size): if current_chunk: chunks.append(''.join(current_chunk)) - + chunks = self.process_text_chunks(chunks) return chunks @@ -208,7 +250,7 @@ def chunk_sentences(self, paragraphs, chunk_size): '[5]欧阳金雨 .及早抓住疫情过后的消费机会 [N].湖南日报 ,2020', '-02-23(003).', '作者简介 :谭诗怡(1999.10- ),女,籍贯:湖南湘乡 ,湘潭大学,本', '科在读,研究方向 :电子商务', '11Copyright©博看网 www.bookan.com.cn. All Rights Reserved.'] - paragraphs=['Hello!\nHi!\nGoodbye!'] + paragraphs = ['Hello!\nHi!\nGoodbye!'] tc = TextChunker() chunk_size = 512 chunks = tc.chunk_sentences(paragraphs, chunk_size) diff --git a/gomate/modules/document/json_parser.py b/gomate/modules/document/json_parser.py index a94c423..4eaaf23 100644 --- a/gomate/modules/document/json_parser.py +++ b/gomate/modules/document/json_parser.py @@ -16,23 +16,17 @@ def parse(self, fnm, from_page=0, to_page=100000, **kwargs): if not isinstance(fnm, str): encoding = find_codec(fnm) txt = fnm.decode(encoding, errors="ignore") - json_lines = txt.split('}\n') - # print("json_parser", json_lines[0] + '}') else: with open(fnm, "r", encoding=get_encoding(fnm)) as f: txt = f.read() - json_lines = json.loads(txt) - # print(len(json_lines)) - # print("json_parser", json_lines[0] + '}') + # print(txt) + data = json.loads(txt) sections = [] - # for sec in txt.split("\n"): - # sections.append(sec) - for line in json_lines: - try: - sections.append(line['title'] + line['content']) - except: - pass - print(len(sections),len(json_lines)) + try: + sections.append(data['title'] +'\n'+data['content']) + except: + pass + # print(len(sections),len(json_lines)) return sections def crop(self, ck, need_position): @@ -45,5 +39,5 @@ def remove_tag(txt): if __name__ == '__main__': jp = JsonParser() - data = jp.parse('/data/users/searchgpt/yq/GoMate_dev/data/docs/JSON格式/习语录1_list.json') + data = jp.parse(r'H:\Projects\GoMate\data\modified_demo.json') print(data[0]) diff --git a/gomate/modules/judger/bge_judger.py b/gomate/modules/judger/bge_judger.py index 78a8108..dac47c8 100644 --- a/gomate/modules/judger/bge_judger.py +++ b/gomate/modules/judger/bge_judger.py @@ -66,7 +66,7 @@ def judge(self, query: str, documents: List[str], k: int = 5, is_sorted: bool = { 'text': doc, 'score': score, - 'label': 1 if score >= 0.5 else 0 + 'label': 1 if score >= 0.35 else 0 } for doc, score in zip(documents, scores) ] diff --git a/gomate/modules/refiner/compressor.py b/gomate/modules/refiner/compressor.py index ed296d7..89c41a0 100644 --- a/gomate/modules/refiner/compressor.py +++ b/gomate/modules/refiner/compressor.py @@ -21,7 +21,7 @@ def __init__(self): """ # self.api_url = '' # 根据自己api地址修改 - self.api_url = 'http://127.0.0.1:8888' + self.api_url = 'http://10.208.63.29:8888' def compress(self, query, contexts): prompt = self.prompt_template.format(query=query, contexts=contexts)