csuer commited on
Commit
ceae728
·
verified ·
1 Parent(s): 88589c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -101
app.py CHANGED
@@ -5,112 +5,137 @@ from torchvision import transforms
5
  from torchvision.models import resnet18, ResNet18_Weights
6
  from PIL import Image
7
 
8
- # number convert to label
9
  labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
10
- description = """### NSFW 图片分类器 🚀
11
- 该模型可以分类NSFW(非工作安全)图片,支持以下类别:
12
- - **Drawings**:绘画图像
13
- - **Hentai**:色情漫画
14
- - **Neutral**:中立图像
15
- - **Porn**:成人内容
16
- - **Sexy**:性感内容
 
 
 
 
 
 
17
 
18
- 上传一张图片,我们将为您预测其类别。
19
- 更多信息请访问 [GitHub](https://github.com/csuer411/nsfw_classify)"""
20
-
21
- # Define CNN model
22
- class Classifier(nn.Module):
23
- def __init__(self):
24
- super(Classifier, self).__init__()
25
- self.cnn_layers = resnet18(weights=ResNet18_Weights.DEFAULT)
26
- self.fc_layers = nn.Sequential(
27
- nn.Linear(1000, 512),
28
- nn.Dropout(0.3),
29
- nn.Linear(512, 128),
30
- nn.ReLU(),
31
- nn.Linear(128, 5),
32
- )
33
 
34
- def forward(self, x):
35
- x = self.cnn_layers(x)
36
- x = self.fc_layers(x)
37
- return x
38
-
39
- # Pre-process
40
- preprocess = transforms.Compose([
41
- transforms.Resize(224),
42
- transforms.ToTensor(),
43
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
44
- ])
45
-
46
- # Load model
47
- model = Classifier()
48
- model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
49
- model.eval()
50
-
51
- def predict(image_path):
52
- img = Image.open(image_path).convert("RGB")
53
- img = preprocess(img).unsqueeze(0)
54
-
55
- with torch.no_grad():
56
- prediction = torch.nn.functional.softmax(model(img)[0], dim=0)
57
-
58
- result = {labels[i]: float(prediction[i]) for i in range(5)}
59
- return result
60
-
61
- # 美化 Gradio 界面
62
- css = """
63
- .gradio-container {
64
- font-family: 'Arial', sans-serif;
65
- background-color: #f0f4f8;
66
- }
67
- .gr-button {
68
- background-color: #4CAF50;
69
- color: white;
70
- font-weight: bold;
71
- }
72
- .gr-button:hover {
73
- background-color: #45a049;
74
- }
75
- .gr-label {
76
- font-size: 1.2em;
77
- font-weight: bold;
78
- }
79
- .gr-image {
80
- border: 2px solid #4CAF50;
81
- border-radius: 8px;
82
- }
83
- .gr-row {
84
- margin-bottom: 20px;
85
- }
86
- .gr-markdown {
87
- color: #333;
88
- }
89
- .gr-input {
90
- border-radius: 5px;
91
- }
92
- footer {visibility: hidden}
93
  """
94
 
95
- with gr.Blocks(theme="huggingface", css=css) as demo:
96
- gr.Markdown("# 🌟 NSFW 图片分类器")
97
- gr.Markdown(description)
98
-
 
 
 
99
  with gr.Row():
100
- with gr.Column():
101
- image_input = gr.Image(type="filepath", label="📷 上传图片", elem_id="input-image")
102
- submit_btn = gr.Button("🚀 运行检测", variant="primary")
103
-
104
- with gr.Column():
105
- output_label = gr.Label(num_top_classes=2, label="📊 预测结果")
106
- gr.Markdown("#### 示例图片")
107
- examples = [
108
- "./example/anime.jpg",
109
- "./example/real.jpg"
110
- ]
111
- gr.Examples(examples=examples, inputs=image_input, label="点击上传示例图片")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- submit_btn.click(predict, inputs=image_input, outputs=output_label, api_name="predict")
 
 
 
 
 
 
 
114
 
115
- # 启动 Web 界面
116
- demo.launch()
 
5
  from torchvision.models import resnet18, ResNet18_Weights
6
  from PIL import Image
7
 
8
+ # 配置参数
9
  labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
10
+ theme_color = "#6C5B7B" # 主色调改为优雅的紫色
11
+ description = """<div style="padding: 20px; background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%); border-radius: 10px;">
12
+ <h2 style="color: {color}; margin-bottom: 15px;">🎨 NSFW 图片分类器</h2>
13
+ <p>该模型使用深度神经网络对图片内容进行分类,支持以下类别:</p>
14
+ <ul style="list-style-type: circle; padding-left: 25px;">
15
+ <li><span style="color: #4B4453;">Drawings</span> - 艺术绘画作品</li>
16
+ <li><span style="color: #845EC2;">Hentai</span> - 二次元成人内容</li>
17
+ <li><span style="color: #008F7A;">Neutral</span> - 日常安全内容</li>
18
+ <li><span style="color: #D65DB1;">Porn</span> - 露骨成人内容</li>
19
+ <li><span style="color: #FF9671;">Sexy</span> - 性感但不露骨内容</li>
20
+ </ul>
21
+ <p style="margin-top: 15px;">🖼️ 请上传图片或点击下方示例体验</p>
22
+ </div>""".format(color=theme_color)
23
 
