import gradio as gr import numpy as np import os import torch import pandas as pd import faiss from huggingface_hub import hf_hub_download import time import sys import json # 创建安全缓存目录 CACHE_DIR = "/home/user/cache" os.makedirs(CACHE_DIR, exist_ok=True) # 减少内存占用 # os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32" # torch.set_num_threads(1) os.environ["OMP_NUM_THREADS"] = "2" os.environ["TOKENIZERS_PARALLELISM"] = "false" # 防止tokenizer内存泄漏 # 全局变量 - 移除了clip_model和clip_processor index = None metadata = None def load_resources(): """加载所有必要资源(768维专用)""" # 清理残留锁文件 lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')] for lock_file in lock_files: try: os.remove(os.path.join(CACHE_DIR, lock_file)) print(f"🧹 清理锁文件: {lock_file}") except: pass global index, metadata # 仅当资源未加载时才初始化 if index is None or metadata is None: print("🔄 正在加载所有资源...") # 加载FAISS索引(768维) if index is None: print("📥 正在下载FAISS索引...") try: INDEX_PATH = hf_hub_download( repo_id="GOGO198/GOGO_rag_index", filename="faiss_index.bin", cache_dir=CACHE_DIR, token=os.getenv("HF_TOKEN") ) index = faiss.read_index(INDEX_PATH) # 验证索引维度 if index.d != 768: raise ValueError(f"❌ 索引维度错误:预期768维,实际{index.d}维") # if index and not index.is_trained: # print("🔧 训练量化索引...") # index.train(np.random.rand(10000, 768).astype('float32')) # print("✅ 索引训练完成") print(f"✅ FAISS索引加载完成 | 维度: {index.d}") except Exception as e: print(f"❌ FAISS索引加载失败: {str(e)}") raise # 加载元数据 if metadata is None: print("📄 正在下载元数据...") try: METADATA_PATH = hf_hub_download( repo_id="GOGO198/GOGO_rag_index", filename="metadata.csv", cache_dir=CACHE_DIR, token=os.getenv("HF_TOKEN") ) metadata = pd.read_csv(METADATA_PATH) print(f"✅ 元数据加载完成 | 记录数: {len(metadata)}") except Exception as e: print(f"❌ 元数据加载失败: {str(e)}") raise from functools import lru_cache @lru_cache(maxsize=100) def predict(vector): """处理768维向量输入并返回答案""" start_time = time.time() print(f"输入向量维度: {np.array(vector).shape}") try: # 验证输入格式 if not isinstance(vector, list) or len(vector) == 0: error_msg = "错误:输入格式无效" print(error_msg) return error_msg if len(vector) != 1 or len(vector[0]) != 768: error_msg = f"错误:需要1x768的二维数组,收到{len(vector)}x{len(vector[0]) if vector else 0}" print(error_msg) return error_msg # 添加实际处理逻辑 vector_array = np.array(vector, dtype=np.float32) D, I = index.search(vector_array, k=3) results = [] for i in range(3): try: result = metadata.iloc[I[0][i]] confidence = 1/(1+D[0][i]) results.append(f"匹配结果 {i+1}: {result['source']} | 置信度: {confidence:.2f}") except Exception as e: print(f"结果处理错误: {str(e)}") results.append(f"结果 {i+1}: 数据获取失败") print(f"处理完成 | 耗时: {time.time()-start_time:.2f}秒") return json.dumps({ "results": results # 确保嵌套结构合法 }) except Exception as e: import traceback error_msg = f"处理错误: {str(e)}\n{traceback.format_exc()}" print(error_msg) return "处理错误,请重试或联系管理员" # 创建简化接口 with gr.Blocks() as demo: gr.Markdown("## 🛍 电商智能客服系统 (768维专用)") gr.Markdown("**使用CLIP-vit-large-patch14模型 | 向量维度: 768**") with gr.Row(): vector_input = gr.Dataframe( headers=["向量值"], type="array", label="输入向量 (768维)", value=[[0.1]*768] # 768维默认值 ) output = gr.Textbox(label="智能回答", lines=5) submit_btn = gr.Button("生成回答") submit_btn.click( fn=predict, inputs=vector_input, outputs=output ) # 启动应用 if __name__ == "__main__": # 预加载资源 if index is None or metadata is None: print("🚀 启动前预加载资源...") try: load_resources() except Exception as e: print(f"⛔ 资源加载失败: {str(e)}") sys.exit(1) # 确保缓存目录存在 # import pathlib # pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) try: dummy_vector = [0.1] * 768 predict([dummy_vector]) except: pass print("="*50) print("Space启动完成 | 准备接收请求") print(f"索引维度: {index.d if index else '未加载'}") print(f"元数据记录: {len(metadata) if metadata is not None else 0}") print("="*50) demo.launch( server_name="0.0.0.0", server_port=7860 )