File size: 4,801 Bytes
7756dfe 868518d 1b59f0f 49bc88a cfb951e a9e075d 49bc88a 6183cb0 a9e075d 6183cb0 1b59f0f 49bc88a 015c037 49bc88a 5a49e7f 49bc88a a9e075d 5a49e7f 7260359 5a49e7f a9e075d 5a49e7f 52111a1 5a49e7f a9e075d 5a49e7f 49bc88a 015c037 868518d a9e075d 40ef0e2 49bc88a 1ce6f7d 49bc88a 1ce6f7d a9e075d 49bc88a a7ea7d8 868518d 49bc88a 080637e a9e075d dbbaf71 080637e 49bc88a 080637e a9e075d 080637e a9e075d 080637e 868518d dbbaf71 49bc88a a9e075d 5a49e7f 49bc88a de12dbe a7ea7d8 de12dbe 49bc88a a9e075d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import gradio as gr
import numpy as np
import os
import torch
import pandas as pd
import faiss
from huggingface_hub import hf_hub_download
import time
# 创建安全缓存目录
CACHE_DIR = "/home/user/cache"
os.makedirs(CACHE_DIR, exist_ok=True)
# 减少内存占用
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
# torch.set_num_threads(1)
os.environ["OMP_NUM_THREADS"] = "8"
# 全局变量 - 移除了clip_model和clip_processor
index = None
metadata = None
def load_resources():
"""加载所有必要资源(768维专用)"""
# 清理残留锁文件
lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')]
for lock_file in lock_files:
try:
os.remove(os.path.join(CACHE_DIR, lock_file))
print(f"🧹 清理锁文件: {lock_file}")
except: pass
global index, metadata
# 仅当资源未加载时才初始化
if index is None or metadata is None:
print("🔄 正在加载所有资源...")
# 加载FAISS索引(768维)
if index is None:
print("📥 正在下载FAISS索引...")
try:
INDEX_PATH = hf_hub_download(
repo_id="GOGO198/GOGO_rag_index",
filename="faiss_index.bin",
cache_dir=CACHE_DIR,
token=os.getenv("HF_TOKEN")
)
index = faiss.read_index(INDEX_PATH)
# 验证索引维度
if index.d != 768:
raise ValueError(f"❌ 索引维度错误:预期768维,实际{index.d}维")
print(f"✅ FAISS索引加载完成 | 维度: {index.d}")
except Exception as e:
print(f"❌ FAISS索引加载失败: {str(e)}")
raise
# 加载元数据
if metadata is None:
print("📄 正在下载元数据...")
try:
METADATA_PATH = hf_hub_download(
repo_id="GOGO198/GOGO_rag_index",
filename="metadata.csv",
cache_dir=CACHE_DIR,
token=os.getenv("HF_TOKEN")
)
metadata = pd.read_csv(METADATA_PATH)
print(f"✅ 元数据加载完成 | 记录数: {len(metadata)}")
except Exception as e:
print(f"❌ 元数据加载失败: {str(e)}")
raise
from functools import lru_cache
@lru_cache(maxsize=1000)
def predict(vector):
"""处理768维向量输入并返回答案"""
print(f"输入向量维度: {np.array(vector).shape}")
try:
# 确保输入数据完整
if len(vector[0]) != 768:
return "维度错误:需要768维向量"
# 添加实际处理逻辑
D, I = index.search(np.array(vector, dtype=np.float32), k=3)
results = [
f"匹配结果 {i+1}: {metadata.iloc[I[0][i]]['source']} | 置信度: {1/(1+D[0][i]):.2f}"
for i in range(3)
]
return "\n".join(results)
except Exception as e:
# 添加详细错误日志
import traceback
error_msg = f"处理错误: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return error_msg
# 创建简化接口
with gr.Blocks() as demo:
gr.Markdown("## 🛍 电商智能客服系统 (768维专用)")
gr.Markdown("**使用CLIP-vit-large-patch14模型 | 向量维度: 768**")
with gr.Row():
vector_input = gr.Dataframe(
headers=["向量值"],
type="array",
label="输入向量 (768维)",
value=[[0.1]*768] # 768维默认值
)
output = gr.Textbox(label="智能回答", lines=5)
submit_btn = gr.Button("生成回答")
submit_btn.click(
fn=predict,
inputs=vector_input,
outputs=output
)
# 启动应用
if __name__ == "__main__":
# 预加载资源
if index is None or metadata is None:
print("🚀 启动前预加载资源...")
try:
load_resources()
except Exception as e:
print(f"⛔ 资源加载失败: {str(e)}")
# 确保缓存目录存在
# import pathlib
# pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
try:
dummy_vector = [0.1] * 768
predict([dummy_vector])
except:
pass
print("="*50)
print("Space启动完成 | 准备接收请求")
print(f"索引维度: {index.d if index else '未加载'}")
print(f"元数据记录: {len(metadata) if metadata is not None else 0}")
print("="*50)
demo.launch(
server_name="0.0.0.0",
server_port=7860
)
|