#srlsy bruh... checkin the code??
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from transformers import ViTForImageClassification, ViTImageProcessor
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
human_model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224-in21k", num_labels=2
)
human_model.load_state_dict(torch.load("humanNsfw_Swf.pth", map_location="cpu"))
human_model.eval()
anime_model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224-in21k", num_labels=2
)
anime_model.load_state_dict(torch.load("animeCartoonNsfw_Sfw.pth", map_location="cpu"))
anime_model.eval()
def preprocess(image: Image.Image):
inputs = processor(images=image, return_tensors="pt")
return inputs["pixel_values"]
def predict(image, model_type):
if image is None:
return "
pls upload an img...
"
inputs = preprocess(image)
model = human_model if model_type == "Human" else anime_model
with torch.no_grad():
outputs = model(pixel_values=inputs)
logits = outputs.logits
probs = F.softmax(logits, dim=1)
pred_class = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred_class].item()
label = "NSFW" if pred_class == 0 else "SFW"
return f"""
Model: {model_type}
Prediction: {label}
Confidence: {confidence:.2%}
"""
custom_css = """
.result-box {
position: relative;
background-color: black;
padding: 20px;
border-radius: 12px;
color: white;
font-size: 1.2rem;
text-align: center;
font-weight: bold;
width: 100%;
z-index: 1;
overflow: hidden;
box-shadow: 0 0 15px oklch(0.718 0.202 349.761);
}
.result-box::before {
content: "";
position: absolute;
top: -4px;
left: -4px;
right: -4px;
bottom: -4px;
background: conic-gradient(from 0deg, oklch(0.718 0.202 349.761), transparent 40%, oklch(0.718 0.202 349.761));
border-radius: 16px;
animation: spin 3s linear infinite;
z-index: -1;
filter: blur(8px);
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.disclaimer {
color: white;
font-size: 0.9rem;
text-align: center;
margin-top: 40px;
text-shadow: 0 0 10px oklch(0.718 0.202 349.761);
}
.gradio-container {
max-width: 900px;
margin: auto;
}
"""
# ui
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("## NSFW Detector (Human and Anime/Cartoon)")
gr.Markdown(
"Upload an img and select the appropriate model. The system will detect whether the content is NSFW or SFW."
)
with gr.Row():
with gr.Column(scale=1):
model_choice = gr.Radio(["Human", "Anime"], label="Select Model Type", value="Human")
image_input = gr.Image(type="pil", label="Upload Image")
with gr.Column(scale=1):
output_box = gr.HTML("Awaiting input...
")
image_input.change(fn=predict, inputs=[image_input, model_choice], outputs=output_box)
model_choice.change(fn=predict, inputs=[image_input, model_choice], outputs=output_box)
# Disclaimer with glow
gr.Markdown(
"This is a side project. Results are not guaranteed. No images are stored.For more info, pls check readme file
"
)
if __name__ == "__main__":
demo.launch()