import gradio as gr from transformers import AutoModelForImageClassification, AutoImageProcessor from PIL import Image import torch # Load the model and processor # Using trust_remote_code=True because the model might have custom code. # This is generally safe for models from trusted sources like Hugging Face. model_name = "Falconsai/nsfw_image_detection" try: model = AutoModelForImageClassification.from_pretrained(model_name, trust_remote_code=True) processor = AutoImageProcessor.from_pretrained(model_name, trust_remote_code=True) print(f"Successfully loaded model and processor for {model_name}") except Exception as e: print(f"Error loading model or processor: {e}") # Fallback or error handling if model loading fails # For now, we'll let Gradio show an error if this happens. raise # Define the prediction function def predict_nsfw(image_pil): """ Predicts if an image is NSFW or Normal. Args: image_pil (PIL.Image.Image): The input image in PIL format. Returns: dict: A dictionary with class labels as keys and probabilities as values. Example: {'nsfw': 0.95, 'normal': 0.05} """ if image_pil is None: return {"Error": "No image provided. Please upload an image."} try: # Ensure image is in RGB format if image_pil.mode != "RGB": image_pil = image_pil.convert("RGB") print("Converted image to RGB") # Preprocess the image inputs = processor(images=image_pil, return_tensors="pt") print("Image processed") # Make prediction with torch.no_grad(): # Disable gradient calculations for inference outputs = model(**inputs) logits = outputs.logits print(f"Logits: {logits}") # Apply softmax to get probabilities probabilities = torch.softmax(logits, dim=1) print(f"Probabilities: {probabilities}") # Get the labels from the model's config # id2label is usually: {0: 'normal', 1: 'nsfw'} or similar # We need to confirm the exact labels for this model. # From the model card, it seems to be 'normal' and 'nsfw'. # Let's assume standard ordering or check model.config.id2label labels = model.config.id2label if not labels: # Fallback if id2label is not in config print("Warning: model.config.id2label not found. Using default labels ['normal', 'nsfw']") labels = {0: 'normal', 1: 'nsfw'} # Common convention # Create a dictionary of labels and their probabilities results = {labels[i]: prob.item() for i, prob in enumerate(probabilities[0])} print(f"Results: {results}") return results except Exception as e: print(f"Error during prediction: {e}") return {"Error": f"An error occurred during processing: {str(e)}"} # Create the Gradio interface title = "NSFW Image Detection (Falconsai)" description = ( "Upload an image to classify it as NSFW (Not Safe For Work) or Normal. " "This app uses the `Falconsai/nsfw_image_detection` model from Hugging Face." ) article = ( "
" ) # Define the input and output components for Gradio image_input = gr.Image(type="pil", label="Upload Image") output_label = gr.Label(num_top_classes=2, label="Prediction") # Shows top 2 classes and their scores # Launch the Gradio app iface = gr.Interface( fn=predict_nsfw, inputs=image_input, outputs=output_label, title=title, description=description, article=article, examples=[ # You can add paths to example images here if you have them in your Space # For example: ["path/to/example_normal.jpg", "path/to/example_nsfw.jpg"] # For now, we'll leave it empty or Gradio will show an upload box. ] ) if __name__ == "__main__": # To run locally for testing before pushing to Spaces # iface.launch(share=True) # share=True creates a public link (temporary) iface.launch()