kaisex commited on
Commit
55b1acf
·
verified ·
1 Parent(s): bcf874d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -9
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn.functional as F
@@ -5,32 +7,28 @@ from torchvision import transforms
5
  from PIL import Image
6
  from transformers import ViTForImageClassification, ViTImageProcessor
7
 
8
- # Load processor (same for both)
9
  processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
10
 
11
- # Rebuild and load Human NSFW model
12
  human_model = ViTForImageClassification.from_pretrained(
13
  "google/vit-base-patch16-224-in21k", num_labels=2
14
  )
15
  human_model.load_state_dict(torch.load("humanNsfw_Swf.pth", map_location="cpu"))
16
  human_model.eval()
17
 
18
- # Rebuild and load Anime NSFW model
19
  anime_model = ViTForImageClassification.from_pretrained(
20
  "google/vit-base-patch16-224-in21k", num_labels=2
21
  )
22
  anime_model.load_state_dict(torch.load("animeCartoonNsfw_Sfw.pth", map_location="cpu"))
23
  anime_model.eval()
24
 
25
- # Preprocess function
26
  def preprocess(image: Image.Image):
27
  inputs = processor(images=image, return_tensors="pt")
28
  return inputs["pixel_values"]
29
 
30
- # Prediction function
31
  def predict(image, model_type):
32
  if image is None:
33
- return "<div class='result-box'>Please upload an image.</div>"
34
 
35
  inputs = preprocess(image)
36
  model = human_model if model_type == "Human" else anime_model
@@ -50,8 +48,6 @@ def predict(image, model_type):
50
  <strong>Confidence:</strong> {confidence:.2%}
51
  </div>
52
  """
53
-
54
- # Custom glow box CSS
55
  custom_css = """
56
  .result-box {
57
  position: relative;
@@ -102,7 +98,7 @@ custom_css = """
102
  """
103
 
104
 
105
- # Gradio UI
106
  with gr.Blocks(css=custom_css) as demo:
107
  gr.Markdown("## NSFW Detector (Human and Anime/Cartoon)")
108
  gr.Markdown(
 
1
+ #srlsy bruh... checkin the code??
2
+
3
  import gradio as gr
4
  import torch
5
  import torch.nn.functional as F
 
7
  from PIL import Image
8
  from transformers import ViTForImageClassification, ViTImageProcessor
9
 
 
10
  processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
11
 
 
12
  human_model = ViTForImageClassification.from_pretrained(
13
  "google/vit-base-patch16-224-in21k", num_labels=2
14
  )
15
  human_model.load_state_dict(torch.load("humanNsfw_Swf.pth", map_location="cpu"))
16
  human_model.eval()
17
 
 
18
  anime_model = ViTForImageClassification.from_pretrained(
19
  "google/vit-base-patch16-224-in21k", num_labels=2
20
  )
21
  anime_model.load_state_dict(torch.load("animeCartoonNsfw_Sfw.pth", map_location="cpu"))
22
  anime_model.eval()
23
 
24
+
25
  def preprocess(image: Image.Image):
26
  inputs = processor(images=image, return_tensors="pt")
27
  return inputs["pixel_values"]
28
 
 
29
  def predict(image, model_type):
30
  if image is None:
31
+ return "<div class='result-box'>pls upload an img.</div>"
32
 
33
  inputs = preprocess(image)
34
  model = human_model if model_type == "Human" else anime_model
 
48
  <strong>Confidence:</strong> {confidence:.2%}
49
  </div>
50
  """
 
 
51
  custom_css = """
52
  .result-box {
53
  position: relative;
 
98
  """
99
 
100
 
101
+ # ui
102
  with gr.Blocks(css=custom_css) as demo:
103
  gr.Markdown("## NSFW Detector (Human and Anime/Cartoon)")
104
  gr.Markdown(