from typing import List, Dict, Callable, Optional from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import ( DirectoryLoader, UnstructuredMarkdownLoader, PyPDFLoader, TextLoader ) import os import requests import base64 from PIL import Image import io class DocumentLoader: """通用文档加载器""" def __init__(self, file_path: str, original_filename: str = None): self.file_path = file_path # 使用传入的原始文件名或者从路径提取的文件名 self.original_filename = original_filename or os.path.basename(file_path) # 从原始文件名中获取扩展名,确保中文文件名也能正确识别文件类型 self.extension = os.path.splitext(self.original_filename)[1].lower() self.api_key = os.getenv("API_KEY") self.api_base = os.getenv("BASE_URL") def process_image(self, image_path: str) -> str: """使用 SiliconFlow VLM 模型处理图片""" try: # 读取图片并转换为base64 with open(image_path, 'rb') as image_file: image_data = image_file.read() base64_image = base64.b64encode(image_data).decode('utf-8') # 调用 SiliconFlow API headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } response = requests.post( f"{self.api_base}/chat/completions", headers=headers, json={ "model": "Qwen/Qwen2.5-VL-72B-Instruct", "messages": [ { "role": "user", "content": [ { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{base64_image}", "detail": "high" } }, { "type": "text", "text": "请详细描述这张图片的内容,包括主要对象、场景、活动、颜色、布局等关键信息。" } ] } ], "temperature": 0.7, "max_tokens": 500 } ) if response.status_code != 200: raise Exception(f"图片处理API调用失败: {response.text}") description = response.json()["choices"][0]["message"]["content"] return description except Exception as e: print(f"处理图片时出错: {str(e)}") return "图片处理失败" def load(self): try: print(f"正在加载文件: {self.file_path}, 原始文件名: {self.original_filename}, 扩展名: {self.extension}") if self.extension == '.md': try: loader = UnstructuredMarkdownLoader(self.file_path, encoding='utf-8') return loader.load() except UnicodeDecodeError: # 如果UTF-8失败,尝试GBK loader = UnstructuredMarkdownLoader(self.file_path, encoding='gbk') return loader.load() elif self.extension == '.pdf': loader = PyPDFLoader(self.file_path) return loader.load() elif self.extension == '.txt': try: loader = TextLoader(self.file_path, encoding='utf-8') return loader.load() except UnicodeDecodeError: # 如果UTF-8失败,尝试GBK loader = TextLoader(self.file_path, encoding='gbk') return loader.load() elif self.extension in ['.png', '.jpg', '.jpeg', '.gif', '.bmp']: # 处理图片 description = self.process_image(self.file_path) # 创建一个包含图片描述的文档 from langchain.schema import Document doc = Document( page_content=description, metadata={ 'source': self.file_path, 'file_name': self.original_filename, # 使用原始文件名 'img_url': os.path.abspath(self.file_path) # 存储图片的绝对路径 } ) return [doc] else: print(f"不支持的文件扩展名: {self.extension}") raise ValueError(f"不支持的文件格式: {self.extension}") except UnicodeDecodeError: # 如果默认编码处理失败,尝试其他编码 print(f"文件编码错误,尝试其他编码: {self.file_path}") if self.extension in ['.md', '.txt']: try: loader = TextLoader(self.file_path, encoding='gbk') return loader.load() except Exception as e: print(f"尝试GBK编码也失败: {str(e)}") raise except Exception as e: print(f"加载文件 {self.file_path} 时出错: {str(e)}") import traceback traceback.print_exc() raise class DocumentProcessor: def __init__(self): self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200, length_function=len, ) def get_index_name(self, path: str) -> str: """根据文件路径生成索引名称""" if os.path.isdir(path): # 如果是目录,使用目录名 return f"rag_{os.path.basename(path).lower()}" else: # 如果是文件,使用文件名(不含扩展名) return f"rag_{os.path.splitext(os.path.basename(path))[0].lower()}" def process(self, path: str, progress_callback: Optional[Callable] = None, original_filename: str = None) -> List[Dict]: """ 加载并处理文档,支持目录或单个文件 参数: path: 文档路径 progress_callback: 进度回调函数,用于报告处理进度 original_filename: 原始文件名(包括中文) 返回:处理后的文档列表 """ if os.path.isdir(path): documents = [] total_files = sum([len(files) for _, _, files in os.walk(path)]) processed_files = 0 processed_size = 0 for root, _, files in os.walk(path): for file in files: file_path = os.path.join(root, file) try: # 更新处理进度 if progress_callback: file_size = os.path.getsize(file_path) processed_size += file_size processed_files += 1 progress_callback(processed_size, f"处理文件 {processed_files}/{total_files}: {file}") # 为目录中的每个文件,传递原始文件名 loader = DocumentLoader(file_path, original_filename=file) docs = loader.load() # 添加文件名到metadata for doc in docs: doc.metadata['file_name'] = file # 使用原始文件名 documents.extend(docs) except Exception as e: print(f"警告:加载文件 {file_path} 时出错: {str(e)}") continue else: try: if progress_callback: file_size = os.path.getsize(path) progress_callback(file_size * 0.3, f"加载文件: {original_filename or os.path.basename(path)}") # 为单个文件,传递原始文件名 loader = DocumentLoader(path, original_filename=original_filename) documents = loader.load() # 更新进度 if progress_callback: progress_callback(file_size * 0.6, f"处理文件内容...") # 使用原始文件名而不是存储的文件名 file_name = original_filename or os.path.basename(path) for doc in documents: doc.metadata['file_name'] = file_name except Exception as e: print(f"加载文件时出错: {str(e)}") raise # 分块 chunks = self.text_splitter.split_documents(documents) # 更新进度 if progress_callback: if os.path.isdir(path): progress_callback(processed_size, f"文档分块完成,共{len(chunks)}个文档片段") else: file_size = os.path.getsize(path) progress_callback(file_size * 0.9, f"文档分块完成,共{len(chunks)}个文档片段") # 处理成统一格式 processed_docs = [] for i, chunk in enumerate(chunks): processed_docs.append({ 'id': f'doc_{i}', 'content': chunk.page_content, 'metadata': chunk.metadata }) return processed_docs