kaisex commited on
Commit
933b48e
·
verified ·
1 Parent(s): 0a2eaba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -28
app.py CHANGED
@@ -1,27 +1,54 @@
1
  import gradio as gr
 
 
 
2
  from PIL import Image
3
 
4
- from model_human import predict_human_nsfw
5
- from model_anime import predict_anime_nsfw
 
6
 
7
- def nsfw_detector(image: Image.Image, model_type: str):
 
 
 
 
 
 
 
 
 
 
8
  if image is None:
9
- return "<div class='result-box'>Pls upload an img</div>"
 
 
10
 
11
- if model_type == "Human":
12
- label, confidence = predict_human_nsfw(image)
 
 
 
 
 
 
 
 
13
  else:
14
- label, confidence = predict_anime_nsfw(image)
 
 
 
15
 
16
- result = (
17
- f"<div class='result-box'>"
18
- f"<strong>Model:</strong> {model_type}<br>"
19
- f"<strong>Prediction:</strong> {label}<br>"
20
- f"<strong>Confidence:</strong> {confidence:.2%}"
21
- f"</div>"
22
- )
23
- return result
24
 
 
25
  custom_css = """
26
  .result-box {
27
  background-color: oklch(0.718 0.202 349.761);
@@ -37,30 +64,23 @@ custom_css = """
37
  .gradio-container { max-width: 900px; margin: auto; }
38
  """
39
 
 
40
  with gr.Blocks(css=custom_css) as demo:
41
  gr.Markdown("## NSFW Detector (Human + Anime/Cartoon)")
42
  gr.Markdown(
43
- "Upload an image and select the appropriate model for detection. "
44
- "No data is stored. This is a side project results may not be fully accurate."
45
  )
46
 
47
  with gr.Row():
48
  with gr.Column(scale=1):
49
- model_choice = gr.Radio(
50
- ["Human", "Anime"],
51
- label="Select Model Type",
52
- value="Human"
53
- )
54
  image_input = gr.Image(type="pil", label="Upload Image")
55
  with gr.Column(scale=1):
56
  output_box = gr.HTML("<div class='result-box'>Awaiting input...</div>")
57
 
58
- # Trigger detection when image is uploaded
59
- image_input.change(
60
- fn=nsfw_detector,
61
- inputs=[image_input, model_choice],
62
- outputs=output_box
63
- )
64
 
65
  if __name__ == "__main__":
66
  demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
  from PIL import Image
6
 
7
+ # Load models
8
+ human_model = torch.load("humanNsfw_Swf.pth", map_location=torch.device("cpu"))
9
+ human_model.eval()
10
 
11
+ anime_model = torch.load("animeCartoonNsfw_Sfw.pth", map_location=torch.device("cpu"))
12
+ anime_model.eval()
13
+
14
+ # Shared preprocessing
15
+ preprocess = transforms.Compose([
16
+ transforms.Resize((224, 224)),
17
+ transforms.ToTensor(),
18
+ ])
19
+
20
+ # Prediction functions
21
+ def predict(image, model_type):
22
  if image is None:
23
+ return "<div class='result-box'>Please upload an image.</div>"
24
+
25
+ input_tensor = preprocess(image).unsqueeze(0)
26
 
27
+ with torch.no_grad():
28
+ if model_type == "Human":
29
+ output = human_model(input_tensor)
30
+ else:
31
+ output = anime_model(input_tensor)
32
+
33
+ if output.shape[-1] == 1:
34
+ prob = torch.sigmoid(output).item()
35
+ label = "NSFW" if prob > 0.5 else "SFW"
36
+ confidence = prob if prob > 0.5 else 1 - prob
37
  else:
38
+ probs = F.softmax(output, dim=1).squeeze()
39
+ label_index = torch.argmax(probs).item()
40
+ label = "NSFW" if label_index == 1 else "SFW"
41
+ confidence = probs[label_index].item()
42
 
43
+ return f"""
44
+ <div class='result-box'>
45
+ <strong>Model:</strong> {model_type}<br>
46
+ <strong>Prediction:</strong> {label}<br>
47
+ <strong>Confidence:</strong> {confidence:.2%}
48
+ </div>
49
+ """
 
50
 
51
+ # Custom glowing style
52
  custom_css = """
53
  .result-box {
54
  background-color: oklch(0.718 0.202 349.761);
 
64
  .gradio-container { max-width: 900px; margin: auto; }
65
  """
66
 
67
+ # Gradio Interface
68
  with gr.Blocks(css=custom_css) as demo:
69
  gr.Markdown("## NSFW Detector (Human + Anime/Cartoon)")
70
  gr.Markdown(
71
+ "Upload an image and choose the model. The system will predict whether the content is NSFW or SFW. "
72
+ "This is a side project. Results may vary. No images are stored."
73
  )
74
 
75
  with gr.Row():
76
  with gr.Column(scale=1):
77
+ model_choice = gr.Radio(["Human", "Anime"], label="Select Model Type", value="Human")
 
 
 
 
78
  image_input = gr.Image(type="pil", label="Upload Image")
79
  with gr.Column(scale=1):
80
  output_box = gr.HTML("<div class='result-box'>Awaiting input...</div>")
81
 
82
+ image_input.change(fn=predict, inputs=[image_input, model_choice], outputs=output_box)
83
+ model_choice.change(fn=predict, inputs=[image_input, model_choice], outputs=output_box)
 
 
 
 
84
 
85
  if __name__ == "__main__":
86
  demo.launch()