GOGO_rag / app.py
GOGO198's picture
Update app.py
75dc93c verified
raw
history blame
5.98 kB
import gradio as gr
import numpy as np
import os
import torch
import pandas as pd
import faiss
from huggingface_hub import hf_hub_download
import time
import sys
import json
# 创建安全缓存目录
CACHE_DIR = "/home/user/cache"
os.makedirs(CACHE_DIR, exist_ok=True)
# 减少内存占用
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
# torch.set_num_threads(1)
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["TOKENIZERS_PARALLELISM"] = "false" # 防止tokenizer内存泄漏
# 全局变量 - 移除了clip_model和clip_processor
index = None
metadata = None
def load_resources():
"""加载所有必要资源(768维专用)"""
# 清理残留锁文件
lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')]
for lock_file in lock_files:
try:
os.remove(os.path.join(CACHE_DIR, lock_file))
print(f"🧹 清理锁文件: {lock_file}")
except: pass
global index, metadata
# 仅当资源未加载时才初始化
if index is None or metadata is None:
print("🔄 正在加载所有资源...")
# 加载FAISS索引(768维)
if index is None:
print("📥 正在下载FAISS索引...")
try:
INDEX_PATH = hf_hub_download(
repo_id="GOGO198/GOGO_rag_index",
filename="faiss_index.bin",
cache_dir=CACHE_DIR,
token=os.getenv("HF_TOKEN")
)
index = faiss.read_index(INDEX_PATH)
# 验证索引维度
if index.d != 768:
raise ValueError(f"❌ 索引维度错误:预期768维,实际{index.d}维")
# if index and not index.is_trained:
# print("🔧 训练量化索引...")
# index.train(np.random.rand(10000, 768).astype('float32'))
# print("✅ 索引训练完成")
print(f"✅ FAISS索引加载完成 | 维度: {index.d}")
except Exception as e:
print(f"❌ FAISS索引加载失败: {str(e)}")
raise
# 加载元数据
if metadata is None:
print("📄 正在下载元数据...")
try:
METADATA_PATH = hf_hub_download(
repo_id="GOGO198/GOGO_rag_index",
filename="metadata.csv",
cache_dir=CACHE_DIR,
token=os.getenv("HF_TOKEN")
)
metadata = pd.read_csv(METADATA_PATH)
print(f"✅ 元数据加载完成 | 记录数: {len(metadata)}")
except Exception as e:
print(f"❌ 元数据加载失败: {str(e)}")
raise
from functools import lru_cache
@lru_cache(maxsize=100)
def predict(vector):
"""处理768维向量输入并返回答案"""
start_time = time.time()
print(f"输入向量维度: {np.array(vector).shape}")
try:
# 验证输入格式
if not isinstance(vector, list) or len(vector) == 0:
error_msg = "错误:输入格式无效"
print(error_msg)
return error_msg
if len(vector) != 1 or len(vector[0]) != 768:
error_msg = f"错误:需要1x768的二维数组,收到{len(vector)}x{len(vector[0]) if vector else 0}"
print(error_msg)
return error_msg
# 添加实际处理逻辑
vector_array = np.array(vector, dtype=np.float32)
D, I = index.search(vector_array, k=3)
results = []
for i in range(3):
try:
result = metadata.iloc[I[0][i]]
confidence = 1/(1+D[0][i])
results.append(f"匹配结果 {i+1}: {result['source']} | 置信度: {confidence:.2f}")
except Exception as e:
print(f"结果处理错误: {str(e)}")
results.append(f"结果 {i+1}: 数据获取失败")
print(f"处理完成 | 耗时: {time.time()-start_time:.2f}秒")
return json.dumps({
"results": results # 确保嵌套结构合法
})
except Exception as e:
import traceback
error_msg = f"处理错误: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return "处理错误,请重试或联系管理员"
# 创建简化接口
with gr.Blocks() as demo:
gr.Markdown("## 🛍 电商智能客服系统 (768维专用)")
gr.Markdown("**使用CLIP-vit-large-patch14模型 | 向量维度: 768**")
with gr.Row():
vector_input = gr.Dataframe(
headers=["向量值"],
type="array",
label="输入向量 (768维)",
value=[[0.1]*768] # 768维默认值
)
output = gr.Textbox(label="智能回答", lines=5)
submit_btn = gr.Button("生成回答")
submit_btn.click(
fn=predict,
inputs=vector_input,
outputs=output
)
# 启动应用
if __name__ == "__main__":
# 预加载资源
if index is None or metadata is None:
print("🚀 启动前预加载资源...")
try:
load_resources()
except Exception as e:
print(f"⛔ 资源加载失败: {str(e)}")
sys.exit(1)
# 确保缓存目录存在
# import pathlib
# pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
try:
dummy_vector = [0.1] * 768
predict([dummy_vector])
except:
pass
print("="*50)
print("Space启动完成 | 准备接收请求")
print(f"索引维度: {index.d if index else '未加载'}")
print(f"元数据记录: {len(metadata) if metadata is not None else 0}")
print("="*50)
demo.launch(
server_name="0.0.0.0",
server_port=7860
)