File size: 7,387 Bytes
43655b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27b4f7a
43655b0
 
 
2b933d2
43655b0
 
 
 
 
 
 
 
 
 
 
2b933d2
43655b0
 
 
2b933d2
43655b0
 
 
 
 
27b4f7a
43655b0
727b540
43655b0
 
2b933d2
43655b0
 
2b933d2
27b4f7a
 
 
 
43655b0
727b540
 
2b933d2
 
727b540
2b933d2
727b540
 
 
 
 
2b933d2
 
 
727b540
 
 
 
 
 
2b933d2
 
727b540
2b933d2
727b540
2b933d2
727b540
 
2b933d2
727b540
 
2b933d2
 
727b540
 
27b4f7a
 
 
2b933d2
 
 
27b4f7a
 
2b933d2
727b540
2b933d2
 
727b540
2b933d2
 
727b540
 
2b933d2
 
 
 
 
727b540
 
 
 
2b933d2
727b540
 
2b933d2
 
 
 
 
 
 
 
727b540
 
 
 
c892189
 
 
2b933d2
727b540
 
43655b0
727b540
 
 
 
2b933d2
c892189
727b540
 
 
 
 
2b933d2
727b540
2b933d2
727b540
 
27b4f7a
727b540
 
 
 
 
 
 
2b933d2
727b540
 
 
 
 
2b933d2
727b540
 
 
 
2b933d2
727b540
c892189
727b540
 
 
c892189
 
727b540
c892189
27b4f7a
2b933d2
727b540
 
 
2b933d2
27b4f7a
 
727b540
43655b0
 
 
2b933d2
43655b0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import gradio as gr
import torch
from PIL import Image as PILImage
from transformers import AutoImageProcessor, SiglipForImageClassification
import os
import warnings

# --- Configuration ---
MODEL_IDENTIFIER = r"Ateeqq/ai-vs-human-image-detector"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Suppress specific warnings ---
warnings.filterwarnings("ignore", message="Possibly corrupt EXIF data.")
warnings.filterwarnings("ignore", message=".*You are using the default legacy behaviour.*")

# --- Load Model and Processor (Load once at startup) ---
print(f"Using device: {DEVICE}")
print(f"Loading processor from: {MODEL_IDENTIFIER}")
try:
    processor = AutoImageProcessor.from_pretrained(MODEL_IDENTIFIER)
    print(f"Loading model from: {MODEL_IDENTIFIER}")
    model = SiglipForImageClassification.from_pretrained(MODEL_IDENTIFIER)
    model.to(DEVICE)
    model.eval()
    print("Model and processor loaded successfully.")
except Exception as e:
    print(f"FATAL: Error loading model or processor: {e}")
    raise gr.Error(f"Failed to load the model: {e}. Cannot start the application.") from e

# --- Prediction Function ---
def classify_image(image_pil):
    if image_pil is None:
        print("Warning: No image provided.")
        return {}

    print("Processing image...")
    try:
        image = image_pil.convert("RGB")
        inputs = processor(images=image, return_tensors="pt").to(DEVICE)

        print("Running inference...")
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits

        probabilities = torch.softmax(logits, dim=-1)[0]
        results = {}
        for i, prob in enumerate(probabilities):
            label = model.config.id2label[i]
            results[label] = round(prob.item(), 4)

        print(f"Prediction results: {results}")
        return results
    except Exception as e:
        print(f"Error during prediction: {e}")
        return {"Error": f"Processing failed. Please try again or use a different image."}

# --- Define Example Images ---
example_dir = "examples"
example_images = []
if os.path.exists(example_dir) and os.listdir(example_dir):
    for img_name in os.listdir(example_dir):
        if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
            example_images.append(os.path.join(example_dir, img_name))
    if example_images:
        print(f"Found examples: {example_images}")
    else:
        print("No valid image files found in 'examples' directory.")
else:
    print("No 'examples' directory found or it's empty. Examples will not be shown.")

