File size: 5,803 Bytes
3e2ba9b
 
 
 
 
 
 
ceae728
1373086
ceae728
 
 
 
 
 
 
 
 
 
 
 
 
3e2ba9b
ceae728
 
3e2ba9b
ceae728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58773de
f2d9015
 
 
 
 
 
 
 
 
 
 
 
58773de
f2d9015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ceae728
 
 
 
 
 
 
1373086
ceae728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38e331e
ceae728
 
 
 
 
 
 
 
48f08ee
ceae728
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image

# 配置参数
labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
theme_color = "#6C5B7B"  # 主色调改为优雅的紫色
description = """<div style="padding: 20px; background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%); border-radius: 10px;">
<h2 style="color: {color}; margin-bottom: 15px;">🎨 NSFW 图片分类器</h2>
<p>该模型使用深度神经网络对图片内容进行分类,支持以下类别:</p>
<ul style="list-style-type: circle; padding-left: 25px;">
  <li><span style="color: #4B4453;">Drawings</span> - 艺术绘画作品</li>
  <li><span style="color: #845EC2;">Hentai</span> - 二次元成人内容</li>
  <li><span style="color: #008F7A;">Neutral</span> - 日常安全内容</li>
  <li><span style="color: #D65DB1;">Porn</span> - 露骨成人内容</li>
  <li><span style="color: #FF9671;">Sexy</span> - 性感但不露骨内容</li>
</ul>
<p style="margin-top: 15px;">🖼️ 请上传图片或点击下方示例体验</p>
</div>""".format(color=theme_color)

# 模型定义和预处理(保持不变)
# ... [保持原有模型代码不变] ...

# 高级 CSS 样式
advanced_css = f"""
.gradio-container {{
    background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
    min-height: 100vh;
}}
.header-section {{
    background: white;
    padding: 2rem;
    border-radius: 15px;
    box-shadow: 0 4px 6px rgba(0,0,0,0.05);
    margin-bottom: 2rem;
}}
.result-card {{
    background: white !important;
    padding: 1.5rem !important;
    border-radius: 12px !important;
    box-shadow: 0 2px 8px rgba(108,91,123,0.1) !important;
}}
.custom-button {{
    background: {theme_color} !important;
    color: white !important;
    border: none !important;
    padding: 12px 28px !important;
    border-radius: 25px !important;
    transition: all 0.3s ease !important;
}}
.custom-button:hover {{
    transform: translateY(-2px);
    box-shadow: 0 4px 12px rgba(108,91,123,0.3) !important;
}}
.upload-box {{
    border: 2px dashed {theme_color} !important;
    border-radius: 15px !important;
    background: rgba(255,255,255,0.9) !important;
}}
.example-card {{
    cursor: pointer;
    transition: all 0.3s ease;
    border-radius: 12px;
    overflow: hidden;
}}
.example-card:hover {{
    transform: scale(1.02);
    box-shadow: 0 4px 12px rgba(108,91,123,0.2);
}}
.prob-bar {{
    height: 8px;
    border-radius: 4px;
    background: linear-gradient(90deg, {theme_color} 0%, #C06C84 100%);
}}
"""
# Define CNN model
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.cnn_layers = resnet18(weights=ResNet18_Weights.DEFAULT)
        self.fc_layers = nn.Sequential(
            nn.Linear(1000, 512),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 5),
        )

    def forward(self, x):
        x = self.cnn_layers(x)
        x = self.fc_layers(x)
        return x

# Pre-process
preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load model
model = Classifier()
model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
model.eval()

def predict(image_path):
    img = Image.open(image_path).convert("RGB")
    img = preprocess(img).unsqueeze(0)

    with torch.no_grad():
        prediction = torch.nn.functional.softmax(model(img)[0], dim=0)

    result = {labels[i]: float(prediction[i]) for i in range(5)}
    return result
with gr.Blocks(theme=gr.themes.Soft(), css=advanced_css) as demo:
    # 标题区
    with gr.Column(elem_classes="header-section"):
        gr.Markdown("# 🎭 智能内容识别系统", elem_id="main-title")
        gr.HTML(description)
    
    # 主功能区
    with gr.Row():
        # 输入列
        with gr.Column(scale=2):
            upload_box = gr.Image(
                type="filepath",
                label="📤 上传图片",
                elem_id="upload-box",
                elem_classes="upload-box",
                height=400
            )
            with gr.Row():
                submit_btn = gr.Button(
                    "✨ 开始分析",
                    elem_classes="custom-button",
                    size="lg"
                )
                clear_btn = gr.Button(
                    "🔄 重新上传",
                    variant="secondary",
                    size="lg"
                )

        # 输出列
        with gr.Column(scale=1):
            with gr.Column(elem_classes="result-card"):
                gr.Markdown("### 🔍 分析结果")
                result_display = gr.Label(
                    label="分类概率分布",
                    num_top_classes=3,
                    show_label=False
                )
                gr.Markdown("**最高概率类别**: <span id='top-class'></span>", elem_id="dynamic-text")
            
            # 示例区
            with gr.Column():
                gr.Markdown("### 🖼️ 示例图片")
                examples = gr.Examples(
                    examples=["./example/anime.jpg", "./example/real.jpg"],
                    inputs=upload_box,
                    examples_per_page=2,
                    label="点击使用示例",
                    elem_id="example-gallery"
                )

    # 交互逻辑
    clear_btn.click(fn=lambda: None, inputs=None, outputs=upload_box)
    submit_btn.click(
        fn=predict,
        inputs=upload_box,
        outputs=result_display,
        api_name="predict"
    )

# 启动界面
demo.launch()