csuer commited on
Commit
f2d9015
·
verified ·
1 Parent(s): ceae728

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -0
app.py CHANGED
@@ -76,7 +76,45 @@ advanced_css = f"""
76
  background: linear-gradient(90deg, {theme_color} 0%, #C06C84 100%);
77
  }}
78
  """
 
 
 
 
 
 
 
 
 
 
 
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  with gr.Blocks(theme=gr.themes.Soft(), css=advanced_css) as demo:
81
  # 标题区
82
  with gr.Column(elem_classes="header-section"):
 
76
  background: linear-gradient(90deg, {theme_color} 0%, #C06C84 100%);
77
  }}
78
  """
79
+ # Define CNN model
80
+ class Classifier(nn.Module):
81
+ def __init__(self):
82
+ super(Classifier, self).__init__()
83
+ self.cnn_layers = resnet18(weights=ResNet18_Weights.DEFAULT)
84
+ self.fc_layers = nn.Sequential(
85
+ nn.Linear(1000, 512),
86
+ nn.Dropout(0.3),
87
+ nn.Linear(512, 128),
88
+ nn.ReLU(),
89
+ nn.Linear(128, 5),
90
+ )
91
 
92
+ def forward(self, x):
93
+ x = self.cnn_layers(x)
94
+ x = self.fc_layers(x)
95
+ return x
96
+
97
+ # Pre-process
98
+ preprocess = transforms.Compose([
99
+ transforms.Resize(224),
100
+ transforms.ToTensor(),
101
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
102
+ ])
103
+
104
+ # Load model
105
+ model = Classifier()
106
+ model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
107
+ model.eval()
108
+
109
+ def predict(image_path):
110
+ img = Image.open(image_path).convert("RGB")
111
+ img = preprocess(img).unsqueeze(0)
112
+
113
+ with torch.no_grad():
114
+ prediction = torch.nn.functional.softmax(model(img)[0], dim=0)
115
+
116
+ result = {labels[i]: float(prediction[i]) for i in range(5)}
117
+ return result
118
  with gr.Blocks(theme=gr.themes.Soft(), css=advanced_css) as demo:
119
  # 标题区
120
  with gr.Column(elem_classes="header-section"):