File size: 5,880 Bytes
7756dfe
868518d
1b59f0f
49bc88a
cfb951e
a9e075d
49bc88a
75dc93c
223da69
 
 
 
 
 
6183cb0
a9e075d
6183cb0
 
1b59f0f
49bc88a
252fd1e
e931572
49bc88a
9e5ab6c
49bc88a
 
 
 
a9e075d
e931572
 
5a49e7f
 
 
 
 
 
ffcb61e
 
e931572
52111a1
5a49e7f
e931572
5a49e7f
 
 
 
 
 
 
 
 
 
 
 
 
e931572
 
5a49e7f
 
 
 
e931572
5a49e7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49bc88a
c3d3c82
 
 
868518d
a9e075d
931e174
dbdb4ab
e931572
49bc88a
2e36b27
dbdb4ab
 
 
 
c3d3c82
 
dbdb4ab
 
931e174
2e36b27
e2a605c
 
d99c445
 
2e36b27
e931572
ffcb61e
931e174
 
e931572
931e174
 
 
 
 
2e36b27
 
 
 
931e174
 
2e36b27
e931572
931e174
dbdb4ab
2e36b27
 
 
 
 
a9e075d
49bc88a
931e174
 
a7ea7d8
2e36b27
 
 
 
 
223da69
2e36b27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868518d
c3d3c82
49bc88a
c3d3c82
a7ea7d8
 
c3d3c82
 
a7ea7d8
c3d3c82
223da69
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import gradio as gr
import numpy as np
import os
import pandas as pd
import faiss
from huggingface_hub import hf_hub_download
import time
import json
import fastapi
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import threading

# 创建安全缓存目录
CACHE_DIR = "/home/user/cache"
os.makedirs(CACHE_DIR, exist_ok=True)

# 减少内存占用
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# 全局变量
index = None
metadata = None

def load_resources():
    """加载所有必要资源(768维专用)"""
    global index, metadata

    # 清理残留锁文件
    lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')]
    for lock_file in lock_files:
        try:
            os.remove(os.path.join(CACHE_DIR, lock_file))
            print(f"🧹 清理锁文件: {lock_file}")
        except:
            pass

    if index is None or metadata is None:
        print("🔄 正在加载所有资源...")

        # 加载FAISS索引(768维)
        if index is None:
            print("📥 正在下载FAISS索引...")
            try:
                INDEX_PATH = hf_hub_download(
                    repo_id="GOGO198/GOGO_rag_index",
                    filename="faiss_index.bin",
                    cache_dir=CACHE_DIR,
                    token=os.getenv("HF_TOKEN")
                )
                index = faiss.read_index(INDEX_PATH)
                
                if index.d != 768:
                    raise ValueError("❌ 索引维度错误:预期768维")
                
                print(f"✅ FAISS索引加载完成 | 维度: {index.d}")
            except Exception as e:
                print(f"❌ FAISS索引加载失败: {str(e)}")
                raise

        # 加载元数据
        if metadata is None:
            print("📄 正在下载元数据...")
            try:
                METADATA_PATH = hf_hub_download(
                    repo_id="GOGO198/GOGO_rag_index",
                    filename="metadata.csv",
                    cache_dir=CACHE_DIR,
                    token=os.getenv("HF_TOKEN")
                )
                metadata = pd.read_csv(METADATA_PATH)
                print(f"✅ 元数据加载完成 | 记录数: {len(metadata)}")
            except Exception as e:
                print(f"❌ 元数据加载失败: {str(e)}")
                raise

# 确保资源在API调用前加载
load_resources()

def predict(vector):
    """处理768维向量输入并返回答案"""
    start_time = time.time()
    print(f"输入向量类型: {type(vector)}, 长度: {len(vector) if hasattr(vector, '__len__') else 'N/A'}")

    try:
        # 处理不同输入格式
        if isinstance(vector, list) and len(vector) > 0 and isinstance(vector[0], list):
            # 已经是正确的二维数组格式
            pass
        elif isinstance(vector, list) and all(isinstance(x, float) for x in vector):
            # 一维列表输入 (API调用)
            vector = [vector]
        else:
            error_msg = "错误:输入格式无效"
            print(error_msg)
            return {"error": error_msg}

        if len(vector) != 1 or len(vector[0]) != 768:
            error_msg = f"错误:需要1x768的二维数组,收到{len(vector)}x{len(vector[0]) if vector else 0}"
            print(error_msg)
            return {"error": error_msg}

        # 添加实际处理逻辑
        vector_array = np.array(vector, dtype=np.float32)
        D, I = index.search(vector_array, k=3)

        results = []
        for i in range(3):
            try:
                result = metadata.iloc[I[0][i]]
                confidence = 1/(1+D[0][i])
                results.append({
                    "source": result['source'],
                    "confidence": round(float(confidence), 2)
                })
            except Exception as e:
                print(f"结果处理错误: {str(e)}")
                results.append({"error": f"结果 {i+1}: 数据获取失败"})

        print(f"处理完成 | 耗时: {time.time()-start_time:.2f}秒")
        
        # 返回标准化的JSON格式
        return {
            "status": "success",
            "results": results
        }
        
    except Exception as e:
        import traceback
        error_msg = f"处理错误: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        return {
            "status": "error",
            "message": "处理错误,请重试或联系管理员"
        }

# 创建FastAPI应用
app = FastAPI()

# 添加CORS支持
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.post("/predict")
async def api_predict(request: Request):
    """API预测端点"""
    try:
        data = await request.json()
        vector = data.get("vector")
        
        if not vector or not isinstance(vector, list):
            return JSONResponse(
                status_code=400,
                content={"status": "error", "message": "无效输入: 需要向量数据"}
            )
            
        result = predict(vector)
        return JSONResponse(content=result)
        
    except Exception as e:
        return JSONResponse(
            status_code=500,
            content={
                "status": "error",
                "message": f"服务器内部错误: {str(e)}"
            }
        )

# 启动应用
if __name__ == "__main__":
    # 验证资源
    print("="*50)
    print("Space启动完成 | 准备接收请求")
    print(f"索引维度: {index.d}")
    print(f"元数据记录数: {len(metadata)}")
    print("="*50)
    
    # 只启动FastAPI服务
    uvicorn.run(app, host="0.0.0.0", port=7860)