Spaces:
Running
Running
import torch | |
import os | |
from PIL import Image | |
from transformers import AutoModelForImageClassification, SiglipImageProcessor | |
import gradio as gr | |
import pytesseract | |
# Model path | |
MODEL_PATH = "./model" | |
try: | |
print(f"=== Loading model from: {MODEL_PATH} ===") | |
print(f"Available files: {os.listdir(MODEL_PATH)}") | |
# Load the model (this should work with your files) | |
print("Loading model...") | |
model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True) | |
print("β Model loaded successfully!") | |
# Load just the image processor (not the full AutoProcessor) | |
print("Loading image processor...") | |
try: | |
# Try to load the image processor from your local files | |
processor = SiglipImageProcessor.from_pretrained(MODEL_PATH, local_files_only=True) | |
print("β Image processor loaded from local files!") | |
except Exception as e: | |
print(f"β οΈ Could not load local processor: {e}") | |
print("Loading image processor from base SigLIP model...") | |
# Fallback: load processor from base model online | |
processor = SiglipImageProcessor.from_pretrained("google/siglip-base-patch16-224") | |
print("β Image processor loaded from base model!") | |
# Get labels from your model config | |
if hasattr(model.config, 'id2label') and model.config.id2label: | |
labels = model.config.id2label | |
print(f"β Found {len(labels)} labels in model config") | |
else: | |
# Create generic labels if none exist | |
num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 2 | |
labels = {i: f"class_{i}" for i in range(num_labels)} | |
print(f"β Created {len(labels)} generic labels") | |
print("π Model setup complete!") | |
except Exception as e: | |
print(f"β Error loading model: {e}") | |
print("\n=== Debug Information ===") | |
print(f"Files in model directory: {os.listdir(MODEL_PATH)}") | |
raise | |
def classify_meme(image: Image.Image): | |
""" | |
Classify meme and extract text using OCR | |
""" | |
try: | |
# OCR: extract text from image | |
extracted_text = pytesseract.image_to_string(image) | |
# Process image for the model | |
inputs = processor(images=image, return_tensors="pt") | |
# Run inference | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
# Get predictions | |
predictions = {} | |
for i in range(len(labels)): | |
label = labels.get(i, f"class_{i}") | |
predictions[label] = float(probs[0][i]) | |
# Sort predictions by confidence | |
sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True)) | |
# Debug prints | |
print("=== Classification Results ===") | |
print(f"Extracted Text: '{extracted_text.strip()}'") | |
print("Top 3 Predictions:") | |
for i, (label, prob) in enumerate(list(sorted_predictions.items())[:3]): | |
print(f" {i+1}. {label}: {prob:.4f}") | |
return sorted_predictions, extracted_text.strip() | |
except Exception as e: | |
error_msg = f"Error processing image: {str(e)}" | |
print(f"β {error_msg}") | |
return {"Error": 1.0}, error_msg | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=classify_meme, | |
inputs=gr.Image(type="pil", label="Upload Meme Image"), | |
outputs=[ | |
gr.Label(num_top_classes=5, label="Meme Classification"), | |
gr.Textbox(label="Extracted Text", lines=3) | |
], | |
title="π Meme Classifier with OCR", | |
description=""" | |
Upload a meme image to: | |
1. **Classify** its content using your trained SigLIP2_77 model | |
2. **Extract text** using OCR (Optical Character Recognition) | |
Your model was trained on meme data and will predict the category/sentiment of the uploaded meme. | |
""", | |
examples=None, | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
print("π Starting Gradio interface...") | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) | |