Spaces:
Running
Running
#srlsy bruh... checkin the code?? | |
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from PIL import Image | |
from transformers import ViTForImageClassification, ViTImageProcessor | |
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") | |
human_model = ViTForImageClassification.from_pretrained( | |
"google/vit-base-patch16-224-in21k", num_labels=2 | |
) | |
human_model.load_state_dict(torch.load("humanNsfw_Swf.pth", map_location="cpu")) | |
human_model.eval() | |
anime_model = ViTForImageClassification.from_pretrained( | |
"google/vit-base-patch16-224-in21k", num_labels=2 | |
) | |
anime_model.load_state_dict(torch.load("animeCartoonNsfw_Sfw.pth", map_location="cpu")) | |
anime_model.eval() | |
def preprocess(image: Image.Image): | |
inputs = processor(images=image, return_tensors="pt") | |
return inputs["pixel_values"] | |
def predict(image, model_type): | |
if image is None: | |
return "<div class='result-box'>pls upload an img...</div>" | |
inputs = preprocess(image) | |
model = human_model if model_type == "Human" else anime_model | |
with torch.no_grad(): | |
outputs = model(pixel_values=inputs) | |
logits = outputs.logits | |
probs = F.softmax(logits, dim=1) | |
pred_class = torch.argmax(probs, dim=1).item() | |
confidence = probs[0][pred_class].item() | |
label = "NSFW" if pred_class == 0 else "SFW" | |
return f""" | |
<div class='result-box'> | |
<strong>Model:</strong> {model_type}<br> | |
<strong>Prediction:</strong> {label}<br> | |
<strong>Confidence:</strong> {confidence:.2%} | |
</div> | |
""" | |
custom_css = """ | |
.result-box { | |
position: relative; | |
background-color: black; | |
padding: 20px; | |
border-radius: 12px; | |
color: white; | |
font-size: 1.2rem; | |
text-align: center; | |
font-weight: bold; | |
width: 100%; | |
z-index: 1; | |
overflow: hidden; | |
box-shadow: 0 0 15px oklch(0.718 0.202 349.761); | |
} | |
.result-box::before { | |
content: ""; | |
position: absolute; | |
top: -4px; | |
left: -4px; | |
right: -4px; | |
bottom: -4px; | |
background: conic-gradient(from 0deg, oklch(0.718 0.202 349.761), transparent 40%, oklch(0.718 0.202 349.761)); | |
border-radius: 16px; | |
animation: spin 3s linear infinite; | |
z-index: -1; | |
filter: blur(8px); | |
} | |
@keyframes spin { | |
0% { transform: rotate(0deg); } | |
100% { transform: rotate(360deg); } | |
} | |
.disclaimer { | |
color: white; | |
font-size: 0.9rem; | |
text-align: center; | |
margin-top: 40px; | |
text-shadow: 0 0 10px oklch(0.718 0.202 349.761); | |
} | |
.gradio-container { | |
max-width: 900px; | |
margin: auto; | |
} | |
""" | |
# ui | |
with gr.Blocks(css=custom_css) as demo: | |
gr.Markdown("## NSFW Detector (Human and Anime/Cartoon)") | |
gr.Markdown( | |
"Upload an img and select the appropriate model. The system will detect whether the content is NSFW or SFW." | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
model_choice = gr.Radio(["Human", "Anime"], label="Select Model Type", value="Human") | |
image_input = gr.Image(type="pil", label="Upload Image") | |
with gr.Column(scale=1): | |
output_box = gr.HTML("<div class='result-box'>Awaiting input...</div>") | |
image_input.change(fn=predict, inputs=[image_input, model_choice], outputs=output_box) | |
model_choice.change(fn=predict, inputs=[image_input, model_choice], outputs=output_box) | |
# Disclaimer with glow | |
gr.Markdown( | |
"<div class='disclaimer'>This is a side project. Results are not guaranteed. No images are stored.For more info, pls check readme file</div>" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |