File size: 8,110 Bytes
1625bb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
from typing import List, Dict, Callable, Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
    DirectoryLoader,
    UnstructuredMarkdownLoader,
    PyPDFLoader,
    TextLoader
)
import os
import requests
import base64
from PIL import Image
import io

class DocumentLoader:
    """通用文档加载器"""
    def __init__(self, file_path: str):
        self.file_path = file_path
        self.extension = os.path.splitext(file_path)[1].lower()
        self.api_key = os.getenv("API_KEY")
        self.api_base = os.getenv("BASE_URL")
        
    def process_image(self, image_path: str) -> str:
        """使用 SiliconFlow VLM 模型处理图片"""
        try:
            # 读取图片并转换为base64
            with open(image_path, 'rb') as image_file:
                image_data = image_file.read()
                base64_image = base64.b64encode(image_data).decode('utf-8')
            
            # 调用 SiliconFlow API
            headers = {
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json"
            }
            
            response = requests.post(
                f"{self.api_base}/chat/completions",
                headers=headers,
                json={
                    "model": "Qwen/Qwen2.5-VL-72B-Instruct",
                    "messages": [
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "image_url",
                                    "image_url": {
                                        "url": f"data:image/jpeg;base64,{base64_image}",
                                        "detail": "high"
                                    }
                                },
                                {
                                    "type": "text",
                                    "text": "请详细描述这张图片的内容,包括主要对象、场景、活动、颜色、布局等关键信息。"
                                }
                            ]
                        }
                    ],
                    "temperature": 0.7,
                    "max_tokens": 500
                }
            )
            
            if response.status_code != 200:
                raise Exception(f"图片处理API调用失败: {response.text}")
                
            description = response.json()["choices"][0]["message"]["content"]
            return description
            
        except Exception as e:
            print(f"处理图片时出错: {str(e)}")
            return "图片处理失败"
    
    def load(self):
        try:
            if self.extension == '.md':
                loader = UnstructuredMarkdownLoader(self.file_path, encoding='utf-8')
                return loader.load()
            elif self.extension == '.pdf':
                loader = PyPDFLoader(self.file_path)
                return loader.load()
            elif self.extension == '.txt':
                loader = TextLoader(self.file_path, encoding='utf-8')
                return loader.load()
            elif self.extension in ['.png', '.jpg', '.jpeg', '.gif', '.bmp']:
                # 处理图片
                description = self.process_image(self.file_path)
                # 创建一个包含图片描述的文档
                from langchain.schema import Document
                doc = Document(
                    page_content=description,
                    metadata={
                        'source': self.file_path,
                        'img_url': os.path.abspath(self.file_path)  # 存储图片的绝对路径
                    }
                )
                return [doc]
            else:
                raise ValueError(f"不支持的文件格式: {self.extension}")
                
        except UnicodeDecodeError:
            # 如果 utf-8 失败,尝试 gbk
            if self.extension in ['.md', '.txt']:
                loader = TextLoader(self.file_path, encoding='gbk')
                return loader.load()
            raise

class DocumentProcessor:
    def __init__(self):
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200,
            length_function=len,
        )
        
    def get_index_name(self, path: str) -> str:
        """根据文件路径生成索引名称"""
        if os.path.isdir(path):
            # 如果是目录,使用目录名
            return f"rag_{os.path.basename(path).lower()}"
        else:
            # 如果是文件,使用文件名(不含扩展名)
            return f"rag_{os.path.splitext(os.path.basename(path))[0].lower()}"
        
    def process(self, path: str, progress_callback: Optional[Callable] = None) -> List[Dict]:
        """
        加载并处理文档,支持目录或单个文件
        参数:
            path: 文档路径
            progress_callback: 进度回调函数,用于报告处理进度
        返回:处理后的文档列表
        """
        if os.path.isdir(path):
            documents = []
            total_files = sum([len(files) for _, _, files in os.walk(path)])
            processed_files = 0
            processed_size = 0
            
            for root, _, files in os.walk(path):
                for file in files:
                    file_path = os.path.join(root, file)
                    try:
                        # 更新处理进度
                        if progress_callback:
                            file_size = os.path.getsize(file_path)
                            processed_size += file_size
                            processed_files += 1
                            progress_callback(processed_size, f"处理文件 {processed_files}/{total_files}: {file}")
                            
                        loader = DocumentLoader(file_path)
                        docs = loader.load()
                        # 添加文件名到metadata
                        for doc in docs:
                            doc.metadata['file_name'] = os.path.basename(file_path)
                        documents.extend(docs)
                    except Exception as e:
                        print(f"警告:加载文件 {file_path} 时出错: {str(e)}")
                        continue
        else:
            try:
                if progress_callback:
                    file_size = os.path.getsize(path)
                    progress_callback(file_size * 0.3, f"加载文件: {os.path.basename(path)}")
                    
                loader = DocumentLoader(path)
                documents = loader.load()
                
                # 更新进度
                if progress_callback:
                    progress_callback(file_size * 0.6, f"处理文件内容...")
                    
                # 添加文件名到metadata
                file_name = os.path.basename(path)
                for doc in documents:
                    doc.metadata['file_name'] = file_name
            except Exception as e:
                print(f"加载文件时出错: {str(e)}")
                raise
        
        # 分块
        chunks = self.text_splitter.split_documents(documents)
        
        # 更新进度
        if progress_callback:
            if os.path.isdir(path):
                progress_callback(processed_size, f"文档分块完成,共{len(chunks)}个文档片段")
            else:
                file_size = os.path.getsize(path)
                progress_callback(file_size * 0.9, f"文档分块完成,共{len(chunks)}个文档片段")
        
        # 处理成统一格式
        processed_docs = []
        for i, chunk in enumerate(chunks):
            processed_docs.append({
                'id': f'doc_{i}',
                'content': chunk.page_content,
                'metadata': chunk.metadata
            })
            
        return processed_docs