|
import gradio as gr |
|
import numpy as np |
|
import os |
|
import torch |
|
import pandas as pd |
|
import faiss |
|
from huggingface_hub import hf_hub_download |
|
import time |
|
import sys |
|
import json |
|
|
|
|
|
CACHE_DIR = "/home/user/cache" |
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
|
os.environ["OMP_NUM_THREADS"] = "2" |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
index = None |
|
metadata = None |
|
|
|
def load_resources(): |
|
"""加载所有必要资源(768维专用)""" |
|
|
|
lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')] |
|
for lock_file in lock_files: |
|
try: |
|
os.remove(os.path.join(CACHE_DIR, lock_file)) |
|
print(f"🧹 清理锁文件: {lock_file}") |
|
except: pass |
|
|
|
global index, metadata |
|
|
|
|
|
if index is None or metadata is None: |
|
print("🔄 正在加载所有资源...") |
|
|
|
|
|
if index is None: |
|
print("📥 正在下载FAISS索引...") |
|
try: |
|
INDEX_PATH = hf_hub_download( |
|
repo_id="GOGO198/GOGO_rag_index", |
|
filename="faiss_index.bin", |
|
cache_dir=CACHE_DIR, |
|
token=os.getenv("HF_TOKEN") |
|
) |
|
index = faiss.read_index(INDEX_PATH) |
|
|
|
|
|
if index.d != 768: |
|
raise ValueError(f"❌ 索引维度错误:预期768维,实际{index.d}维") |
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"✅ FAISS索引加载完成 | 维度: {index.d}") |
|
except Exception as e: |
|
print(f"❌ FAISS索引加载失败: {str(e)}") |
|
raise |
|
|
|
|
|
if metadata is None: |
|
print("📄 正在下载元数据...") |
|
try: |
|
METADATA_PATH = hf_hub_download( |
|
repo_id="GOGO198/GOGO_rag_index", |
|
filename="metadata.csv", |
|
cache_dir=CACHE_DIR, |
|
token=os.getenv("HF_TOKEN") |
|
) |
|
metadata = pd.read_csv(METADATA_PATH) |
|
print(f"✅ 元数据加载完成 | 记录数: {len(metadata)}") |
|
except Exception as e: |
|
print(f"❌ 元数据加载失败: {str(e)}") |
|
raise |
|
|
|
from functools import lru_cache |
|
@lru_cache(maxsize=100) |
|
def predict(vector): |
|
"""处理768维向量输入并返回答案""" |
|
start_time = time.time() |
|
print(f"输入向量维度: {np.array(vector).shape}") |
|
|
|
try: |
|
|
|
if not isinstance(vector, list) or len(vector) == 0: |
|
error_msg = "错误:输入格式无效" |
|
print(error_msg) |
|
return error_msg |
|
|
|
if len(vector) != 1 or len(vector[0]) != 768: |
|
error_msg = f"错误:需要1x768的二维数组,收到{len(vector)}x{len(vector[0]) if vector else 0}" |
|
print(error_msg) |
|
return error_msg |
|
|
|
|
|
vector_array = np.array(vector, dtype=np.float32) |
|
D, I = index.search(vector_array, k=3) |
|
|
|
results = [] |
|
for i in range(3): |
|
try: |
|
result = metadata.iloc[I[0][i]] |
|
confidence = 1/(1+D[0][i]) |
|
results.append(f"匹配结果 {i+1}: {result['source']} | 置信度: {confidence:.2f}") |
|
except Exception as e: |
|
print(f"结果处理错误: {str(e)}") |
|
results.append(f"结果 {i+1}: 数据获取失败") |
|
|
|
print(f"处理完成 | 耗时: {time.time()-start_time:.2f}秒") |
|
return json.dumps({ |
|
"results": results |
|
}) |
|
|
|
except Exception as e: |
|
import traceback |
|
error_msg = f"处理错误: {str(e)}\n{traceback.format_exc()}" |
|
print(error_msg) |
|
return "处理错误,请重试或联系管理员" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## 🛍 电商智能客服系统 (768维专用)") |
|
gr.Markdown("**使用CLIP-vit-large-patch14模型 | 向量维度: 768**") |
|
|
|
with gr.Row(): |
|
vector_input = gr.Dataframe( |
|
headers=["向量值"], |
|
type="array", |
|
label="输入向量 (768维)", |
|
value=[[0.1]*768] |
|
) |
|
output = gr.Textbox(label="智能回答", lines=5) |
|
|
|
submit_btn = gr.Button("生成回答") |
|
submit_btn.click( |
|
fn=predict, |
|
inputs=vector_input, |
|
outputs=output |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
if index is None or metadata is None: |
|
print("🚀 启动前预加载资源...") |
|
try: |
|
load_resources() |
|
except Exception as e: |
|
print(f"⛔ 资源加载失败: {str(e)}") |
|
sys.exit(1) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
dummy_vector = [0.1] * 768 |
|
predict([dummy_vector]) |
|
except: |
|
pass |
|
|
|
print("="*50) |
|
print("Space启动完成 | 准备接收请求") |
|
print(f"索引维度: {index.d if index else '未加载'}") |
|
print(f"元数据记录: {len(metadata) if metadata is not None else 0}") |
|
print("="*50) |
|
|
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860 |
|
) |
|
|