Ateeqq's picture
Update app.py
c892189 verified
raw
history blame
7.39 kB
import gradio as gr
import torch
from PIL import Image as PILImage
from transformers import AutoImageProcessor, SiglipForImageClassification
import os
import warnings
# --- Configuration ---
MODEL_IDENTIFIER = r"Ateeqq/ai-vs-human-image-detector"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Suppress specific warnings ---
warnings.filterwarnings("ignore", message="Possibly corrupt EXIF data.")
warnings.filterwarnings("ignore", message=".*You are using the default legacy behaviour.*")
# --- Load Model and Processor (Load once at startup) ---
print(f"Using device: {DEVICE}")
print(f"Loading processor from: {MODEL_IDENTIFIER}")
try:
processor = AutoImageProcessor.from_pretrained(MODEL_IDENTIFIER)
print(f"Loading model from: {MODEL_IDENTIFIER}")
model = SiglipForImageClassification.from_pretrained(MODEL_IDENTIFIER)
model.to(DEVICE)
model.eval()
print("Model and processor loaded successfully.")
except Exception as e:
print(f"FATAL: Error loading model or processor: {e}")
raise gr.Error(f"Failed to load the model: {e}. Cannot start the application.") from e
# --- Prediction Function ---
def classify_image(image_pil):
if image_pil is None:
print("Warning: No image provided.")
return {}
print("Processing image...")
try:
image = image_pil.convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(DEVICE)
print("Running inference...")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)[0]
results = {}
for i, prob in enumerate(probabilities):
label = model.config.id2label[i]
results[label] = round(prob.item(), 4)
print(f"Prediction results: {results}")
return results
except Exception as e:
print(f"Error during prediction: {e}")
return {"Error": f"Processing failed. Please try again or use a different image."}
# --- Define Example Images ---
example_dir = "examples"
example_images = []
if os.path.exists(example_dir) and os.listdir(example_dir):
for img_name in os.listdir(example_dir):
if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
example_images.append(os.path.join(example_dir, img_name))
if example_images:
print(f"Found examples: {example_images}")
else:
print("No valid image files found in 'examples' directory.")
else:
print("No 'examples' directory found or it's empty. Examples will not be shown.")
# --- Custom CSS for Dark Theme Adjustments ---
# Minimal CSS - let the dark theme handle most things
css = """
body { font-family: 'Inter', sans-serif; }
/* Style the main title */
#app-title {
text-align: center;
font-weight: bold;
font-size: 2.5em;
margin-bottom: 5px;
/* color removed - let theme handle */
}
/* Style the description */
#app-description {
text-align: center;
font-size: 1.1em;
margin-bottom: 25px;
/* color removed - let theme handle */
}
#app-description code { /* Style model name - theme might handle this, but can force */
font-weight: bold;
background-color: rgba(255, 255, 255, 0.1); /* Slightly lighter background for code */
padding: 2px 5px;
border-radius: 4px;
color: #c5f7dc; /* Light green text for code block */
}
#app-description strong { /* Style device name */
color: #2dd4bf; /* Brighter teal/emerald for dark theme */
font-weight: bold;
}
/* Style the results heading */
#results-heading {
text-align: center;
font-size: 1.2em;
margin-bottom: 10px;
/* color removed - let theme handle */
}
/* Add some definition to input/output columns if needed */
#input-column, #output-column {
border: 1px solid #4b5563; /* Darker border for dark theme */
border-radius: 12px;
padding: 20px;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); /* Subtle shadow, works on dark too */
/* background-color removed - let theme handle */
}
/* Ensure label text inside columns is readable */
#prediction-label .label-name { font-weight: bold; font-size: 1.1em; }
#prediction-label .confidence { font-size: 1em; }
/* Footer styling */
#app-footer {
margin-top: 40px;
padding-top: 20px;
border-top: 1px solid #374151; /* Darker border for footer */
text-align: center;
font-size: 0.9em;
/* color removed - let theme handle */
}
#app-footer a {
color: #60a5fa; /* Lighter blue for links */
text-decoration: none;
}
#app-footer a:hover {
text-decoration: underline;
}
"""
# --- Gradio Interface using Blocks and Theme ---
# Use the theme string identifier for the dark mode variant
# Other options: "default/dark", "monochrome/dark", "glass/dark"
with gr.Blocks(theme="soft/dark", css=css) as iface: # <<< CHANGE IS HERE
# Title and Description
gr.Markdown("# AI vs Human Image Detector", elem_id="app-title")
gr.Markdown(
f"Upload an image to classify if it was likely generated by AI or created by a human. "
f"Uses the `{MODEL_IDENTIFIER}` model. Running on **{str(DEVICE).upper()}**.",
elem_id="app-description"
)
# Main layout
with gr.Row(variant='panel'):
with gr.Column(scale=1, min_width=300, elem_id="input-column"):
image_input = gr.Image(
type="pil",
label="πŸ–ΌοΈ Upload Your Image",
sources=["upload", "webcam", "clipboard"],
height=400,
)
submit_button = gr.Button("πŸ” Classify Image", variant="primary")
with gr.Column(scale=1, min_width=300, elem_id="output-column"):
gr.Markdown("πŸ“Š **Prediction Results**", elem_id="results-heading")
result_output = gr.Label(
num_top_classes=2,
label="Classification",
elem_id="prediction-label"
)
# Examples Section
if example_images:
gr.Examples(
examples=example_images,
inputs=image_input,
outputs=result_output,
fn=classify_image,
cache_examples=True,
label="✨ Click an Example to Try!"
)
# Footer / Article section
gr.Markdown(f"""
---
**How it Works:**
This application uses a fine-tuned [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) vision model
specifically trained to differentiate between images generated by Artificial Intelligence and those created by humans.
**Model:**
* You can find the model card here: <a href='https://huggingface.co/{MODEL_IDENTIFIER}' target='_blank'>{MODEL_IDENTIFIER}</a>
**Training Code:**
Fine tuning code available at [https://exnrt.com/blog/ai/fine-tuning-siglip2/](https://exnrt.com/blog/ai/fine-tuning-siglip2/).
""",
elem_id="app-footer"
)
# Connect events
submit_button.click(fn=classify_image, inputs=image_input, outputs=result_output, api_name="classify_image_button")
image_input.change(fn=classify_image, inputs=image_input, outputs=result_output, api_name="classify_image_change")
# --- Launch the App ---
if __name__ == "__main__":
print("Launching Gradio interface...")
iface.launch()
print("Gradio interface launched.")