|
import gradio as gr |
|
from transformers import AutoModelForImageClassification, AutoImageProcessor |
|
from PIL import Image |
|
import torch |
|
|
|
|
|
|
|
|
|
model_name = "Falconsai/nsfw_image_detection" |
|
try: |
|
model = AutoModelForImageClassification.from_pretrained(model_name, trust_remote_code=True) |
|
processor = AutoImageProcessor.from_pretrained(model_name, trust_remote_code=True) |
|
print(f"Successfully loaded model and processor for {model_name}") |
|
except Exception as e: |
|
print(f"Error loading model or processor: {e}") |
|
|
|
|
|
raise |
|
|
|
|
|
def predict_nsfw(image_pil): |
|
""" |
|
Predicts if an image is NSFW or Normal. |
|
|
|
Args: |
|
image_pil (PIL.Image.Image): The input image in PIL format. |
|
|
|
Returns: |
|
dict: A dictionary with class labels as keys and probabilities as values. |
|
Example: {'nsfw': 0.95, 'normal': 0.05} |
|
""" |
|
if image_pil is None: |
|
return {"Error": "No image provided. Please upload an image."} |
|
|
|
try: |
|
|
|
if image_pil.mode != "RGB": |
|
image_pil = image_pil.convert("RGB") |
|
print("Converted image to RGB") |
|
|
|
|
|
inputs = processor(images=image_pil, return_tensors="pt") |
|
print("Image processed") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
print(f"Logits: {logits}") |
|
|
|
|
|
probabilities = torch.softmax(logits, dim=1) |
|
print(f"Probabilities: {probabilities}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
labels = model.config.id2label |
|
if not labels: |
|
print("Warning: model.config.id2label not found. Using default labels ['normal', 'nsfw']") |
|
labels = {0: 'normal', 1: 'nsfw'} |
|
|
|
|
|
results = {labels[i]: prob.item() for i, prob in enumerate(probabilities[0])} |
|
print(f"Results: {results}") |
|
return results |
|
|
|
except Exception as e: |
|
print(f"Error during prediction: {e}") |
|
return {"Error": f"An error occurred during processing: {str(e)}"} |
|
|
|
|
|
title = "NSFW Image Detection (Falconsai)" |
|
description = ( |
|
"Upload an image to classify it as NSFW (Not Safe For Work) or Normal. " |
|
"This app uses the `Falconsai/nsfw_image_detection` model from Hugging Face." |
|
) |
|
article = ( |
|
"<p style='text-align: center;'>" |
|
"<a href='https://huggingface.co/Falconsai/nsfw_image_detection' target='_blank'>" |
|
"Model Card: Falconsai/nsfw_image_detection" |
|
"</a>" |
|
"</p>" |
|
) |
|
|
|
|
|
image_input = gr.Image(type="pil", label="Upload Image") |
|
output_label = gr.Label(num_top_classes=2, label="Prediction") |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict_nsfw, |
|
inputs=image_input, |
|
outputs=output_label, |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=[ |
|
|
|
|
|
|
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
iface.launch() |
|
|