File size: 3,734 Bytes
3e2ba9b
 
 
 
 
 
48f08ee
3e2ba9b
 
1373086
 
58773de
 
 
 
 
 
3e2ba9b
58773de
 
 
1373086
3e2ba9b
 
 
1373086
3e2ba9b
 
 
 
 
 
 
 
 
 
 
 
 
1373086
 
 
 
 
 
3e2ba9b
1373086
3e2ba9b
 
 
 
1373086
 
 
3e2ba9b
 
1373086
 
3e2ba9b
 
 
48f08ee
 
 
 
 
 
 
 
 
 
 
1373086
58773de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60d0c94
 
1373086
 
 
 
 
58773de
1373086
 
 
 
58773de
 
 
 
 
38e331e
 
1373086
3e2ba9b
60d0c94
 
48f08ee
1373086
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image
from io import BytesIO

# number convert to label
labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
description = """### NSFW 图片分类器 🚀  
该模型可以分类NSFW(非工作安全)图片,支持以下类别:  
- **Drawings**:绘画图像  
- **Hentai**:色情漫画  
- **Neutral**:中立图像  
- **Porn**:成人内容  
- **Sexy**:性感内容  

上传一张图片,我们将为您预测其类别。  
更多信息请访问 [GitHub](https://github.com/csuer411/nsfw_classify)"""
  
# Define CNN model
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.cnn_layers = resnet18(weights=ResNet18_Weights.DEFAULT)
        self.fc_layers = nn.Sequential(
            nn.Linear(1000, 512),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 5),
        )

    def forward(self, x):
        x = self.cnn_layers(x)
        x = self.fc_layers(x)
        return x

# Pre-process
preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load model
model = Classifier()
model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
model.eval()

def predict(image_path):
    img = Image.open(image_path).convert("RGB")
    img = preprocess(img).unsqueeze(0)

    with torch.no_grad():
        prediction = torch.nn.functional.softmax(model(img)[0], dim=0)

    result = {labels[i]: float(prediction[i]) for i in range(5)}
    return result

# API 预测接口
def api_predict(image: gr.Image) -> dict:
    img = Image.open(BytesIO(image)).convert("RGB")
    img = preprocess(img).unsqueeze(0)

    with torch.no_grad():
        prediction = torch.nn.functional.softmax(model(img)[0], dim=0)

    result = {labels[i]: float(prediction[i]) for i in range(5)}
    return result

# 美化 Gradio 界面
css = """
    .gradio-container {
        font-family: 'Arial', sans-serif;
        background-color: #f0f4f8;
    }
    .gr-button {
        background-color: #4CAF50;
        color: white;
        font-weight: bold;
    }
    .gr-button:hover {
        background-color: #45a049;
    }
    .gr-label {
        font-size: 1.2em;
        font-weight: bold;
    }
    .gr-image {
        border: 2px solid #4CAF50;
        border-radius: 8px;
    }
    .gr-row {
        margin-bottom: 20px;
    }
    .gr-markdown {
        color: #333;
    }
    .gr-input {
        border-radius: 5px;
    }
    footer {visibility: hidden}
"""

# 使用 Blocks API 创建界面
with gr.Blocks(css=css) as demo:
    gr.Markdown("# 🌟 NSFW 图片分类器")
    gr.Markdown(description)

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="filepath", label="📷 上传图片", elem_id="input-image")
            submit_btn = gr.Button("🚀 运行检测", variant="primary")
        
        with gr.Column():
            output_label = gr.Label(num_top_classes=2, label="📊 预测结果")
            gr.Markdown("#### 示例图片")
            examples = [
                "./example/anime.jpg",
                "./example/real.jpg"
            ]
            gr.Examples(examples=examples, inputs=image_input, label="点击上传示例图片")

    submit_btn.click(predict, inputs=image_input, outputs=output_label)

    # 手动为 API 设置路由
    demo.add_api_route("/api/predict", api_predict, methods=["POST"])

# 启动 Web 界面
demo.launch()