File size: 8,110 Bytes
1625bb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
from typing import List, Dict, Callable, Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
DirectoryLoader,
UnstructuredMarkdownLoader,
PyPDFLoader,
TextLoader
)
import os
import requests
import base64
from PIL import Image
import io
class DocumentLoader:
"""通用文档加载器"""
def __init__(self, file_path: str):
self.file_path = file_path
self.extension = os.path.splitext(file_path)[1].lower()
self.api_key = os.getenv("API_KEY")
self.api_base = os.getenv("BASE_URL")
def process_image(self, image_path: str) -> str:
"""使用 SiliconFlow VLM 模型处理图片"""
try:
# 读取图片并转换为base64
with open(image_path, 'rb') as image_file:
image_data = image_file.read()
base64_image = base64.b64encode(image_data).decode('utf-8')
# 调用 SiliconFlow API
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(
f"{self.api_base}/chat/completions",
headers=headers,
json={
"model": "Qwen/Qwen2.5-VL-72B-Instruct",
"messages": [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": "high"
}
},
{
"type": "text",
"text": "请详细描述这张图片的内容,包括主要对象、场景、活动、颜色、布局等关键信息。"
}
]
}
],
"temperature": 0.7,
"max_tokens": 500
}
)
if response.status_code != 200:
raise Exception(f"图片处理API调用失败: {response.text}")
description = response.json()["choices"][0]["message"]["content"]
return description
except Exception as e:
print(f"处理图片时出错: {str(e)}")
return "图片处理失败"
def load(self):
try:
if self.extension == '.md':
loader = UnstructuredMarkdownLoader(self.file_path, encoding='utf-8')
return loader.load()
elif self.extension == '.pdf':
loader = PyPDFLoader(self.file_path)
return loader.load()
elif self.extension == '.txt':
loader = TextLoader(self.file_path, encoding='utf-8')
return loader.load()
elif self.extension in ['.png', '.jpg', '.jpeg', '.gif', '.bmp']:
# 处理图片
description = self.process_image(self.file_path)
# 创建一个包含图片描述的文档
from langchain.schema import Document
doc = Document(
page_content=description,
metadata={
'source': self.file_path,
'img_url': os.path.abspath(self.file_path) # 存储图片的绝对路径
}
)
return [doc]
else:
raise ValueError(f"不支持的文件格式: {self.extension}")
except UnicodeDecodeError:
# 如果 utf-8 失败,尝试 gbk
if self.extension in ['.md', '.txt']:
loader = TextLoader(self.file_path, encoding='gbk')
return loader.load()
raise
class DocumentProcessor:
def __init__(self):
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
)
def get_index_name(self, path: str) -> str:
"""根据文件路径生成索引名称"""
if os.path.isdir(path):
# 如果是目录,使用目录名
return f"rag_{os.path.basename(path).lower()}"
else:
# 如果是文件,使用文件名(不含扩展名)
return f"rag_{os.path.splitext(os.path.basename(path))[0].lower()}"
def process(self, path: str, progress_callback: Optional[Callable] = None) -> List[Dict]:
"""
加载并处理文档,支持目录或单个文件
参数:
path: 文档路径
progress_callback: 进度回调函数,用于报告处理进度
返回:处理后的文档列表
"""
if os.path.isdir(path):
documents = []
total_files = sum([len(files) for _, _, files in os.walk(path)])
processed_files = 0
processed_size = 0
for root, _, files in os.walk(path):
for file in files:
file_path = os.path.join(root, file)
try:
# 更新处理进度
if progress_callback:
file_size = os.path.getsize(file_path)
processed_size += file_size
processed_files += 1
progress_callback(processed_size, f"处理文件 {processed_files}/{total_files}: {file}")
loader = DocumentLoader(file_path)
docs = loader.load()
# 添加文件名到metadata
for doc in docs:
doc.metadata['file_name'] = os.path.basename(file_path)
documents.extend(docs)
except Exception as e:
print(f"警告:加载文件 {file_path} 时出错: {str(e)}")
continue
else:
try:
if progress_callback:
file_size = os.path.getsize(path)
progress_callback(file_size * 0.3, f"加载文件: {os.path.basename(path)}")
loader = DocumentLoader(path)
documents = loader.load()
# 更新进度
if progress_callback:
progress_callback(file_size * 0.6, f"处理文件内容...")
# 添加文件名到metadata
file_name = os.path.basename(path)
for doc in documents:
doc.metadata['file_name'] = file_name
except Exception as e:
print(f"加载文件时出错: {str(e)}")
raise
# 分块
chunks = self.text_splitter.split_documents(documents)
# 更新进度
if progress_callback:
if os.path.isdir(path):
progress_callback(processed_size, f"文档分块完成,共{len(chunks)}个文档片段")
else:
file_size = os.path.getsize(path)
progress_callback(file_size * 0.9, f"文档分块完成,共{len(chunks)}个文档片段")
# 处理成统一格式
processed_docs = []
for i, chunk in enumerate(chunks):
processed_docs.append({
'id': f'doc_{i}',
'content': chunk.page_content,
'metadata': chunk.metadata
})
return processed_docs |