Spaces:
Sleeping
Sleeping
File size: 4,178 Bytes
b1779fd 6a6e076 b1779fd 47d0c84 b1779fd 131383f 47d0c84 cee92ec 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 23ad95f 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 b1779fd 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6554f18 47d0c84 b1779fd 6a6e076 3711151 6a6e076 47d0c84 3711151 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 b1779fd 47d0c84 6a6e076 47d0c84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import torch
import os
from PIL import Image
from transformers import AutoModelForImageClassification, SiglipImageProcessor
import gradio as gr
import pytesseract
# Model path
MODEL_PATH = "./model"
try:
print(f"=== Loading model from: {MODEL_PATH} ===")
print(f"Available files: {os.listdir(MODEL_PATH)}")
# Load the model (this should work with your files)
print("Loading model...")
model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True)
print("β
Model loaded successfully!")
# Load just the image processor (not the full AutoProcessor)
print("Loading image processor...")
try:
# Try to load the image processor from your local files
processor = SiglipImageProcessor.from_pretrained(MODEL_PATH, local_files_only=True)
print("β
Image processor loaded from local files!")
except Exception as e:
print(f"β οΈ Could not load local processor: {e}")
print("Loading image processor from base SigLIP model...")
# Fallback: load processor from base model online
processor = SiglipImageProcessor.from_pretrained("google/siglip-base-patch16-224")
print("β
Image processor loaded from base model!")
# Get labels from your model config
if hasattr(model.config, 'id2label') and model.config.id2label:
labels = model.config.id2label
print(f"β
Found {len(labels)} labels in model config")
else:
# Create generic labels if none exist
num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 2
labels = {i: f"class_{i}" for i in range(num_labels)}
print(f"β
Created {len(labels)} generic labels")
print("π Model setup complete!")
except Exception as e:
print(f"β Error loading model: {e}")
print("\n=== Debug Information ===")
print(f"Files in model directory: {os.listdir(MODEL_PATH)}")
raise
def classify_meme(image: Image.Image):
"""
Classify meme and extract text using OCR
"""
try:
# OCR: extract text from image
extracted_text = pytesseract.image_to_string(image)
# Process image for the model
inputs = processor(images=image, return_tensors="pt")
# Run inference
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get predictions
predictions = {}
for i in range(len(labels)):
label = labels.get(i, f"class_{i}")
predictions[label] = float(probs[0][i])
# Sort predictions by confidence
sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True))
# Debug prints
print("=== Classification Results ===")
print(f"Extracted Text: '{extracted_text.strip()}'")
print("Top 3 Predictions:")
for i, (label, prob) in enumerate(list(sorted_predictions.items())[:3]):
print(f" {i+1}. {label}: {prob:.4f}")
return sorted_predictions, extracted_text.strip()
except Exception as e:
error_msg = f"Error processing image: {str(e)}"
print(f"β {error_msg}")
return {"Error": 1.0}, error_msg
# Create Gradio interface
demo = gr.Interface(
fn=classify_meme,
inputs=gr.Image(type="pil", label="Upload Meme Image"),
outputs=[
gr.Label(num_top_classes=5, label="Meme Classification"),
gr.Textbox(label="Extracted Text", lines=3)
],
title="π Meme Classifier with OCR",
description="""
Upload a meme image to:
1. **Classify** its content using your trained SigLIP2_77 model
2. **Extract text** using OCR (Optical Character Recognition)
Your model was trained on meme data and will predict the category/sentiment of the uploaded meme.
""",
examples=None,
allow_flagging="never"
)
if __name__ == "__main__":
print("π Starting Gradio interface...")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
|