doc-etl/table/text_splitter.py
2025-05-21 10:57:19 +08:00

396 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import re
import json
import argparse
def count_chinese_tokens(text):
"""
估算中文文本的token数量
1个汉字约等于1.5个token
1个英文单词约等于1个token
1个标点符号约等于1个token
"""
# 匹配中文字符
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
# 匹配英文单词
english_words = len(re.findall(r'[a-zA-Z]+', text))
# 匹配标点符号
punctuations = len(re.findall(r'[^\w\s]', text))
# 计算总token数粗略估算
total_tokens = chinese_chars * 1.5 + english_words + punctuations
return int(total_tokens)
def process_table_content(table_content):
"""
处理表格内容,移除表格标记并进行智能分段
处理策略:
1. 清理无效内容
2. 智能分段
3. 保持语义完整性
4. 控制token长度
"""
# 移除表格标记和多余空白
content = re.sub(r'表格\s*\d+\s*(?:开始|结束)', '', table_content)
content = re.sub(r'\s+', ' ', content).strip()
# 分段处理
paragraphs = []
current_para = []
# 按句子分割
sentences = re.split(r'([。!?\n])', content)
for i in range(0, len(sentences), 2):
sentence = sentences[i].strip()
if not sentence:
continue
# 添加标点符号(如果存在)
if i + 1 < len(sentences):
sentence += sentences[i + 1]
# 检查是否是新段落的开始
if (re.match(r'^[的]', sentence) or # 以"的"开头
re.match(r'^[在]', sentence) or # 以"在"开头
re.match(r'^[\w()]+[:]', sentence)): # 以键值对形式开头
# 保存当前段落
if current_para:
full_para = ''.join(current_para).strip()
if full_para:
# 控制token长度
if count_chinese_tokens(full_para) > 512:
split_paras = split_long_paragraph(full_para)
paragraphs.extend(split_paras)
else:
paragraphs.append(full_para)
current_para = []
current_para.append(sentence)
# 处理最后一个段落
if current_para:
full_para = ''.join(current_para).strip()
if full_para:
if count_chinese_tokens(full_para) > 512:
split_paras = split_long_paragraph(full_para)
paragraphs.extend(split_paras)
else:
paragraphs.append(full_para)
return paragraphs
def split_long_paragraph(paragraph):
"""智能分割长段落,保持语义完整性"""
result = []
# 首先尝试按逗号分割
parts = re.split(r'([,。!?])', paragraph)
current_part = ""
current_tokens = 0
for i in range(0, len(parts), 2):
part = parts[i].strip()
if not part:
continue
# 添加标点符号(如果存在)
if i + 1 < len(parts):
part += parts[i + 1]
part_tokens = count_chinese_tokens(part)
if current_tokens + part_tokens > 512:
if current_part:
result.append(current_part)
current_part = part
current_tokens = part_tokens
else:
current_part += part
current_tokens += part_tokens
if current_part:
result.append(current_part)
return result
def format_group_to_text(group):
"""将分组数据格式化为易读的文本,采用通用的处理方式"""
if not group:
return ""
parts = []
# 通用处理:遍历所有键值对,构建文本
for key, value in group.items():
# 跳过空值
if not value:
continue
# 清理和格式化键名
clean_key = re.sub(r'[_\(\)]', ' ', key).strip()
# 清理值中的"表格无有效数据"字眼
if isinstance(value, str):
value = re.sub(r'[【\[]*表格无[有效]*数据[】\]]*', '', value)
if not value.strip(): # 如果清理后为空,则跳过
continue
# 构建文本片段
text = f"{clean_key}{value}"
parts.append(text)
# 使用逗号连接所有部分,并确保结果中没有"表格无有效数据"字眼
result = "".join(parts)
result = re.sub(r'[【\[]*表格无[有效]*数据[】\]]*', '', result)
return result.strip("") + "" if result.strip("") else ""
def split_long_text(text):
"""将长文本按token限制分割"""
if count_chinese_tokens(text) <= 512:
return [text]
result = []
parts = re.split(r'([。])', text)
current_part = ""
current_tokens = 0
for i in range(0, len(parts), 2):
sentence = parts[i]
if i + 1 < len(parts):
sentence += parts[i + 1] # 添加句号
sentence_tokens = count_chinese_tokens(sentence)
if current_tokens + sentence_tokens > 512:
if current_part:
result.append(current_part)
current_part = sentence
current_tokens = sentence_tokens
else:
current_part += sentence
current_tokens += sentence_tokens
if current_part:
result.append(current_part)
return result
def split_text_into_paragraphs(text):
"""
将连续文本智能分段
策略:
1. 基于标题和章节标记进行主要分段
2. 基于段落语义标记进行次要分段
3. 基于句子关联度进行内容分段
4. 基于token长度进行辅助分段确保每段不超过512个token
5. 保持段落的语义完整性
6. 智能处理表格内容
"""
# 清理文本中可能存在的多余空格和换行
text = re.sub(r'\s+', ' ', text).strip()
# 首先处理表格内容
table_pattern = re.compile(r'(表格\s*\d+\s*开始.*?表格\s*\d+\s*结束)', re.DOTALL)
parts = []
last_end = 0
for match in table_pattern.finditer(text):
# 添加表格前的文本
if match.start() > last_end:
parts.append(("text", text[last_end:match.start()]))
# 处理表格内容
table_content = match.group(1)
table_paragraphs = process_table_content(table_content)
for para in table_paragraphs:
# 确保表格段落没有冒号开头
para = re.sub(r'^[:]+\s*', '', para.strip())
if para: # 只添加非空段落
parts.append(("table", para))
last_end = match.end()
# 添加最后一个表格之后的文本
if last_end < len(text):
parts.append(("text", text[last_end:]))
# 如果没有找到表格,将整个文本作为一个文本部分
if not parts:
parts = [("text", text)]
# 主要分段标记(标题、章节等)
major_markers = [
r'^第[一二三四五六七八九十百千]+[章节篇]', # 中文数字章节
r'^第\d+[章节篇]', # 阿拉伯数字章节
r'^[一二三四五六七八九十][、.]', # 中文数字序号
r'^\d+[、.]', # 阿拉伯数字序号
r'^[(][一二三四五六七八九十][)]', # 带括号的中文数字
r'^[(]\d+[)]', # 带括号的阿拉伯数字
r'^[IVX]+[、.]', # 罗马数字序号
]
# 次要分段标记(语义转折等)
minor_markers = [
r'然而[,]',
r'但是[,]',
r'不过[,]',
r'相反[,]',
r'因此[,]',
r'所以[,]',
r'总的来说',
r'综上所述',
r'总而言之',
r'例如[,]',
r'比如[,]',
r'首先[,]',
r'其次[,]',
r'最后[,]',
r'另外[,]',
]
# 特殊段落标记
special_markers = [
r'^摘要',
r'^引言',
r'^前言',
r'^结论',
r'^致谢',
r'^参考文献',
r'^注释',
r'^附录',
]
# 合并所有标记模式
all_markers = major_markers + special_markers
marker_pattern = '|'.join(all_markers)
minor_marker_pattern = '|'.join(minor_markers)
# 按句子分割的分隔符
sentence_separators = r'([。!?\!\?])'
# 分段处理
paragraphs = []
for part_type, content in parts:
if part_type == "table":
# 表格内容已经过处理,直接添加
paragraphs.append(content)
continue
# 处理普通文本
current_para = ""
current_tokens = 0
# 按主要标记分段
text_parts = re.split(f'({marker_pattern})', content)
for i, part in enumerate(text_parts):
if not part.strip(): # 跳过空部分
continue
# 去除冒号开头
part = re.sub(r'^[:]+\s*', '', part.strip())
if not part: # 跳过清理后为空的部分
continue
if i % 2 == 1: # 是标记
if current_para:
paragraphs.append(current_para)
current_para = part
current_tokens = count_chinese_tokens(part)
else: # 是内容
sentences = re.split(sentence_separators, part)
for j, sentence in enumerate(sentences):
if not sentence.strip():
continue
# 去除句子开头的冒号
sentence = re.sub(r'^[:]+\s*', '', sentence.strip())
if not sentence:
continue
sentence_tokens = count_chinese_tokens(sentence)
# 检查是否有次要分段标记
has_minor_marker = bool(re.search(minor_marker_pattern, sentence))
if has_minor_marker and current_para:
paragraphs.append(current_para)
current_para = sentence
current_tokens = sentence_tokens
elif current_tokens + sentence_tokens > 512:
if current_para:
paragraphs.append(current_para)
current_para = sentence
current_tokens = sentence_tokens
else:
if current_para:
current_para += sentence
else:
current_para = sentence
current_tokens += sentence_tokens
if current_para:
paragraphs.append(current_para)
# 最后一次清理所有段落,确保没有冒号开头
cleaned_paragraphs = []
for para in paragraphs:
para = re.sub(r'^[:]+\s*', '', para.strip())
if para: # 只添加非空段落
cleaned_paragraphs.append(para)
return cleaned_paragraphs
def save_to_json(paragraphs, output_file):
"""将段落保存为JSON格式"""
data = {
"total_paragraphs": len(paragraphs),
"paragraphs": paragraphs
}
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print(f"成功将文本分成 {len(paragraphs)} 个段落并保存到 {output_file}")
def save_to_txt(paragraphs, output_file):
"""将段落保存为TXT格式每段用换行符分隔"""
with open(output_file, 'w', encoding='utf-8') as f:
for paragraph in paragraphs:
f.write(paragraph + '\n\n') # 使用两个换行符使段落分隔更清晰
print(f"成功将文本分成 {len(paragraphs)} 个段落并保存到 {output_file}")
def main():
parser = argparse.ArgumentParser(description="将连续文本智能分段并保存为TXT或JSON")
parser.add_argument("input_file", help="输入文件路径例如sample_continuous_text.txt")
parser.add_argument("--output", "-o", default="paragraphs.txt", help="输出文件路径,默认为当前目录下的 paragraphs.txt")
parser.add_argument("--format", "-f", choices=['txt', 'json'], default='txt', help="输出文件格式支持txt和json默认为txt")
args = parser.parse_args()
# 读取输入文件
try:
with open(args.input_file, 'r', encoding='utf-8') as f:
text = f.read()
except Exception as e:
print(f"读取文件出错: {e}")
return
# 分段
paragraphs = split_text_into_paragraphs(text)
# 根据指定格式保存
if args.format == 'json':
save_to_json(paragraphs, args.output)
else:
save_to_txt(paragraphs, args.output)
if __name__ == "__main__":
main()