kaisex commited on
Commit
5b14e9a
·
verified ·
1 Parent(s): abb0840

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -28
app.py CHANGED
@@ -3,43 +3,46 @@ import torch
3
  import torch.nn.functional as F
4
  from torchvision import transforms
5
  from PIL import Image
 
6
 
7
- # Load models
8
- human_model = torch.load("humanNsfw_Swf.pth", map_location=torch.device("cpu"))
 
 
 
 
 
 
9
  human_model.eval()
10
 
11
- anime_model = torch.load("animeCartoonNsfw_Sfw.pth", map_location=torch.device("cpu"))
 
 
 
 
12
  anime_model.eval()
13
 
14
- # Shared preprocessing
15
- preprocess = transforms.Compose([
16
- transforms.Resize((224, 224)),
17
- transforms.ToTensor(),
18
- ])
19
 
20
- # Prediction functions
21
  def predict(image, model_type):
22
  if image is None:
23
  return "<div class='result-box'>Please upload an image.</div>"
24
 
25
- input_tensor = preprocess(image).unsqueeze(0)
 
26
 
27
  with torch.no_grad():
28
- if model_type == "Human":
29
- output = human_model(input_tensor)
30
- else:
31
- output = anime_model(input_tensor)
32
-
33
- if output.shape[-1] == 1:
34
- prob = torch.sigmoid(output).item()
35
- label = "NSFW" if prob > 0.5 else "SFW"
36
- confidence = prob if prob > 0.5 else 1 - prob
37
- else:
38
- probs = F.softmax(output, dim=1).squeeze()
39
- label_index = torch.argmax(probs).item()
40
- label = "NSFW" if label_index == 1 else "SFW"
41
- confidence = probs[label_index].item()
42
 
 
43
  return f"""
44
  <div class='result-box'>
45
  <strong>Model:</strong> {model_type}<br>
@@ -48,7 +51,7 @@ def predict(image, model_type):
48
  </div>
49
  """
50
 
51
- # Custom glowing style
52
  custom_css = """
53
  .result-box {
54
  background-color: oklch(0.718 0.202 349.761);
@@ -64,12 +67,12 @@ custom_css = """
64
  .gradio-container { max-width: 900px; margin: auto; }
65
  """
66
 
67
- # Gradio Interface
68
  with gr.Blocks(css=custom_css) as demo:
69
  gr.Markdown("## NSFW Detector (Human + Anime/Cartoon)")
70
  gr.Markdown(
71
- "Upload an image and choose the model. The system will predict whether the content is NSFW or SFW. "
72
- "This is a side project. Results may vary. No images are stored."
73
  )
74
 
75
  with gr.Row():
 
3
  import torch.nn.functional as F
4
  from torchvision import transforms
5
  from PIL import Image
6
+ from transformers import ViTForImageClassification, ViTImageProcessor
7
 
8
+ # Load processor (same for both)
9
+ processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
10
+
11
+ # Rebuild and load Human NSFW model
12
+ human_model = ViTForImageClassification.from_pretrained(
13
+ "google/vit-base-patch16-224-in21k", num_labels=2
14
+ )
15
+ human_model.load_state_dict(torch.load("humanNsfw_Swf.pth", map_location="cpu"))
16
  human_model.eval()
17
 
18
+ # Rebuild and load Anime NSFW model
19
+ anime_model = ViTForImageClassification.from_pretrained(
20
+ "google/vit-base-patch16-224-in21k", num_labels=2
21
+ )
22
+ anime_model.load_state_dict(torch.load("animeCartoonNsfw_Sfw.pth", map_location="cpu"))
23
  anime_model.eval()
24
 
25
+ # Preprocess function
26
+ def preprocess(image: Image.Image):
27
+ inputs = processor(images=image, return_tensors="pt")
28
+ return inputs["pixel_values"]
 
29
 
30
+ # Prediction function
31
  def predict(image, model_type):
32
  if image is None:
33
  return "<div class='result-box'>Please upload an image.</div>"
34
 
35
+ inputs = preprocess(image)
36
+ model = human_model if model_type == "Human" else anime_model
37
 
38
  with torch.no_grad():
39
+ outputs = model(pixel_values=inputs)
40
+ logits = outputs.logits
41
+ probs = F.softmax(logits, dim=1)
42
+ pred_class = torch.argmax(probs, dim=1).item()
43
+ confidence = probs[0][pred_class].item()
 
 
 
 
 
 
 
 
 
44
 
45
+ label = "NSFW" if pred_class == 1 else "SFW"
46
  return f"""
47
  <div class='result-box'>
48
  <strong>Model:</strong> {model_type}<br>
 
51
  </div>
52
  """
53
 
54
+ # Custom glow box CSS
55
  custom_css = """
56
  .result-box {
57
  background-color: oklch(0.718 0.202 349.761);
 
67
  .gradio-container { max-width: 900px; margin: auto; }
68
  """
69
 
70
+ # Gradio UI
71
  with gr.Blocks(css=custom_css) as demo:
72
  gr.Markdown("## NSFW Detector (Human + Anime/Cartoon)")
73
  gr.Markdown(
74
+ "Upload an image and select the appropriate model. The system will detect whether the content is NSFW or SFW. "
75
+ "This is a side project. Results are not guaranteed. No images are stored."
76
  )
77
 
78
  with gr.Row():