samlax12's picture
Upload 29 files
9e00cc6 verified
from typing import List, Dict, Callable, Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
DirectoryLoader,
UnstructuredMarkdownLoader,
PyPDFLoader,
TextLoader
)
import os
import requests
import base64
from PIL import Image
import io
class DocumentLoader:
"""通用文档加载器"""
def __init__(self, file_path: str, original_filename: str = None):
self.file_path = file_path
# 使用传入的原始文件名或者从路径提取的文件名
self.original_filename = original_filename or os.path.basename(file_path)
# 从原始文件名中获取扩展名,确保中文文件名也能正确识别文件类型
self.extension = os.path.splitext(self.original_filename)[1].lower()
self.api_key = os.getenv("API_KEY")
self.api_base = os.getenv("BASE_URL")
def process_image(self, image_path: str) -> str:
"""使用 SiliconFlow VLM 模型处理图片"""
try:
# 读取图片并转换为base64
with open(image_path, 'rb') as image_file:
image_data = image_file.read()
base64_image = base64.b64encode(image_data).decode('utf-8')
# 调用 SiliconFlow API
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(
f"{self.api_base}/chat/completions",
headers=headers,
json={
"model": "Qwen/Qwen2.5-VL-72B-Instruct",
"messages": [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": "high"
}
},
{
"type": "text",
"text": "请详细描述这张图片的内容,包括主要对象、场景、活动、颜色、布局等关键信息。"
}
]
}
],
"temperature": 0.7,
"max_tokens": 500
}
)
if response.status_code != 200:
raise Exception(f"图片处理API调用失败: {response.text}")
description = response.json()["choices"][0]["message"]["content"]
return description
except Exception as e:
print(f"处理图片时出错: {str(e)}")
return "图片处理失败"
def load(self):
try:
print(f"正在加载文件: {self.file_path}, 原始文件名: {self.original_filename}, 扩展名: {self.extension}")
if self.extension == '.md':
try:
loader = UnstructuredMarkdownLoader(self.file_path, encoding='utf-8')
return loader.load()
except UnicodeDecodeError:
# 如果UTF-8失败,尝试GBK
loader = UnstructuredMarkdownLoader(self.file_path, encoding='gbk')
return loader.load()
elif self.extension == '.pdf':
loader = PyPDFLoader(self.file_path)
return loader.load()
elif self.extension == '.txt':
try:
loader = TextLoader(self.file_path, encoding='utf-8')
return loader.load()
except UnicodeDecodeError:
# 如果UTF-8失败,尝试GBK
loader = TextLoader(self.file_path, encoding='gbk')
return loader.load()
elif self.extension in ['.png', '.jpg', '.jpeg', '.gif', '.bmp']:
# 处理图片
description = self.process_image(self.file_path)
# 创建一个包含图片描述的文档
from langchain.schema import Document
doc = Document(
page_content=description,
metadata={
'source': self.file_path,
'file_name': self.original_filename, # 使用原始文件名
'img_url': os.path.abspath(self.file_path) # 存储图片的绝对路径
}
)
return [doc]
else:
print(f"不支持的文件扩展名: {self.extension}")
raise ValueError(f"不支持的文件格式: {self.extension}")
except UnicodeDecodeError:
# 如果默认编码处理失败,尝试其他编码
print(f"文件编码错误,尝试其他编码: {self.file_path}")
if self.extension in ['.md', '.txt']:
try:
loader = TextLoader(self.file_path, encoding='gbk')
return loader.load()
except Exception as e:
print(f"尝试GBK编码也失败: {str(e)}")
raise
except Exception as e:
print(f"加载文件 {self.file_path} 时出错: {str(e)}")
import traceback
traceback.print_exc()
raise
class DocumentProcessor:
def __init__(self):
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
)
def get_index_name(self, path: str) -> str:
"""根据文件路径生成索引名称"""
if os.path.isdir(path):
# 如果是目录,使用目录名
return f"rag_{os.path.basename(path).lower()}"
else:
# 如果是文件,使用文件名(不含扩展名)
return f"rag_{os.path.splitext(os.path.basename(path))[0].lower()}"
def process(self, path: str, progress_callback: Optional[Callable] = None, original_filename: str = None) -> List[Dict]:
"""
加载并处理文档,支持目录或单个文件
参数:
path: 文档路径
progress_callback: 进度回调函数,用于报告处理进度
original_filename: 原始文件名(包括中文)
返回:处理后的文档列表
"""
if os.path.isdir(path):
documents = []
total_files = sum([len(files) for _, _, files in os.walk(path)])
processed_files = 0
processed_size = 0
for root, _, files in os.walk(path):
for file in files:
file_path = os.path.join(root, file)
try:
# 更新处理进度
if progress_callback:
file_size = os.path.getsize(file_path)
processed_size += file_size
processed_files += 1
progress_callback(processed_size, f"处理文件 {processed_files}/{total_files}: {file}")
# 为目录中的每个文件,传递原始文件名
loader = DocumentLoader(file_path, original_filename=file)
docs = loader.load()
# 添加文件名到metadata
for doc in docs:
doc.metadata['file_name'] = file # 使用原始文件名
documents.extend(docs)
except Exception as e:
print(f"警告:加载文件 {file_path} 时出错: {str(e)}")
continue
else:
try:
if progress_callback:
file_size = os.path.getsize(path)
progress_callback(file_size * 0.3, f"加载文件: {original_filename or os.path.basename(path)}")
# 为单个文件,传递原始文件名
loader = DocumentLoader(path, original_filename=original_filename)
documents = loader.load()
# 更新进度
if progress_callback:
progress_callback(file_size * 0.6, f"处理文件内容...")
# 使用原始文件名而不是存储的文件名
file_name = original_filename or os.path.basename(path)
for doc in documents:
doc.metadata['file_name'] = file_name
except Exception as e:
print(f"加载文件时出错: {str(e)}")
raise
# 分块
chunks = self.text_splitter.split_documents(documents)
# 更新进度
if progress_callback:
if os.path.isdir(path):
progress_callback(processed_size, f"文档分块完成,共{len(chunks)}个文档片段")
else:
file_size = os.path.getsize(path)
progress_callback(file_size * 0.9, f"文档分块完成,共{len(chunks)}个文档片段")
# 处理成统一格式
processed_docs = []
for i, chunk in enumerate(chunks):
processed_docs.append({
'id': f'doc_{i}',
'content': chunk.page_content,
'metadata': chunk.metadata
})
return processed_docs