File size: 5,880 Bytes
7756dfe 868518d 1b59f0f 49bc88a cfb951e a9e075d 49bc88a 75dc93c 223da69 6183cb0 a9e075d 6183cb0 1b59f0f 49bc88a 252fd1e e931572 49bc88a 9e5ab6c 49bc88a a9e075d e931572 5a49e7f ffcb61e e931572 52111a1 5a49e7f e931572 5a49e7f e931572 5a49e7f e931572 5a49e7f 49bc88a c3d3c82 868518d a9e075d 931e174 dbdb4ab e931572 49bc88a 2e36b27 dbdb4ab c3d3c82 dbdb4ab 931e174 2e36b27 e2a605c d99c445 2e36b27 e931572 ffcb61e 931e174 e931572 931e174 2e36b27 931e174 2e36b27 e931572 931e174 dbdb4ab 2e36b27 a9e075d 49bc88a 931e174 a7ea7d8 2e36b27 223da69 2e36b27 868518d c3d3c82 49bc88a c3d3c82 a7ea7d8 c3d3c82 a7ea7d8 c3d3c82 223da69 |
|
import gradio as gr
import numpy as np
import os
import pandas as pd
import faiss
from huggingface_hub import hf_hub_download
import time
import json
import fastapi
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import threading
# 创建安全缓存目录
CACHE_DIR = "/home/user/cache"
os.makedirs(CACHE_DIR, exist_ok=True)
# 减少内存占用
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# 全局变量
index = None
metadata = None
def load_resources():
"""加载所有必要资源(768维专用)"""
global index, metadata
# 清理残留锁文件
lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')]
for lock_file in lock_files:
try:
os.remove(os.path.join(CACHE_DIR, lock_file))
print(f"🧹 清理锁文件: {lock_file}")
except:
pass
if index is None or metadata is None:
print("🔄 正在加载所有资源...")
# 加载FAISS索引(768维)
if index is None:
print("📥 正在下载FAISS索引...")
try:
INDEX_PATH = hf_hub_download(
repo_id="GOGO198/GOGO_rag_index",
filename="faiss_index.bin",
cache_dir=CACHE_DIR,
token=os.getenv("HF_TOKEN")
)
index = faiss.read_index(INDEX_PATH)
if index.d != 768:
raise ValueError("❌ 索引维度错误:预期768维")
print(f"✅ FAISS索引加载完成 | 维度: {index.d}")
except Exception as e:
print(f"❌ FAISS索引加载失败: {str(e)}")
raise
# 加载元数据
if metadata is None:
print("📄 正在下载元数据...")
try:
METADATA_PATH = hf_hub_download(
repo_id="GOGO198/GOGO_rag_index",
filename="metadata.csv",
cache_dir=CACHE_DIR,
token=os.getenv("HF_TOKEN")
)
metadata = pd.read_csv(METADATA_PATH)
print(f"✅ 元数据加载完成 | 记录数: {len(metadata)}")
except Exception as e:
print(f"❌ 元数据加载失败: {str(e)}")
raise
# 确保资源在API调用前加载
load_resources()
def predict(vector):
"""处理768维向量输入并返回答案"""
start_time = time.time()
print(f"输入向量类型: {type(vector)}, 长度: {len(vector) if hasattr(vector, '__len__') else 'N/A'}")
try:
# 处理不同输入格式
if isinstance(vector, list) and len(vector) > 0 and isinstance(vector[0], list):
# 已经是正确的二维数组格式
pass
elif isinstance(vector, list) and all(isinstance(x, float) for x in vector):
# 一维列表输入 (API调用)
vector = [vector]
else:
error_msg = "错误:输入格式无效"
print(error_msg)
return {"error": error_msg}
if len(vector) != 1 or len(vector[0]) != 768:
error_msg = f"错误:需要1x768的二维数组,收到{len(vector)}x{len(vector[0]) if vector else 0}"
print(error_msg)
return {"error": error_msg}
# 添加实际处理逻辑
vector_array = np.array(vector, dtype=np.float32)
D, I = index.search(vector_array, k=3)
results = []
for i in range(3):
try:
result = metadata.iloc[I[0][i]]
confidence = 1/(1+D[0][i])
results.append({
"source": result['source'],
"confidence": round(float(confidence), 2)
})
except Exception as e:
print(f"结果处理错误: {str(e)}")
results.append({"error": f"结果 {i+1}: 数据获取失败"})
print(f"处理完成 | 耗时: {time.time()-start_time:.2f}秒")
# 返回标准化的JSON格式
return {
"status": "success",
"results": results
}
except Exception as e:
import traceback
error_msg = f"处理错误: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return {
"status": "error",
"message": "处理错误,请重试或联系管理员"
}
# 创建FastAPI应用
app = FastAPI()
# 添加CORS支持
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/predict")
async def api_predict(request: Request):
"""API预测端点"""
try:
data = await request.json()
vector = data.get("vector")
if not vector or not isinstance(vector, list):
return JSONResponse(
status_code=400,
content={"status": "error", "message": "无效输入: 需要向量数据"}
)
result = predict(vector)
return JSONResponse(content=result)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"status": "error",
"message": f"服务器内部错误: {str(e)}"
}
)
# 启动应用
if __name__ == "__main__":
# 验证资源
print("="*50)
print("Space启动完成 | 准备接收请求")
print(f"索引维度: {index.d}")
print(f"元数据记录数: {len(metadata)}")
print("="*50)
# 只启动FastAPI服务
uvicorn.run(app, host="0.0.0.0", port=7860) |