Skip to content

Commit

Permalink
Merge pull request #1 from franky519/main
Browse files Browse the repository at this point in the history
增加 '文本分块章-基于稀疏文本信息摘要的分块' ,修改'_toc.yml'文件以容纳新文件
  • Loading branch information
xiezipeng05 authored Nov 11, 2024
2 parents da4b3db + df7fb9c commit 6620356
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 0 deletions.
3 changes: 3 additions & 0 deletions _toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ parts:
- caption: 文本标记
chapters:
- file: text_marking/a.ipynb
- caption: 文本分块
chapters:
- file: text_marking/text_chunking_summary.ipynb
167 changes: 167 additions & 0 deletions text_marking/text_chunking_summary.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 基于稀疏文本信息摘要的分块\n",
"\n",
"## 背景\n",
"\n",
"从长文本(可能是播客音频的文字稿、长篇的访谈文本或分析报告)中,完整的提取出文章中的关键信息。\n",
"\n",
"## 方案\n",
"\n",
"先对文本基于讨论的话题进行初步的分块,然后使用大模型对每个分块进行总结\n",
"\n",
"## 问题\n",
"\n",
"现有的文本分块模式基本上是服务于 RAG 来做的,需要考虑召回的效果等,因此天然的限制了每个块的大小。\n",
"\n",
"但我的需求是将文本基于所讨论的主要话题进行分块,并且每个块不能包含多个话题,因为不相关信息会损害总结的效果,并且一篇文章讨论的主要话题可能就只有四到五个,所以每个块的大小最终会到 2k-8k tokens 甚至更大。\n",
"\n",
"## 解决方案 1\n",
"\n",
"Kamradt 提出的 sematic chunking,其原理是使用正则表达式先对句子进行分割,利用嵌入来聚合语义相似的文本块,并通过监控嵌入距离的显着变化来识别分割点。\n",
"\n",
"缺陷是考虑的 embedding 结果是句子级别的(对单一句子进行 embedding,缺少上下文信息),因此分块结果很容易断在同一话题的句子转折处。\n",
"\n",
"## 解决方案 2 \n",
"\n",
"使用 \\n 将正文切割成段落,从每个段落提取摘要,对摘要进行类似上述 sematic chunking 的处理。\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\"\"\"\n",
"此处代码无法正常运行,后续会有补充,只是作为示例进行演示,最终版本为可运行的核心代码片段\n",
"\"\"\"\n",
"\n",
"# # 方案二的代码\n",
"\n",
"# # \\n 分段落 \n",
"\n",
"\n",
"# def split_sentences(text: str) -> List[Dict]:\n",
"# \"\"\"将文本分割成句子\"\"\"\n",
"# sentences = text.split('\\n')\n",
"# # 过滤掉空字符串\n",
"# sentences = [s.strip() for s in sentences if s.strip()]\n",
"# return [{'sentence': s, 'index': i} for i, s in enumerate(sentences)]\n",
"\n",
"# # 写摘要 \n",
"\n",
"# def summarize_sentences(sentences: List[Dict]) -> List[Dict]:\n",
"# \"\"\"为每个句子生成一句话摘要\"\"\"\n",
"# client = OpenAI()\n",
"\n",
"# for sentence in sentences:\n",
"# try:\n",
"# prompt = f\"请用一句话总结以下内容,要求简洁且包含关键信息:\\n{sentence['sentence']}\"\n",
" \n",
"# completion = client.chat.completions.create(\n",
"# model='Qwen/Qwen2.5-7B-Instruct',\n",
"# messages=[\n",
"# {\"role\": \"user\", \"content\": prompt}\n",
"# ],\n",
"# stream=True,\n",
"# )\n",
" \n",
"# print(f\"正在生成句子摘要: {sentence['sentence']}\\n摘要内容: \", end=\"\", flush=True)\n",
"# summary = \"\"\n",
"# for chunk in completion:\n",
"# content = chunk.choices[0].delta.content or \"\"\n",
"# print(content, end=\"\", flush=True)\n",
"# summary += content\n",
"# print(\"\\n\")\n",
"# sentence['summary'] = summary.strip()\n",
"\n",
" \n",
"# return sentences\n",
"\n",
"# # combine 摘要 \n",
"\n",
"# def combine_sentences(sentences: List[Dict]) -> List[Dict]:\n",
"# \"\"\"组合相邻句子的摘要\"\"\"\n",
"# for i in range(len(sentences)):\n",
"# combined_summary = ''\n",
"# for j in range(i - buffer_size, i + buffer_size + 1):\n",
"# if 0 <= j < len(sentences):\n",
"# combined_summary += sentences[j]['summary'] + ' '\n",
"# sentences[i]['combined_summary'] = combined_summary.strip()\n",
"# print(f\"Sentence {i+1} combined summary: {sentences[i]['combined_summary']}\")\n",
"# return sentences\n",
"\n",
"# # embedding \n",
"\n",
"# def get_embeddings(sentences: List[Dict]) -> List[Dict]:\n",
"# \"\"\"获取句子的嵌入向量\"\"\"\n",
"# combined_sentences = [s['combined_summary'] for s in sentences]\n",
"# embedded_sentences = embedding_function(combined_sentences)\n",
"# return embedded_sentences\n",
"\n",
"# # 计算相似度\n",
"\n",
"# def calculate_distances(embeddings: List) -> List[float]:\n",
"# \"\"\"计算相邻句子嵌入向量的余弦距离\"\"\"\n",
"# distances = []\n",
"# for i in range(len(embeddings) - 1):\n",
"# embedding_current = embeddings[i]\n",
"# embedding_next = embeddings[i + 1]\n",
"# similarity = cosine_similarity([embedding_current], [embedding_next])[0][0]\n",
"# distance = 1 - similarity\n",
"# distances.append(distance)\n",
"# return distances\n",
"\n",
"# # 基于阈值的分块\n",
"\n",
"# def get_filtered_breakpoints(distances, min_chunk_size=3, percentile_threshold=90):\n",
"# \"\"\"\n",
"# 根据句子间距离获取过滤后的断点索引。\n",
" \n",
"# Args:\n",
"# distances: 句子间的距离列表\n",
"# min_chunk_size: 每个分组最少包含的句子数,默认为3\n",
"# percentile_threshold: 用于确定断点的百分位数阈值,默认为90\n",
" \n",
"# Returns:\n",
"# filtered_breakpoints: 过滤后的断点索引列表\n",
"# \"\"\"\n",
"# # 使用百分位数确定距离阈值,用于识别异常值(断点)\n",
"# breakpoint_distance_threshold = np.percentile(distances, percentile_threshold)\n",
"\n",
"# # 获取初始断点,这些断点是句间距离大于阈值的点的索引\n",
"# initial_breakpoints = [i for i, x in enumerate(distances) if x > breakpoint_distance_threshold]\n",
"\n",
"# # 过滤断点,确保每个分组至少包含min_chunk_size个句子\n",
"# filtered_breakpoints = []\n",
"# start_idx = 0\n",
"# for bp in initial_breakpoints:\n",
"# # 检查分组大小是否至少为min_chunk_size个句子\n",
"# if bp - start_idx >= min_chunk_size - 1:\n",
"# filtered_breakpoints.append(bp)\n",
"# start_idx = bp + 1\n",
"\n",
"# # 检查最后一个分组\n",
"# if len(distances) - start_idx < min_chunk_size:\n",
"# # 如果最后一个分组太小,则移除最后一个断点\n",
"# if filtered_breakpoints:\n",
"# filtered_breakpoints.pop()\n",
"\n",
"# return filtered_breakpoints\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 6620356

Please sign in to comment.