File size: 4,178 Bytes
b1779fd
6a6e076
b1779fd
47d0c84
b1779fd
131383f
 
47d0c84
 
cee92ec
6a6e076
47d0c84
 
6a6e076
47d0c84
6a6e076
 
47d0c84
6a6e076
47d0c84
 
23ad95f
47d0c84
 
 
 
 
 
 
 
 
6a6e076
47d0c84
6a6e076
 
47d0c84
6a6e076
 
47d0c84
6a6e076
47d0c84
6a6e076
47d0c84
6a6e076
 
47d0c84
 
 
6a6e076
b1779fd
 
6a6e076
 
 
 
 
 
 
47d0c84
6a6e076
 
47d0c84
6a6e076
 
 
 
47d0c84
6a6e076
47d0c84
 
 
 
 
 
6a6e076
47d0c84
 
 
 
 
 
6a6e076
47d0c84
6a6e076
 
47d0c84
 
 
6554f18
47d0c84
b1779fd
 
6a6e076
3711151
6a6e076
47d0c84
3711151
47d0c84
6a6e076
 
47d0c84
 
6a6e076
47d0c84
6a6e076
 
 
b1779fd
 
 
47d0c84
6a6e076
47d0c84
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import os
from PIL import Image
from transformers import AutoModelForImageClassification, SiglipImageProcessor
import gradio as gr
import pytesseract

# Model path
MODEL_PATH = "./model"

try:
    print(f"=== Loading model from: {MODEL_PATH} ===")
    print(f"Available files: {os.listdir(MODEL_PATH)}")
    
    # Load the model (this should work with your files)
    print("Loading model...")
    model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True)
    print("βœ… Model loaded successfully!")
    
    # Load just the image processor (not the full AutoProcessor)
    print("Loading image processor...")
    try:
        # Try to load the image processor from your local files
        processor = SiglipImageProcessor.from_pretrained(MODEL_PATH, local_files_only=True)
        print("βœ… Image processor loaded from local files!")
    except Exception as e:
        print(f"⚠️  Could not load local processor: {e}")
        print("Loading image processor from base SigLIP model...")
        # Fallback: load processor from base model online
        processor = SiglipImageProcessor.from_pretrained("google/siglip-base-patch16-224")
        print("βœ… Image processor loaded from base model!")
    
    # Get labels from your model config
    if hasattr(model.config, 'id2label') and model.config.id2label:
        labels = model.config.id2label
        print(f"βœ… Found {len(labels)} labels in model config")
    else:
        # Create generic labels if none exist
        num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 2
        labels = {i: f"class_{i}" for i in range(num_labels)}
        print(f"βœ… Created {len(labels)} generic labels")
    
    print("πŸŽ‰ Model setup complete!")
    
except Exception as e:
    print(f"❌ Error loading model: {e}")
    print("\n=== Debug Information ===")
    print(f"Files in model directory: {os.listdir(MODEL_PATH)}")
    raise

def classify_meme(image: Image.Image):
    """
    Classify meme and extract text using OCR
    """
    try:
        # OCR: extract text from image
        extracted_text = pytesseract.image_to_string(image)
        
        # Process image for the model
        inputs = processor(images=image, return_tensors="pt")
        
        # Run inference
        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
        
        # Get predictions
        predictions = {}
        for i in range(len(labels)):
            label = labels.get(i, f"class_{i}")
            predictions[label] = float(probs[0][i])
        
        # Sort predictions by confidence
        sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True))
        
        # Debug prints
        print("=== Classification Results ===")
        print(f"Extracted Text: '{extracted_text.strip()}'")
        print("Top 3 Predictions:")
        for i, (label, prob) in enumerate(list(sorted_predictions.items())[:3]):
            print(f"  {i+1}. {label}: {prob:.4f}")
        
        return sorted_predictions, extracted_text.strip()
        
    except Exception as e:
        error_msg = f"Error processing image: {str(e)}"
        print(f"❌ {error_msg}")
        return {"Error": 1.0}, error_msg

# Create Gradio interface
demo = gr.Interface(
    fn=classify_meme,
    inputs=gr.Image(type="pil", label="Upload Meme Image"),
    outputs=[
        gr.Label(num_top_classes=5, label="Meme Classification"),
        gr.Textbox(label="Extracted Text", lines=3)
    ],
    title="🎭 Meme Classifier with OCR",
    description="""
    Upload a meme image to:
    1. **Classify** its content using your trained SigLIP2_77 model
    2. **Extract text** using OCR (Optical Character Recognition)
    
    Your model was trained on meme data and will predict the category/sentiment of the uploaded meme.
    """,
    examples=None,
    allow_flagging="never"
)

if __name__ == "__main__":
    print("πŸš€ Starting Gradio interface...")
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False
    )