Spaces:
Sleeping
Sleeping
File size: 5,446 Bytes
b1779fd 6a6e076 b1779fd 47d0c84 b1779fd 7949bfb 131383f 47d0c84 cee92ec 6a6e076 47d0c84 6a6e076 7949bfb 6a6e076 47d0c84 6a6e076 7949bfb 47d0c84 23ad95f 47d0c84 6a6e076 7949bfb 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 b1779fd 7949bfb b1779fd 6a6e076 7949bfb 6a6e076 7949bfb 6a6e076 7949bfb 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6a6e076 47d0c84 6554f18 47d0c84 b1779fd 6a6e076 3711151 6a6e076 47d0c84 3711151 7949bfb 6a6e076 7949bfb 6a6e076 b1779fd 47d0c84 6a6e076 47d0c84 7949bfb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
import os
from PIL import Image
from transformers import AutoModelForImageClassification, SiglipImageProcessor
import gradio as gr
# Alternative OCR using transformers
def setup_alternative_ocr():
"""Setup alternative OCR using transformers models"""
try:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
print("Setting up TrOCR for text extraction...")
ocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
ocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
print("β
TrOCR model loaded successfully!")
return ocr_processor, ocr_model, True
except Exception as e:
print(f"β οΈ Could not load TrOCR: {e}")
return None, None, False
# Try to setup OCR
OCR_PROCESSOR, OCR_MODEL, OCR_AVAILABLE = setup_alternative_ocr()
# Model path
MODEL_PATH = "./model"
try:
print(f"=== Loading model from: {MODEL_PATH} ===")
print(f"Available files: {os.listdir(MODEL_PATH)}")
# Load the model
print("Loading model...")
model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True)
print("β
Model loaded successfully!")
# Load image processor
print("Loading image processor...")
try:
processor = SiglipImageProcessor.from_pretrained(MODEL_PATH, local_files_only=True)
print("β
Image processor loaded from local files!")
except Exception as e:
print(f"β οΈ Could not load local processor: {e}")
print("Loading image processor from base SigLIP model...")
processor = SiglipImageProcessor.from_pretrained("google/siglip-base-patch16-224")
print("β
Image processor loaded from base model!")
# Get labels
if hasattr(model.config, 'id2label') and model.config.id2label:
labels = model.config.id2label
print(f"β
Found {len(labels)} labels in model config")
else:
num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 2
labels = {i: f"class_{i}" for i in range(num_labels)}
print(f"β
Created {len(labels)} generic labels")
print("π Model setup complete!")
except Exception as e:
print(f"β Error loading model: {e}")
print(f"Files in model directory: {os.listdir(MODEL_PATH)}")
raise
def extract_text_alternative(image):
"""Extract text using TrOCR model"""
if not OCR_AVAILABLE:
return "OCR not available"
try:
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
# Process with TrOCR
pixel_values = OCR_PROCESSOR(image, return_tensors="pt").pixel_values
generated_ids = OCR_MODEL.generate(pixel_values)
generated_text = OCR_PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
except Exception as e:
return f"OCR error: {str(e)}"
def classify_meme(image: Image.Image):
"""
Classify meme and extract text
"""
try:
# Extract text using alternative OCR
if OCR_AVAILABLE:
extracted_text = extract_text_alternative(image)
else:
extracted_text = "OCR not available in this environment"
# Process image for classification
inputs = processor(images=image, return_tensors="pt")
# Run inference
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get predictions
predictions = {}
for i in range(len(labels)):
label = labels.get(i, f"class_{i}")
predictions[label] = float(probs[0][i])
# Sort predictions by confidence
sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True))
# Debug prints
print("=== Classification Results ===")
print(f"Extracted Text: '{extracted_text.strip()}'")
print("Top 3 Predictions:")
for i, (label, prob) in enumerate(list(sorted_predictions.items())[:3]):
print(f" {i+1}. {label}: {prob:.4f}")
return sorted_predictions, extracted_text.strip()
except Exception as e:
error_msg = f"Error processing image: {str(e)}"
print(f"β {error_msg}")
return {"Error": 1.0}, error_msg
# Create Gradio interface
demo = gr.Interface(
fn=classify_meme,
inputs=gr.Image(type="pil", label="Upload Meme Image"),
outputs=[
gr.Label(num_top_classes=5, label="Meme Classification"),
gr.Textbox(label="Extracted Text", lines=3)
],
title="π Meme Classifier" + (" with TrOCR" if OCR_AVAILABLE else ""),
description=f"""
Upload a meme image to **classify** its content using your trained SigLIP2_77 model.
{'β
**Text extraction** available via TrOCR (Microsoft Transformer OCR)' if OCR_AVAILABLE else 'β οΈ **Text extraction** not available'}
Your model will predict the category/sentiment of the uploaded meme.
""",
examples=None,
allow_flagging="never"
)
if __name__ == "__main__":
print("π Starting Gradio interface...")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
) |