File size: 5,880 Bytes
7756dfe 868518d 1b59f0f 49bc88a cfb951e a9e075d 49bc88a 75dc93c 223da69 6183cb0 a9e075d 6183cb0 1b59f0f 49bc88a 252fd1e e931572 49bc88a 9e5ab6c 49bc88a a9e075d e931572 5a49e7f ffcb61e e931572 52111a1 5a49e7f e931572 5a49e7f e931572 5a49e7f e931572 5a49e7f 49bc88a c3d3c82 868518d a9e075d 931e174 dbdb4ab e931572 49bc88a 2e36b27 dbdb4ab c3d3c82 dbdb4ab 931e174 2e36b27 e2a605c d99c445 2e36b27 e931572 ffcb61e 931e174 e931572 931e174 2e36b27 931e174 2e36b27 e931572 931e174 dbdb4ab 2e36b27 a9e075d 49bc88a 931e174 a7ea7d8 2e36b27 223da69 2e36b27 868518d c3d3c82 49bc88a c3d3c82 a7ea7d8 c3d3c82 a7ea7d8 c3d3c82 223da69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import gradio as gr
import numpy as np
import os
import pandas as pd
import faiss
from huggingface_hub import hf_hub_download
import time
import json
import fastapi
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import threading
# 创建安全缓存目录
CACHE_DIR = "/home/user/cache"
os.makedirs(CACHE_DIR, exist_ok=True)
# 减少内存占用
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# 全局变量
index = None
metadata = None
def load_resources():
"""加载所有必要资源(768维专用)"""
global index, metadata
# 清理残留锁文件
lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')]
for lock_file in lock_files:
try:
os.remove(os.path.join(CACHE_DIR, lock_file))
print(f"🧹 清理锁文件: {lock_file}")
except:
pass
if index is None or metadata is None:
print("🔄 正在加载所有资源...")
# 加载FAISS索引(768维)
if index is None:
print("📥 正在下载FAISS索引...")
try:
INDEX_PATH = hf_hub_download(
repo_id="GOGO198/GOGO_rag_index",
filename="faiss_index.bin",
cache_dir=CACHE_DIR,
token=os.getenv("HF_TOKEN")
)
index = faiss.read_index(INDEX_PATH)
if index.d != 768:
raise ValueError("❌ 索引维度错误:预期768维")
print(f"✅ FAISS索引加载完成 | 维度: {index.d}")
except Exception as e:
print(f"❌ FAISS索引加载失败: {str(e)}")
raise
# 加载元数据
if metadata is None:
print("📄 正在下载元数据...")
try:
METADATA_PATH = hf_hub_download(
repo_id="GOGO198/GOGO_rag_index",
filename="metadata.csv",
cache_dir=CACHE_DIR,
token=os.getenv("HF_TOKEN")
)
metadata = pd.read_csv(METADATA_PATH)
print(f"✅ 元数据加载完成 | 记录数: {len(metadata)}")
except Exception as e:
print(f"❌ 元数据加载失败: {str(e)}")
raise
# 确保资源在API调用前加载
load_resources()
def predict(vector):
"""处理768维向量输入并返回答案"""
start_time = time.time()
print(f"输入向量类型: {type(vector)}, 长度: {len(vector) if hasattr(vector, '__len__') else 'N/A'}")
try:
# 处理不同输入格式
if isinstance(vector, list) and len(vector) > 0 and isinstance(vector[0], list):
# 已经是正确的二维数组格式
pass
elif isinstance(vector, list) and all(isinstance(x, float) for x in vector):
# 一维列表输入 (API调用)
vector = [vector]
else:
error_msg = "错误:输入格式无效"
print(error_msg)
return {"error": error_msg}
if len(vector) != 1 or len(vector[0]) != 768:
error_msg = f"错误:需要1x768的二维数组,收到{len(vector)}x{len(vector[0]) if vector else 0}"
print(error_msg)
return {"error": error_msg}
# 添加实际处理逻辑
vector_array = np.array(vector, dtype=np.float32)
D, I = index.search(vector_array, k=3)
results = []
for i in range(3):
try:
result = metadata.iloc[I[0][i]]
confidence = 1/(1+D[0][i])
results.append({
"source": result['source'],
"confidence": round(float(confidence), 2)
})
except Exception as e:
print(f"结果处理错误: {str(e)}")
results.append({"error": f"结果 {i+1}: 数据获取失败"})
print(f"处理完成 | 耗时: {time.time()-start_time:.2f}秒")
# 返回标准化的JSON格式
return {
"status": "success",
"results": results
}
except Exception as e:
import traceback
error_msg = f"处理错误: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return {
"status": "error",
"message": "处理错误,请重试或联系管理员"
}
# 创建FastAPI应用
app = FastAPI()
# 添加CORS支持
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/predict")
async def api_predict(request: Request):
"""API预测端点"""
try:
data = await request.json()
vector = data.get("vector")
if not vector or not isinstance(vector, list):
return JSONResponse(
status_code=400,
content={"status": "error", "message": "无效输入: 需要向量数据"}
)
result = predict(vector)
return JSONResponse(content=result)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"status": "error",
"message": f"服务器内部错误: {str(e)}"
}
)
# 启动应用
if __name__ == "__main__":
# 验证资源
print("="*50)
print("Space启动完成 | 准备接收请求")
print(f"索引维度: {index.d}")
print(f"元数据记录数: {len(metadata)}")
print("="*50)
# 只启动FastAPI服务
uvicorn.run(app, host="0.0.0.0", port=7860) |