diff --git a/_toc.yml b/_toc.yml index 35074ac..b4cdb76 100644 --- a/_toc.yml +++ b/_toc.yml @@ -4,3 +4,6 @@ parts: - caption: 文本标记 chapters: - file: text_marking/a.ipynb + - caption: 文本分块 + chapters: + - file: text_marking/text_chunking_summary.ipynb diff --git a/text_marking/text_chunking_summary.ipynb b/text_marking/text_chunking_summary.ipynb new file mode 100644 index 0000000..7c9a537 --- /dev/null +++ b/text_marking/text_chunking_summary.ipynb @@ -0,0 +1,167 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 基于稀疏文本信息摘要的分块\n", + "\n", + "## 背景\n", + "\n", + "从长文本(可能是播客音频的文字稿、长篇的访谈文本或分析报告)中,完整的提取出文章中的关键信息。\n", + "\n", + "## 方案\n", + "\n", + "先对文本基于讨论的话题进行初步的分块,然后使用大模型对每个分块进行总结\n", + "\n", + "## 问题\n", + "\n", + "现有的文本分块模式基本上是服务于 RAG 来做的,需要考虑召回的效果等,因此天然的限制了每个块的大小。\n", + "\n", + "但我的需求是将文本基于所讨论的主要话题进行分块,并且每个块不能包含多个话题,因为不相关信息会损害总结的效果,并且一篇文章讨论的主要话题可能就只有四到五个,所以每个块的大小最终会到 2k-8k tokens 甚至更大。\n", + "\n", + "## 解决方案 1\n", + "\n", + "Kamradt 提出的 sematic chunking,其原理是使用正则表达式先对句子进行分割,利用嵌入来聚合语义相似的文本块,并通过监控嵌入距离的显着变化来识别分割点。\n", + "\n", + "缺陷是考虑的 embedding 结果是句子级别的(对单一句子进行 embedding,缺少上下文信息),因此分块结果很容易断在同一话题的句子转折处。\n", + "\n", + "## 解决方案 2 \n", + "\n", + "使用 \\n 将正文切割成段落,从每个段落提取摘要,对摘要进行类似上述 sematic chunking 的处理。\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\"\"\"\n", + "此处代码无法正常运行,后续会有补充,只是作为示例进行演示,最终版本为可运行的核心代码片段\n", + "\"\"\"\n", + "\n", + "# # 方案二的代码\n", + "\n", + "# # \\n 分段落 \n", + "\n", + "\n", + "# def split_sentences(text: str) -> List[Dict]:\n", + "# \"\"\"将文本分割成句子\"\"\"\n", + "# sentences = text.split('\\n')\n", + "# # 过滤掉空字符串\n", + "# sentences = [s.strip() for s in sentences if s.strip()]\n", + "# return [{'sentence': s, 'index': i} for i, s in enumerate(sentences)]\n", + "\n", + "# # 写摘要 \n", + "\n", + "# def summarize_sentences(sentences: List[Dict]) -> List[Dict]:\n", + "# \"\"\"为每个句子生成一句话摘要\"\"\"\n", + "# client = OpenAI()\n", + "\n", + "# for sentence in sentences:\n", + "# try:\n", + "# prompt = f\"请用一句话总结以下内容,要求简洁且包含关键信息:\\n{sentence['sentence']}\"\n", + " \n", + "# completion = client.chat.completions.create(\n", + "# model='Qwen/Qwen2.5-7B-Instruct',\n", + "# messages=[\n", + "# {\"role\": \"user\", \"content\": prompt}\n", + "# ],\n", + "# stream=True,\n", + "# )\n", + " \n", + "# print(f\"正在生成句子摘要: {sentence['sentence']}\\n摘要内容: \", end=\"\", flush=True)\n", + "# summary = \"\"\n", + "# for chunk in completion:\n", + "# content = chunk.choices[0].delta.content or \"\"\n", + "# print(content, end=\"\", flush=True)\n", + "# summary += content\n", + "# print(\"\\n\")\n", + "# sentence['summary'] = summary.strip()\n", + "\n", + " \n", + "# return sentences\n", + "\n", + "# # combine 摘要 \n", + "\n", + "# def combine_sentences(sentences: List[Dict]) -> List[Dict]:\n", + "# \"\"\"组合相邻句子的摘要\"\"\"\n", + "# for i in range(len(sentences)):\n", + "# combined_summary = ''\n", + "# for j in range(i - buffer_size, i + buffer_size + 1):\n", + "# if 0 <= j < len(sentences):\n", + "# combined_summary += sentences[j]['summary'] + ' '\n", + "# sentences[i]['combined_summary'] = combined_summary.strip()\n", + "# print(f\"Sentence {i+1} combined summary: {sentences[i]['combined_summary']}\")\n", + "# return sentences\n", + "\n", + "# # embedding \n", + "\n", + "# def get_embeddings(sentences: List[Dict]) -> List[Dict]:\n", + "# \"\"\"获取句子的嵌入向量\"\"\"\n", + "# combined_sentences = [s['combined_summary'] for s in sentences]\n", + "# embedded_sentences = embedding_function(combined_sentences)\n", + "# return embedded_sentences\n", + "\n", + "# # 计算相似度\n", + "\n", + "# def calculate_distances(embeddings: List) -> List[float]:\n", + "# \"\"\"计算相邻句子嵌入向量的余弦距离\"\"\"\n", + "# distances = []\n", + "# for i in range(len(embeddings) - 1):\n", + "# embedding_current = embeddings[i]\n", + "# embedding_next = embeddings[i + 1]\n", + "# similarity = cosine_similarity([embedding_current], [embedding_next])[0][0]\n", + "# distance = 1 - similarity\n", + "# distances.append(distance)\n", + "# return distances\n", + "\n", + "# # 基于阈值的分块\n", + "\n", + "# def get_filtered_breakpoints(distances, min_chunk_size=3, percentile_threshold=90):\n", + "# \"\"\"\n", + "# 根据句子间距离获取过滤后的断点索引。\n", + " \n", + "# Args:\n", + "# distances: 句子间的距离列表\n", + "# min_chunk_size: 每个分组最少包含的句子数,默认为3\n", + "# percentile_threshold: 用于确定断点的百分位数阈值,默认为90\n", + " \n", + "# Returns:\n", + "# filtered_breakpoints: 过滤后的断点索引列表\n", + "# \"\"\"\n", + "# # 使用百分位数确定距离阈值,用于识别异常值(断点)\n", + "# breakpoint_distance_threshold = np.percentile(distances, percentile_threshold)\n", + "\n", + "# # 获取初始断点,这些断点是句间距离大于阈值的点的索引\n", + "# initial_breakpoints = [i for i, x in enumerate(distances) if x > breakpoint_distance_threshold]\n", + "\n", + "# # 过滤断点,确保每个分组至少包含min_chunk_size个句子\n", + "# filtered_breakpoints = []\n", + "# start_idx = 0\n", + "# for bp in initial_breakpoints:\n", + "# # 检查分组大小是否至少为min_chunk_size个句子\n", + "# if bp - start_idx >= min_chunk_size - 1:\n", + "# filtered_breakpoints.append(bp)\n", + "# start_idx = bp + 1\n", + "\n", + "# # 检查最后一个分组\n", + "# if len(distances) - start_idx < min_chunk_size:\n", + "# # 如果最后一个分组太小,则移除最后一个断点\n", + "# if filtered_breakpoints:\n", + "# filtered_breakpoints.pop()\n", + "\n", + "# return filtered_breakpoints\n" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}