File size: 5,977 Bytes
7756dfe
868518d
1b59f0f
49bc88a
 
cfb951e
a9e075d
49bc88a
8187a7a
75dc93c
6183cb0
a9e075d
6183cb0
 
1b59f0f
49bc88a
015c037
 
252fd1e
 
49bc88a
5a49e7f
49bc88a
 
 
 
a9e075d
5a49e7f
 
 
 
 
 
7260359
5a49e7f
 
a9e075d
5a49e7f
52111a1
5a49e7f
a9e075d
5a49e7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51359d6
8187a7a
 
 
 
51359d6
5a49e7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49bc88a
015c037
d99c445
868518d
a9e075d
931e174
40ef0e2
931e174
49bc88a
931e174
8187a7a
931e174
 
 
e2a605c
 
d99c445
 
 
49bc88a
1ce6f7d
931e174
 
 
 
 
 
 
 
 
 
 
 
 
 
75dc93c
 
 
a9e075d
49bc88a
931e174
 
a7ea7d8
931e174
868518d
49bc88a
080637e
a9e075d
 
dbbaf71
080637e
 
49bc88a
080637e
a9e075d
 
080637e
a9e075d
080637e
 
 
 
 
 
 
868518d
dbbaf71
49bc88a
a9e075d
5a49e7f
 
 
 
 
 
8187a7a
5a49e7f
de12dbe
 
 
 
 
 
 
 
 
 
a7ea7d8
 
 
 
 
de12dbe
49bc88a
 
a9e075d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import gradio as gr
import numpy as np
import os
import torch
import pandas as pd
import faiss
from huggingface_hub import hf_hub_download
import time
import sys
import json

# 创建安全缓存目录
CACHE_DIR = "/home/user/cache"
os.makedirs(CACHE_DIR, exist_ok=True)

# 减少内存占用
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
# torch.set_num_threads(1)
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # 防止tokenizer内存泄漏

# 全局变量 - 移除了clip_model和clip_processor
index = None
metadata = None

def load_resources():
    """加载所有必要资源(768维专用)"""
    # 清理残留锁文件
    lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')]
    for lock_file in lock_files:
        try:
            os.remove(os.path.join(CACHE_DIR, lock_file))
            print(f"🧹 清理锁文件: {lock_file}")
        except: pass
            
    global index, metadata
    
    # 仅当资源未加载时才初始化
    if index is None or metadata is None:
        print("🔄 正在加载所有资源...")
        
        # 加载FAISS索引(768维)
        if index is None:
            print("📥 正在下载FAISS索引...")
            try:
                INDEX_PATH = hf_hub_download(
                    repo_id="GOGO198/GOGO_rag_index",
                    filename="faiss_index.bin",
                    cache_dir=CACHE_DIR,
                    token=os.getenv("HF_TOKEN")
                )
                index = faiss.read_index(INDEX_PATH)
                
                # 验证索引维度
                if index.d != 768:
                    raise ValueError(f"❌ 索引维度错误:预期768维,实际{index.d}维")

                # if index and not index.is_trained:
                #     print("🔧 训练量化索引...")
                #     index.train(np.random.rand(10000, 768).astype('float32'))
                #     print("✅ 索引训练完成")
                    
                print(f"✅ FAISS索引加载完成 | 维度: {index.d}")
            except Exception as e:
                print(f"❌ FAISS索引加载失败: {str(e)}")
                raise
        
        # 加载元数据
        if metadata is None:
            print("📄 正在下载元数据...")
            try:
                METADATA_PATH = hf_hub_download(
                    repo_id="GOGO198/GOGO_rag_index",
                    filename="metadata.csv",
                    cache_dir=CACHE_DIR,
                    token=os.getenv("HF_TOKEN")
                )
                metadata = pd.read_csv(METADATA_PATH)
                print(f"✅ 元数据加载完成 | 记录数: {len(metadata)}")
            except Exception as e:
                print(f"❌ 元数据加载失败: {str(e)}")
                raise

from functools import lru_cache
@lru_cache(maxsize=100)
def predict(vector):
    """处理768维向量输入并返回答案"""
    start_time = time.time()
    print(f"输入向量维度: {np.array(vector).shape}")
    
    try:
        # 验证输入格式
        if not isinstance(vector, list) or len(vector) == 0:
            error_msg = "错误:输入格式无效"
            print(error_msg)
            return error_msg

        if len(vector) != 1 or len(vector[0]) != 768:
            error_msg = f"错误:需要1x768的二维数组,收到{len(vector)}x{len(vector[0]) if vector else 0}"
            print(error_msg)
            return error_msg
        
        # 添加实际处理逻辑
        vector_array = np.array(vector, dtype=np.float32)
        D, I = index.search(vector_array, k=3)
        
        results = []
        for i in range(3):
            try:
                result = metadata.iloc[I[0][i]]
                confidence = 1/(1+D[0][i])
                results.append(f"匹配结果 {i+1}: {result['source']} | 置信度: {confidence:.2f}")
            except Exception as e:
                print(f"结果处理错误: {str(e)}")
                results.append(f"结果 {i+1}: 数据获取失败")
        
        print(f"处理完成 | 耗时: {time.time()-start_time:.2f}秒")
        return json.dumps({
            "results": results  # 确保嵌套结构合法
        })
        
    except Exception as e:
        import traceback
        error_msg = f"处理错误: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        return "处理错误,请重试或联系管理员"

# 创建简化接口
with gr.Blocks() as demo:
    gr.Markdown("## 🛍 电商智能客服系统 (768维专用)")
    gr.Markdown("**使用CLIP-vit-large-patch14模型 | 向量维度: 768**")
    
    with gr.Row():
        vector_input = gr.Dataframe(
            headers=["向量值"], 
            type="array",
            label="输入向量 (768维)",
            value=[[0.1]*768]  # 768维默认值
        )
        output = gr.Textbox(label="智能回答", lines=5)
    
    submit_btn = gr.Button("生成回答")
    submit_btn.click(
        fn=predict, 
        inputs=vector_input, 
        outputs=output
    )

# 启动应用
if __name__ == "__main__":
    # 预加载资源
    if index is None or metadata is None:
        print("🚀 启动前预加载资源...")
        try:
            load_resources()
        except Exception as e:
            print(f"⛔ 资源加载失败: {str(e)}")
            sys.exit(1)

    # 确保缓存目录存在
    # import pathlib
    # pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)

    try:
        dummy_vector = [0.1] * 768
        predict([dummy_vector])
    except:
        pass

    print("="*50)
    print("Space启动完成 | 准备接收请求")
    print(f"索引维度: {index.d if index else '未加载'}")
    print(f"元数据记录: {len(metadata) if metadata is not None else 0}")
    print("="*50)
    
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860
    )