Spaces:
Running
Running
File size: 5,803 Bytes
3e2ba9b ceae728 1373086 ceae728 3e2ba9b ceae728 3e2ba9b ceae728 58773de f2d9015 58773de f2d9015 ceae728 1373086 ceae728 38e331e ceae728 48f08ee ceae728 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image
# 配置参数
labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
theme_color = "#6C5B7B" # 主色调改为优雅的紫色
description = """<div style="padding: 20px; background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%); border-radius: 10px;">
<h2 style="color: {color}; margin-bottom: 15px;">🎨 NSFW 图片分类器</h2>
<p>该模型使用深度神经网络对图片内容进行分类,支持以下类别:</p>
<ul style="list-style-type: circle; padding-left: 25px;">
<li><span style="color: #4B4453;">Drawings</span> - 艺术绘画作品</li>
<li><span style="color: #845EC2;">Hentai</span> - 二次元成人内容</li>
<li><span style="color: #008F7A;">Neutral</span> - 日常安全内容</li>
<li><span style="color: #D65DB1;">Porn</span> - 露骨成人内容</li>
<li><span style="color: #FF9671;">Sexy</span> - 性感但不露骨内容</li>
</ul>
<p style="margin-top: 15px;">🖼️ 请上传图片或点击下方示例体验</p>
</div>""".format(color=theme_color)
# 模型定义和预处理(保持不变)
# ... [保持原有模型代码不变] ...
# 高级 CSS 样式
advanced_css = f"""
.gradio-container {{
background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
min-height: 100vh;
}}
.header-section {{
background: white;
padding: 2rem;
border-radius: 15px;
box-shadow: 0 4px 6px rgba(0,0,0,0.05);
margin-bottom: 2rem;
}}
.result-card {{
background: white !important;
padding: 1.5rem !important;
border-radius: 12px !important;
box-shadow: 0 2px 8px rgba(108,91,123,0.1) !important;
}}
.custom-button {{
background: {theme_color} !important;
color: white !important;
border: none !important;
padding: 12px 28px !important;
border-radius: 25px !important;
transition: all 0.3s ease !important;
}}
.custom-button:hover {{
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(108,91,123,0.3) !important;
}}
.upload-box {{
border: 2px dashed {theme_color} !important;
border-radius: 15px !important;
background: rgba(255,255,255,0.9) !important;
}}
.example-card {{
cursor: pointer;
transition: all 0.3s ease;
border-radius: 12px;
overflow: hidden;
}}
.example-card:hover {{
transform: scale(1.02);
box-shadow: 0 4px 12px rgba(108,91,123,0.2);
}}
.prob-bar {{
height: 8px;
border-radius: 4px;
background: linear-gradient(90deg, {theme_color} 0%, #C06C84 100%);
}}
"""
# Define CNN model
class Classifier(nn.Module):
def __init__(self):
super(Classifier, self).__init__()
self.cnn_layers = resnet18(weights=ResNet18_Weights.DEFAULT)
self.fc_layers = nn.Sequential(
nn.Linear(1000, 512),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.ReLU(),
nn.Linear(128, 5),
)
def forward(self, x):
x = self.cnn_layers(x)
x = self.fc_layers(x)
return x
# Pre-process
preprocess = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Load model
model = Classifier()
model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
model.eval()
def predict(image_path):
img = Image.open(image_path).convert("RGB")
img = preprocess(img).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(img)[0], dim=0)
result = {labels[i]: float(prediction[i]) for i in range(5)}
return result
with gr.Blocks(theme=gr.themes.Soft(), css=advanced_css) as demo:
# 标题区
with gr.Column(elem_classes="header-section"):
gr.Markdown("# 🎭 智能内容识别系统", elem_id="main-title")
gr.HTML(description)
# 主功能区
with gr.Row():
# 输入列
with gr.Column(scale=2):
upload_box = gr.Image(
type="filepath",
label="📤 上传图片",
elem_id="upload-box",
elem_classes="upload-box",
height=400
)
with gr.Row():
submit_btn = gr.Button(
"✨ 开始分析",
elem_classes="custom-button",
size="lg"
)
clear_btn = gr.Button(
"🔄 重新上传",
variant="secondary",
size="lg"
)
# 输出列
with gr.Column(scale=1):
with gr.Column(elem_classes="result-card"):
gr.Markdown("### 🔍 分析结果")
result_display = gr.Label(
label="分类概率分布",
num_top_classes=3,
show_label=False
)
gr.Markdown("**最高概率类别**: <span id='top-class'></span>", elem_id="dynamic-text")
# 示例区
with gr.Column():
gr.Markdown("### 🖼️ 示例图片")
examples = gr.Examples(
examples=["./example/anime.jpg", "./example/real.jpg"],
inputs=upload_box,
examples_per_page=2,
label="点击使用示例",
elem_id="example-gallery"
)
# 交互逻辑
clear_btn.click(fn=lambda: None, inputs=None, outputs=upload_box)
submit_btn.click(
fn=predict,
inputs=upload_box,
outputs=result_display,
api_name="predict"
)
# 启动界面
demo.launch() |