csuer's picture
Update app.py
60d0c94 verified
raw
history blame
3.73 kB
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image
from io import BytesIO
# number convert to label
labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
description = """### NSFW 图片分类器 🚀
该模型可以分类NSFW(非工作安全)图片,支持以下类别:
- **Drawings**:绘画图像
- **Hentai**:色情漫画
- **Neutral**:中立图像
- **Porn**:成人内容
- **Sexy**:性感内容
上传一张图片,我们将为您预测其类别。
更多信息请访问 [GitHub](https://github.com/csuer411/nsfw_classify)"""
# Define CNN model
class Classifier(nn.Module):
def __init__(self):
super(Classifier, self).__init__()
self.cnn_layers = resnet18(weights=ResNet18_Weights.DEFAULT)
self.fc_layers = nn.Sequential(
nn.Linear(1000, 512),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.ReLU(),
nn.Linear(128, 5),
)
def forward(self, x):
x = self.cnn_layers(x)
x = self.fc_layers(x)
return x
# Pre-process
preprocess = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Load model
model = Classifier()
model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
model.eval()
def predict(image_path):
img = Image.open(image_path).convert("RGB")
img = preprocess(img).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(img)[0], dim=0)
result = {labels[i]: float(prediction[i]) for i in range(5)}
return result
# API 预测接口
def api_predict(image: gr.Image) -> dict:
img = Image.open(BytesIO(image)).convert("RGB")
img = preprocess(img).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(img)[0], dim=0)
result = {labels[i]: float(prediction[i]) for i in range(5)}
return result
# 美化 Gradio 界面
css = """
.gradio-container {
font-family: 'Arial', sans-serif;
background-color: #f0f4f8;
}
.gr-button {
background-color: #4CAF50;
color: white;
font-weight: bold;
}
.gr-button:hover {
background-color: #45a049;
}
.gr-label {
font-size: 1.2em;
font-weight: bold;
}
.gr-image {
border: 2px solid #4CAF50;
border-radius: 8px;
}
.gr-row {
margin-bottom: 20px;
}
.gr-markdown {
color: #333;
}
.gr-input {
border-radius: 5px;
}
footer {visibility: hidden}
"""
# 使用 Blocks API 创建界面
with gr.Blocks(css=css) as demo:
gr.Markdown("# 🌟 NSFW 图片分类器")
gr.Markdown(description)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="filepath", label="📷 上传图片", elem_id="input-image")
submit_btn = gr.Button("🚀 运行检测", variant="primary")
with gr.Column():
output_label = gr.Label(num_top_classes=2, label="📊 预测结果")
gr.Markdown("#### 示例图片")
examples = [
"./example/anime.jpg",
"./example/real.jpg"
]
gr.Examples(examples=examples, inputs=image_input, label="点击上传示例图片")
submit_btn.click(predict, inputs=image_input, outputs=output_label)
# 手动为 API 设置路由
demo.add_api_route("/api/predict", api_predict, methods=["POST"])
# 启动 Web 界面
demo.launch()