Spaces:
Running
Running
File size: 3,734 Bytes
3e2ba9b 48f08ee 3e2ba9b 1373086 58773de 3e2ba9b 58773de 1373086 3e2ba9b 1373086 3e2ba9b 1373086 3e2ba9b 1373086 3e2ba9b 1373086 3e2ba9b 1373086 3e2ba9b 48f08ee 1373086 58773de 60d0c94 1373086 58773de 1373086 58773de 38e331e 1373086 3e2ba9b 60d0c94 48f08ee 1373086 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image
from io import BytesIO
# number convert to label
labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
description = """### NSFW 图片分类器 🚀
该模型可以分类NSFW(非工作安全)图片,支持以下类别:
- **Drawings**:绘画图像
- **Hentai**:色情漫画
- **Neutral**:中立图像
- **Porn**:成人内容
- **Sexy**:性感内容
上传一张图片,我们将为您预测其类别。
更多信息请访问 [GitHub](https://github.com/csuer411/nsfw_classify)"""
# Define CNN model
class Classifier(nn.Module):
def __init__(self):
super(Classifier, self).__init__()
self.cnn_layers = resnet18(weights=ResNet18_Weights.DEFAULT)
self.fc_layers = nn.Sequential(
nn.Linear(1000, 512),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.ReLU(),
nn.Linear(128, 5),
)
def forward(self, x):
x = self.cnn_layers(x)
x = self.fc_layers(x)
return x
# Pre-process
preprocess = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Load model
model = Classifier()
model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
model.eval()
def predict(image_path):
img = Image.open(image_path).convert("RGB")
img = preprocess(img).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(img)[0], dim=0)
result = {labels[i]: float(prediction[i]) for i in range(5)}
return result
# API 预测接口
def api_predict(image: gr.Image) -> dict:
img = Image.open(BytesIO(image)).convert("RGB")
img = preprocess(img).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(img)[0], dim=0)
result = {labels[i]: float(prediction[i]) for i in range(5)}
return result
# 美化 Gradio 界面
css = """
.gradio-container {
font-family: 'Arial', sans-serif;
background-color: #f0f4f8;
}
.gr-button {
background-color: #4CAF50;
color: white;
font-weight: bold;
}
.gr-button:hover {
background-color: #45a049;
}
.gr-label {
font-size: 1.2em;
font-weight: bold;
}
.gr-image {
border: 2px solid #4CAF50;
border-radius: 8px;
}
.gr-row {
margin-bottom: 20px;
}
.gr-markdown {
color: #333;
}
.gr-input {
border-radius: 5px;
}
footer {visibility: hidden}
"""
# 使用 Blocks API 创建界面
with gr.Blocks(css=css) as demo:
gr.Markdown("# 🌟 NSFW 图片分类器")
gr.Markdown(description)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="filepath", label="📷 上传图片", elem_id="input-image")
submit_btn = gr.Button("🚀 运行检测", variant="primary")
with gr.Column():
output_label = gr.Label(num_top_classes=2, label="📊 预测结果")
gr.Markdown("#### 示例图片")
examples = [
"./example/anime.jpg",
"./example/real.jpg"
]
gr.Examples(examples=examples, inputs=image_input, label="点击上传示例图片")
submit_btn.click(predict, inputs=image_input, outputs=output_label)
# 手动为 API 设置路由
demo.add_api_route("/api/predict", api_predict, methods=["POST"])
# 启动 Web 界面
demo.launch()
|