File size: 5,446 Bytes
b1779fd
6a6e076
b1779fd
47d0c84
b1779fd
7949bfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131383f
47d0c84
 
cee92ec
6a6e076
47d0c84
 
6a6e076
7949bfb
6a6e076
 
47d0c84
6a6e076
7949bfb
47d0c84
23ad95f
47d0c84
 
 
 
 
 
 
6a6e076
7949bfb
6a6e076
 
47d0c84
6a6e076
47d0c84
6a6e076
47d0c84
6a6e076
47d0c84
6a6e076
 
47d0c84
 
6a6e076
b1779fd
7949bfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1779fd
6a6e076
7949bfb
6a6e076
 
7949bfb
 
 
 
 
6a6e076
7949bfb
6a6e076
 
47d0c84
6a6e076
 
 
 
47d0c84
6a6e076
47d0c84
 
 
 
 
 
6a6e076
47d0c84
 
 
 
 
 
6a6e076
47d0c84
6a6e076
 
47d0c84
 
 
6554f18
47d0c84
b1779fd
 
6a6e076
3711151
6a6e076
47d0c84
3711151
7949bfb
 
 
 
 
6a6e076
7949bfb
6a6e076
 
 
b1779fd
 
 
47d0c84
6a6e076
47d0c84
 
 
7949bfb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
import os
from PIL import Image
from transformers import AutoModelForImageClassification, SiglipImageProcessor
import gradio as gr

# Alternative OCR using transformers
def setup_alternative_ocr():
    """Setup alternative OCR using transformers models"""
    try:
        from transformers import TrOCRProcessor, VisionEncoderDecoderModel
        print("Setting up TrOCR for text extraction...")
        
        ocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
        ocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
        
        print("βœ… TrOCR model loaded successfully!")
        return ocr_processor, ocr_model, True
    except Exception as e:
        print(f"⚠️  Could not load TrOCR: {e}")
        return None, None, False

# Try to setup OCR
OCR_PROCESSOR, OCR_MODEL, OCR_AVAILABLE = setup_alternative_ocr()

# Model path
MODEL_PATH = "./model"

try:
    print(f"=== Loading model from: {MODEL_PATH} ===")
    print(f"Available files: {os.listdir(MODEL_PATH)}")
    
    # Load the model
    print("Loading model...")
    model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True)
    print("βœ… Model loaded successfully!")
    
    # Load image processor
    print("Loading image processor...")
    try:
        processor = SiglipImageProcessor.from_pretrained(MODEL_PATH, local_files_only=True)
        print("βœ… Image processor loaded from local files!")
    except Exception as e:
        print(f"⚠️  Could not load local processor: {e}")
        print("Loading image processor from base SigLIP model...")
        processor = SiglipImageProcessor.from_pretrained("google/siglip-base-patch16-224")
        print("βœ… Image processor loaded from base model!")
    
    # Get labels
    if hasattr(model.config, 'id2label') and model.config.id2label:
        labels = model.config.id2label
        print(f"βœ… Found {len(labels)} labels in model config")
    else:
        num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 2
        labels = {i: f"class_{i}" for i in range(num_labels)}
        print(f"βœ… Created {len(labels)} generic labels")
    
    print("πŸŽ‰ Model setup complete!")
    
except Exception as e:
    print(f"❌ Error loading model: {e}")
    print(f"Files in model directory: {os.listdir(MODEL_PATH)}")
    raise

def extract_text_alternative(image):
    """Extract text using TrOCR model"""
    if not OCR_AVAILABLE:
        return "OCR not available"
    
    try:
        # Convert to RGB if needed
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Process with TrOCR
        pixel_values = OCR_PROCESSOR(image, return_tensors="pt").pixel_values
        generated_ids = OCR_MODEL.generate(pixel_values)
        generated_text = OCR_PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        return generated_text
    except Exception as e:
        return f"OCR error: {str(e)}"

def classify_meme(image: Image.Image):
    """
    Classify meme and extract text
    """
    try:
        # Extract text using alternative OCR
        if OCR_AVAILABLE:
            extracted_text = extract_text_alternative(image)
        else:
            extracted_text = "OCR not available in this environment"
        
        # Process image for classification
        inputs = processor(images=image, return_tensors="pt")
        
        # Run inference
        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
        
        # Get predictions
        predictions = {}
        for i in range(len(labels)):
            label = labels.get(i, f"class_{i}")
            predictions[label] = float(probs[0][i])
        
        # Sort predictions by confidence
        sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True))
        
        # Debug prints
        print("=== Classification Results ===")
        print(f"Extracted Text: '{extracted_text.strip()}'")
        print("Top 3 Predictions:")
        for i, (label, prob) in enumerate(list(sorted_predictions.items())[:3]):
            print(f"  {i+1}. {label}: {prob:.4f}")
        
        return sorted_predictions, extracted_text.strip()
        
    except Exception as e:
        error_msg = f"Error processing image: {str(e)}"
        print(f"❌ {error_msg}")
        return {"Error": 1.0}, error_msg

# Create Gradio interface
demo = gr.Interface(
    fn=classify_meme,
    inputs=gr.Image(type="pil", label="Upload Meme Image"),
    outputs=[
        gr.Label(num_top_classes=5, label="Meme Classification"),
        gr.Textbox(label="Extracted Text", lines=3)
    ],
    title="🎭 Meme Classifier" + (" with TrOCR" if OCR_AVAILABLE else ""),
    description=f"""
    Upload a meme image to **classify** its content using your trained SigLIP2_77 model.
    
    {'βœ… **Text extraction** available via TrOCR (Microsoft Transformer OCR)' if OCR_AVAILABLE else '⚠️  **Text extraction** not available'}
    
    Your model will predict the category/sentiment of the uploaded meme.
    """,
    examples=None,
    allow_flagging="never"
)

if __name__ == "__main__":
    print("πŸš€ Starting Gradio interface...")
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False
    )