imseldrith's picture
Update app.py
7c472b7 verified
import gradio as gr
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import torch
# Load the model and processor
# Using trust_remote_code=True because the model might have custom code.
# This is generally safe for models from trusted sources like Hugging Face.
model_name = "Falconsai/nsfw_image_detection"
try:
model = AutoModelForImageClassification.from_pretrained(model_name, trust_remote_code=True)
processor = AutoImageProcessor.from_pretrained(model_name, trust_remote_code=True)
print(f"Successfully loaded model and processor for {model_name}")
except Exception as e:
print(f"Error loading model or processor: {e}")
# Fallback or error handling if model loading fails
# For now, we'll let Gradio show an error if this happens.
raise
# Define the prediction function
def predict_nsfw(image_pil):
"""
Predicts if an image is NSFW or Normal.
Args:
image_pil (PIL.Image.Image): The input image in PIL format.
Returns:
dict: A dictionary with class labels as keys and probabilities as values.
Example: {'nsfw': 0.95, 'normal': 0.05}
"""
if image_pil is None:
return {"Error": "No image provided. Please upload an image."}
try:
# Ensure image is in RGB format
if image_pil.mode != "RGB":
image_pil = image_pil.convert("RGB")
print("Converted image to RGB")
# Preprocess the image
inputs = processor(images=image_pil, return_tensors="pt")
print("Image processed")
# Make prediction
with torch.no_grad(): # Disable gradient calculations for inference
outputs = model(**inputs)
logits = outputs.logits
print(f"Logits: {logits}")
# Apply softmax to get probabilities
probabilities = torch.softmax(logits, dim=1)
print(f"Probabilities: {probabilities}")
# Get the labels from the model's config
# id2label is usually: {0: 'normal', 1: 'nsfw'} or similar
# We need to confirm the exact labels for this model.
# From the model card, it seems to be 'normal' and 'nsfw'.
# Let's assume standard ordering or check model.config.id2label
labels = model.config.id2label
if not labels: # Fallback if id2label is not in config
print("Warning: model.config.id2label not found. Using default labels ['normal', 'nsfw']")
labels = {0: 'normal', 1: 'nsfw'} # Common convention
# Create a dictionary of labels and their probabilities
results = {labels[i]: prob.item() for i, prob in enumerate(probabilities[0])}
print(f"Results: {results}")
return results
except Exception as e:
print(f"Error during prediction: {e}")
return {"Error": f"An error occurred during processing: {str(e)}"}
# Create the Gradio interface
title = "NSFW Image Detection (Falconsai)"
description = (
"Upload an image to classify it as NSFW (Not Safe For Work) or Normal. "
"This app uses the `Falconsai/nsfw_image_detection` model from Hugging Face."
)
article = (
"<p style='text-align: center;'>"
"<a href='https://huggingface.co/Falconsai/nsfw_image_detection' target='_blank'>"
"Model Card: Falconsai/nsfw_image_detection"
"</a>"
"</p>"
)
# Define the input and output components for Gradio
image_input = gr.Image(type="pil", label="Upload Image")
output_label = gr.Label(num_top_classes=2, label="Prediction") # Shows top 2 classes and their scores
# Launch the Gradio app
iface = gr.Interface(
fn=predict_nsfw,
inputs=image_input,
outputs=output_label,
title=title,
description=description,
article=article,
examples=[
# You can add paths to example images here if you have them in your Space
# For example: ["path/to/example_normal.jpg", "path/to/example_nsfw.jpg"]
# For now, we'll leave it empty or Gradio will show an upload box.
]
)
if __name__ == "__main__":
# To run locally for testing before pushing to Spaces
# iface.launch(share=True) # share=True creates a public link (temporary)
iface.launch()