Spaces:
Running
Running
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
from torchvision.models import resnet18, ResNet18_Weights | |
from PIL import Image | |
from io import BytesIO | |
# number convert to label | |
labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"] | |
description = """### NSFW 图片分类器 🚀 | |
该模型可以分类NSFW(非工作安全)图片,支持以下类别: | |
- **Drawings**:绘画图像 | |
- **Hentai**:色情漫画 | |
- **Neutral**:中立图像 | |
- **Porn**:成人内容 | |
- **Sexy**:性感内容 | |
上传一张图片,我们将为您预测其类别。 | |
更多信息请访问 [GitHub](https://github.com/csuer411/nsfw_classify)""" | |
# Define CNN model | |
class Classifier(nn.Module): | |
def __init__(self): | |
super(Classifier, self).__init__() | |
self.cnn_layers = resnet18(weights=ResNet18_Weights.DEFAULT) | |
self.fc_layers = nn.Sequential( | |
nn.Linear(1000, 512), | |
nn.Dropout(0.3), | |
nn.Linear(512, 128), | |
nn.ReLU(), | |
nn.Linear(128, 5), | |
) | |
def forward(self, x): | |
x = self.cnn_layers(x) | |
x = self.fc_layers(x) | |
return x | |
# Pre-process | |
preprocess = transforms.Compose([ | |
transforms.Resize(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Load model | |
model = Classifier() | |
model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu")) | |
model.eval() | |
def predict(image_path): | |
img = Image.open(image_path).convert("RGB") | |
img = preprocess(img).unsqueeze(0) | |
with torch.no_grad(): | |
prediction = torch.nn.functional.softmax(model(img)[0], dim=0) | |
result = {labels[i]: float(prediction[i]) for i in range(5)} | |
return result | |
# API 预测接口 | |
def api_predict(image: gr.Image) -> dict: | |
img = Image.open(BytesIO(image)).convert("RGB") | |
img = preprocess(img).unsqueeze(0) | |
with torch.no_grad(): | |
prediction = torch.nn.functional.softmax(model(img)[0], dim=0) | |
result = {labels[i]: float(prediction[i]) for i in range(5)} | |
return result | |
# 美化 Gradio 界面 | |
css = """ | |
.gradio-container { | |
font-family: 'Arial', sans-serif; | |
background-color: #f0f4f8; | |
} | |
.gr-button { | |
background-color: #4CAF50; | |
color: white; | |
font-weight: bold; | |
} | |
.gr-button:hover { | |
background-color: #45a049; | |
} | |
.gr-label { | |
font-size: 1.2em; | |
font-weight: bold; | |
} | |
.gr-image { | |
border: 2px solid #4CAF50; | |
border-radius: 8px; | |
} | |
.gr-row { | |
margin-bottom: 20px; | |
} | |
.gr-markdown { | |
color: #333; | |
} | |
.gr-input { | |
border-radius: 5px; | |
} | |
footer {visibility: hidden} | |
""" | |
# 使用 Blocks API 创建界面 | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("# 🌟 NSFW 图片分类器") | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="filepath", label="📷 上传图片", elem_id="input-image") | |
submit_btn = gr.Button("🚀 运行检测", variant="primary") | |
with gr.Column(): | |
output_label = gr.Label(num_top_classes=2, label="📊 预测结果") | |
gr.Markdown("#### 示例图片") | |
examples = [ | |
"./example/anime.jpg", | |
"./example/real.jpg" | |
] | |
gr.Examples(examples=examples, inputs=image_input, label="点击上传示例图片") | |
submit_btn.click(predict, inputs=image_input, outputs=output_label) | |
# 手动为 API 设置路由 | |
demo.add_api_route("/api/predict", api_predict, methods=["POST"]) | |
# 启动 Web 界面 | |
demo.launch() | |