# --- Custom CSS for Dark Theme Adjustments ---
# Minimal CSS - let the dark theme handle most things
css = """
body { font-family: 'Inter', sans-serif; }

/* Style the main title */
#app-title {
    text-align: center;
    font-weight: bold;
    font-size: 2.5em;
    margin-bottom: 5px;
    /* color removed - let theme handle */
}

/* Style the description */
#app-description {
    text-align: center;
    font-size: 1.1em;
    margin-bottom: 25px;
    /* color removed - let theme handle */
}
#app-description code { /* Style model name - theme might handle this, but can force */
    font-weight: bold;
    background-color: rgba(255, 255, 255, 0.1); /* Slightly lighter background for code */
    padding: 2px 5px;
    border-radius: 4px;
    color: #c5f7dc; /* Light green text for code block */
}
#app-description strong { /* Style device name */
    color: #2dd4bf; /* Brighter teal/emerald for dark theme */
    font-weight: bold;
}

/* Style the results heading */
#results-heading {
    text-align: center;
    font-size: 1.2em;
    margin-bottom: 10px;
    /* color removed - let theme handle */
}

/* Add some definition to input/output columns if needed */
#input-column, #output-column {
    border: 1px solid #4b5563; /* Darker border for dark theme */
    border-radius: 12px;
    padding: 20px;
    box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow, works on dark too */
    /* background-color removed - let theme handle */
}

/* Ensure label text inside columns is readable */
#prediction-label .label-name { font-weight: bold; font-size: 1.1em; }
#prediction-label .confidence { font-size: 1em; }


/* Footer styling */
#app-footer {
    margin-top: 40px;
    padding-top: 20px;
    border-top: 1px solid #374151; /* Darker border for footer */
    text-align: center;
    font-size: 0.9em;
    /* color removed - let theme handle */
}
#app-footer a {
    color: #60a5fa; /* Lighter blue for links */
    text-decoration: none;
}
#app-footer a:hover {
    text-decoration: underline;
}
"""

# --- Gradio Interface using Blocks and Theme ---
# Use the theme string identifier for the dark mode variant
# Other options: "default/dark", "monochrome/dark", "glass/dark"
with gr.Blocks(theme="soft/dark", css=css) as iface:  # <<< CHANGE IS HERE
    # Title and Description
    gr.Markdown("# AI vs Human Image Detector", elem_id="app-title")
    gr.Markdown(
        f"Upload an image to classify if it was likely generated by AI or created by a human. "
        f"Uses the `{MODEL_IDENTIFIER}` model. Running on **{str(DEVICE).upper()}**.",
        elem_id="app-description"
    )

    # Main layout
    with gr.Row(variant='panel'):
        with gr.Column(scale=1, min_width=300, elem_id="input-column"):
            image_input = gr.Image(
                type="pil",
                label="πŸ–ΌοΈ Upload Your Image",
                sources=["upload", "webcam", "clipboard"],
                height=400,
            )
            submit_button = gr.Button("πŸ” Classify Image", variant="primary")

        with gr.Column(scale=1, min_width=300, elem_id="output-column"):
            gr.Markdown("πŸ“Š **Prediction Results**", elem_id="results-heading")
            result_output = gr.Label(
                num_top_classes=2,
                label="Classification",
                elem_id="prediction-label"
            )

    # Examples Section
    if example_images:
        gr.Examples(
            examples=example_images,
            inputs=image_input,
            outputs=result_output,
            fn=classify_image,
            cache_examples=True,
            label="✨ Click an Example to Try!"
        )

    # Footer / Article section
    gr.Markdown(f"""
       ---
       **How it Works:**
       This application uses a fine-tuned [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) vision model
       specifically trained to differentiate between images generated by Artificial Intelligence and those created by humans.

       **Model:**
       *   You can find the model card here: <a href='https://huggingface.co/{MODEL_IDENTIFIER}' target='_blank'>{MODEL_IDENTIFIER}</a>

       **Training Code:**
       Fine tuning code available at [https://exnrt.com/blog/ai/fine-tuning-siglip2/](https://exnrt.com/blog/ai/fine-tuning-siglip2/).
       """,
       elem_id="app-footer"
    )

    # Connect events
    submit_button.click(fn=classify_image, inputs=image_input, outputs=result_output, api_name="classify_image_button")
    image_input.change(fn=classify_image, inputs=image_input, outputs=result_output, api_name="classify_image_change")

# --- Launch the App ---
if __name__ == "__main__":
    print("Launching Gradio interface...")
    iface.launch()
    print("Gradio interface launched.")