File size: 4,194 Bytes
6117b47
7c472b7
 
 
6117b47
7c472b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import torch

# Load the model and processor
# Using trust_remote_code=True because the model might have custom code.
# This is generally safe for models from trusted sources like Hugging Face.
model_name = "Falconsai/nsfw_image_detection"
try:
    model = AutoModelForImageClassification.from_pretrained(model_name, trust_remote_code=True)
    processor = AutoImageProcessor.from_pretrained(model_name, trust_remote_code=True)
    print(f"Successfully loaded model and processor for {model_name}")
except Exception as e:
    print(f"Error loading model or processor: {e}")
    # Fallback or error handling if model loading fails
    # For now, we'll let Gradio show an error if this happens.
    raise

# Define the prediction function
def predict_nsfw(image_pil):
    """
    Predicts if an image is NSFW or Normal.

    Args:
        image_pil (PIL.Image.Image): The input image in PIL format.

    Returns:
        dict: A dictionary with class labels as keys and probabilities as values.
              Example: {'nsfw': 0.95, 'normal': 0.05}
    """
    if image_pil is None:
        return {"Error": "No image provided. Please upload an image."}

    try:
        # Ensure image is in RGB format
        if image_pil.mode != "RGB":
            image_pil = image_pil.convert("RGB")
            print("Converted image to RGB")

        # Preprocess the image
        inputs = processor(images=image_pil, return_tensors="pt")
        print("Image processed")

        # Make prediction
        with torch.no_grad(): # Disable gradient calculations for inference
            outputs = model(**inputs)
            logits = outputs.logits
        print(f"Logits: {logits}")

        # Apply softmax to get probabilities
        probabilities = torch.softmax(logits, dim=1)
        print(f"Probabilities: {probabilities}")

        # Get the labels from the model's config
        # id2label is usually: {0: 'normal', 1: 'nsfw'} or similar
        # We need to confirm the exact labels for this model.
        # From the model card, it seems to be 'normal' and 'nsfw'.
        # Let's assume standard ordering or check model.config.id2label
        labels = model.config.id2label
        if not labels: # Fallback if id2label is not in config
            print("Warning: model.config.id2label not found. Using default labels ['normal', 'nsfw']")
            labels = {0: 'normal', 1: 'nsfw'} # Common convention

        # Create a dictionary of labels and their probabilities
        results = {labels[i]: prob.item() for i, prob in enumerate(probabilities[0])}
        print(f"Results: {results}")
        return results

    except Exception as e:
        print(f"Error during prediction: {e}")
        return {"Error": f"An error occurred during processing: {str(e)}"}

# Create the Gradio interface
title = "NSFW Image Detection (Falconsai)"
description = (
    "Upload an image to classify it as NSFW (Not Safe For Work) or Normal. "
    "This app uses the `Falconsai/nsfw_image_detection` model from Hugging Face."
)
article = (
    "<p style='text-align: center;'>"
    "<a href='https://huggingface.co/Falconsai/nsfw_image_detection' target='_blank'>"
    "Model Card: Falconsai/nsfw_image_detection"
    "</a>"
    "</p>"
)

# Define the input and output components for Gradio
image_input = gr.Image(type="pil", label="Upload Image")
output_label = gr.Label(num_top_classes=2, label="Prediction") # Shows top 2 classes and their scores

# Launch the Gradio app
iface = gr.Interface(
    fn=predict_nsfw,
    inputs=image_input,
    outputs=output_label,
    title=title,
    description=description,
    article=article,
    examples=[
        # You can add paths to example images here if you have them in your Space
        # For example: ["path/to/example_normal.jpg", "path/to/example_nsfw.jpg"]
        # For now, we'll leave it empty or Gradio will show an upload box.
    ]
)

if __name__ == "__main__":
    # To run locally for testing before pushing to Spaces
    # iface.launch(share=True) # share=True creates a public link (temporary)
    iface.launch()