24
+ # 模型定义和预处理(保持不变)
25
+ # ... [保持原有模型代码不变] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # 高级 CSS 样式
28
+ advanced_css = f"""
29
+ .gradio-container {{
30
+ background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
31
+ min-height: 100vh;
32
+ }}
33
+ .header-section {{
34
+ background: white;
35
+ padding: 2rem;
36
+ border-radius: 15px;
37
+ box-shadow: 0 4px 6px rgba(0,0,0,0.05);
38
+ margin-bottom: 2rem;
39
+ }}
40
+ .result-card {{
41
+ background: white !important;
42
+ padding: 1.5rem !important;
43
+ border-radius: 12px !important;
44
+ box-shadow: 0 2px 8px rgba(108,91,123,0.1) !important;
45
+ }}
46
+ .custom-button {{
47
+ background: {theme_color} !important;
48
+ color: white !important;
49
+ border: none !important;
50
+ padding: 12px 28px !important;
51
+ border-radius: 25px !important;
52
+ transition: all 0.3s ease !important;
53
+ }}
54
+ .custom-button:hover {{
55
+ transform: translateY(-2px);
56
+ box-shadow: 0 4px 12px rgba(108,91,123,0.3) !important;
57
+ }}
58
+ .upload-box {{
59
+ border: 2px dashed {theme_color} !important;
60
+ border-radius: 15px !important;
61
+ background: rgba(255,255,255,0.9) !important;
62
+ }}
63
+ .example-card {{
64
+ cursor: pointer;
65
+ transition: all 0.3s ease;
66
+ border-radius: 12px;
67
+ overflow: hidden;
68
+ }}
69
+ .example-card:hover {{
70
+ transform: scale(1.02);
71
+ box-shadow: 0 4px 12px rgba(108,91,123,0.2);
72
+ }}
73
+ .prob-bar {{
74
+ height: 8px;
75
+ border-radius: 4px;
76
+ background: linear-gradient(90deg, {theme_color} 0%, #C06C84 100%);
77
+ }}
 
 
 
 
 
 
 
 
78
  """
79
 
80
+ with gr.Blocks(theme=gr.themes.Soft(), css=advanced_css) as demo:
81
+ # 标题区
82
+ with gr.Column(elem_classes="header-section"):
83
+ gr.Markdown("# 🎭 智能内容识别系统", elem_id="main-title")
84
+ gr.HTML(description)
85
+
86
+ # 主功能区
87
  with gr.Row():
88
+ # 输入列
89
+ with gr.Column(scale=2):
90
+ upload_box = gr.Image(
91
+ type="filepath",
92
+ label="📤 上传图片",
93
+ elem_id="upload-box",
94
+ elem_classes="upload-box",
95
+ height=400
96
+ )
97
+ with gr.Row():
98
+ submit_btn = gr.Button(
99
+ "✨ 开始分析",
100
+ elem_classes="custom-button",
101
+ size="lg"
102
+ )
103
+ clear_btn = gr.Button(
104
+ "🔄 重新上传",
105
+ variant="secondary",
106
+ size="lg"
107
+ )
108
+
109
+ # 输出列
110
+ with gr.Column(scale=1):
111
+ with gr.Column(elem_classes="result-card"):
112
+ gr.Markdown("### 🔍 分析结果")
113
+ result_display = gr.Label(
114
+ label="分类概率分布",
115
+ num_top_classes=3,
116
+ show_label=False
117
+ )
118
+ gr.Markdown("**最高概率类别**: <span id='top-class'></span>", elem_id="dynamic-text")
119
+
120
+ # 示例区
121
+ with gr.Column():
122
+ gr.Markdown("### 🖼️ 示例图片")
123
+ examples = gr.Examples(
124
+ examples=["./example/anime.jpg", "./example/real.jpg"],
125
+ inputs=upload_box,
126
+ examples_per_page=2,
127
+ label="点击使用示例",
128
+ elem_id="example-gallery"
129
+ )
130
 
131
+ # 交互逻辑
132
+ clear_btn.click(fn=lambda: None, inputs=None, outputs=upload_box)
133
+ submit_btn.click(
134
+ fn=predict,
135
+ inputs=upload_box,
136
+ outputs=result_display,
137
+ api_name="predict"
138
+ )
139
 
140
+ # 启动界面
141
+ demo.launch()