File size: 1,694 Bytes
fa95a21 1b8280b fa95a21 b01567b 1b8280b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import gradio as gr
from huggingface_hub import hf_hub_download
import pickle
import os
# 1) 下载并加载模型(pickle 文件)
def load_pickled_object(repo_id: str, filename: str):
# 如果已经下载过就不重复下载
cache_dir = os.path.join(".cache", repo_id.replace("/", "_"))
os.makedirs(cache_dir, exist_ok=True)
local_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
cache_dir=cache_dir,
force_download=False # 如果本地已有就不重新下
)
# 反序列化
with open(local_path, "rb") as f:
obj = pickle.load(f)
return obj
# 在 Space 启动时就加载一次
# 请替换成你自己的 repo id 和 pickle 文件名
MODEL_REPO = "szk2024/test"
MODEL_FILE = "evil_model.pkl"
try:
model = load_pickled_object(MODEL_REPO, MODEL_FILE)
except Exception as e:
# 如果出错可以打印日志
print("❌ 模型加载失败:", e)
model = None
# 2) 定义预测函数(根据你的 pickle 对象改写)
def predict(text: str):
if model is None:
return "模型加载失败,请检查日志"
# 假设你的 pickle 对象有一个 predict 方法
try:
res = model.predict([text])
return str(res)
except Exception as e:
return f"预测失败:{e}"
# 3) 用 Gradio 搭个简单的文本接口
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=2, placeholder="在此输入内容…"),
outputs="text",
title="Pickle 模型调用示例",
description="从 Hugging Face Hub 下载 pickle 并反序列化后预测"
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)
|