Chanlefe commited on
Commit
7949bfb
Β·
verified Β·
1 Parent(s): fceb20b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -19
app.py CHANGED
@@ -3,7 +3,25 @@ import os
3
  from PIL import Image
4
  from transformers import AutoModelForImageClassification, SiglipImageProcessor
5
  import gradio as gr
6
- import pytesseract
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # Model path
9
  MODEL_PATH = "./model"
@@ -12,30 +30,27 @@ try:
12
  print(f"=== Loading model from: {MODEL_PATH} ===")
13
  print(f"Available files: {os.listdir(MODEL_PATH)}")
14
 
15
- # Load the model (this should work with your files)
16
  print("Loading model...")
17
  model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True)
18
  print("βœ… Model loaded successfully!")
19
 
20
- # Load just the image processor (not the full AutoProcessor)
21
  print("Loading image processor...")
22
  try:
23
- # Try to load the image processor from your local files
24
  processor = SiglipImageProcessor.from_pretrained(MODEL_PATH, local_files_only=True)
25
  print("βœ… Image processor loaded from local files!")
26
  except Exception as e:
27
  print(f"⚠️ Could not load local processor: {e}")
28
  print("Loading image processor from base SigLIP model...")
29
- # Fallback: load processor from base model online
30
  processor = SiglipImageProcessor.from_pretrained("google/siglip-base-patch16-224")
31
  print("βœ… Image processor loaded from base model!")
32
 
33
- # Get labels from your model config
34
  if hasattr(model.config, 'id2label') and model.config.id2label:
35
  labels = model.config.id2label
36
  print(f"βœ… Found {len(labels)} labels in model config")
37
  else:
38
- # Create generic labels if none exist
39
  num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 2
40
  labels = {i: f"class_{i}" for i in range(num_labels)}
41
  print(f"βœ… Created {len(labels)} generic labels")
@@ -44,19 +59,40 @@ try:
44
 
45
  except Exception as e:
46
  print(f"❌ Error loading model: {e}")
47
- print("\n=== Debug Information ===")
48
  print(f"Files in model directory: {os.listdir(MODEL_PATH)}")
49
  raise
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def classify_meme(image: Image.Image):
52
  """
53
- Classify meme and extract text using OCR
54
  """
55
  try:
56
- # OCR: extract text from image
57
- extracted_text = pytesseract.image_to_string(image)
 
 
 
58
 
59
- # Process image for the model
60
  inputs = processor(images=image, return_tensors="pt")
61
 
62
  # Run inference
@@ -95,13 +131,13 @@ demo = gr.Interface(
95
  gr.Label(num_top_classes=5, label="Meme Classification"),
96
  gr.Textbox(label="Extracted Text", lines=3)
97
  ],
98
- title="🎭 Meme Classifier with OCR",
99
- description="""
100
- Upload a meme image to:
101
- 1. **Classify** its content using your trained SigLIP2_77 model
102
- 2. **Extract text** using OCR (Optical Character Recognition)
103
 
104
- Your model was trained on meme data and will predict the category/sentiment of the uploaded meme.
105
  """,
106
  examples=None,
107
  allow_flagging="never"
@@ -113,4 +149,4 @@ if __name__ == "__main__":
113
  server_name="0.0.0.0",
114
  server_port=7860,
115
  share=False
116
- )
 
3
  from PIL import Image
4
  from transformers import AutoModelForImageClassification, SiglipImageProcessor
5
  import gradio as gr
6
+
7
+ # Alternative OCR using transformers
8
+ def setup_alternative_ocr():
9
+ """Setup alternative OCR using transformers models"""
10
+ try:
11
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
12
+ print("Setting up TrOCR for text extraction...")
13
+
14
+ ocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
15
+ ocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
16
+
17
+ print("βœ… TrOCR model loaded successfully!")
18
+ return ocr_processor, ocr_model, True
19
+ except Exception as e:
20
+ print(f"⚠️ Could not load TrOCR: {e}")
21
+ return None, None, False
22
+
23
+ # Try to setup OCR
24
+ OCR_PROCESSOR, OCR_MODEL, OCR_AVAILABLE = setup_alternative_ocr()
25
 
26
  # Model path
27
  MODEL_PATH = "./model"
 
30
  print(f"=== Loading model from: {MODEL_PATH} ===")
31
  print(f"Available files: {os.listdir(MODEL_PATH)}")
32
 
33
+ # Load the model
34
  print("Loading model...")
35
  model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True)
36
  print("βœ… Model loaded successfully!")
37
 
38
+ # Load image processor
39
  print("Loading image processor...")
40
  try:
 
41
  processor = SiglipImageProcessor.from_pretrained(MODEL_PATH, local_files_only=True)
42
  print("βœ… Image processor loaded from local files!")
43
  except Exception as e:
44
  print(f"⚠️ Could not load local processor: {e}")
45
  print("Loading image processor from base SigLIP model...")
 
46
  processor = SiglipImageProcessor.from_pretrained("google/siglip-base-patch16-224")
47
  print("βœ… Image processor loaded from base model!")
48
 
49
+ # Get labels
50
  if hasattr(model.config, 'id2label') and model.config.id2label:
51
  labels = model.config.id2label
52
  print(f"βœ… Found {len(labels)} labels in model config")
53
  else:
 
54
  num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 2
55
  labels = {i: f"class_{i}" for i in range(num_labels)}
56
  print(f"βœ… Created {len(labels)} generic labels")
 
59
 
60
  except Exception as e:
61
  print(f"❌ Error loading model: {e}")
 
62
  print(f"Files in model directory: {os.listdir(MODEL_PATH)}")
63
  raise
64
 
65
+ def extract_text_alternative(image):
66
+ """Extract text using TrOCR model"""
67
+ if not OCR_AVAILABLE:
68
+ return "OCR not available"
69
+
70
+ try:
71
+ # Convert to RGB if needed
72
+ if image.mode != 'RGB':
73
+ image = image.convert('RGB')
74
+
75
+ # Process with TrOCR
76
+ pixel_values = OCR_PROCESSOR(image, return_tensors="pt").pixel_values
77
+ generated_ids = OCR_MODEL.generate(pixel_values)
78
+ generated_text = OCR_PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
79
+
80
+ return generated_text
81
+ except Exception as e:
82
+ return f"OCR error: {str(e)}"
83
+
84
  def classify_meme(image: Image.Image):
85
  """
86
+ Classify meme and extract text
87
  """
88
  try:
89
+ # Extract text using alternative OCR
90
+ if OCR_AVAILABLE:
91
+ extracted_text = extract_text_alternative(image)
92
+ else:
93
+ extracted_text = "OCR not available in this environment"
94
 
95
+ # Process image for classification
96
  inputs = processor(images=image, return_tensors="pt")
97
 
98
  # Run inference
 
131
  gr.Label(num_top_classes=5, label="Meme Classification"),
132
  gr.Textbox(label="Extracted Text", lines=3)
133
  ],
134
+ title="🎭 Meme Classifier" + (" with TrOCR" if OCR_AVAILABLE else ""),
135
+ description=f"""
136
+ Upload a meme image to **classify** its content using your trained SigLIP2_77 model.
137
+
138
+ {'βœ… **Text extraction** available via TrOCR (Microsoft Transformer OCR)' if OCR_AVAILABLE else '⚠️ **Text extraction** not available'}
139
 
140
+ Your model will predict the category/sentiment of the uploaded meme.
141
  """,
142
  examples=None,
143
  allow_flagging="never"
 
149
  server_name="0.0.0.0",
150
  server_port=7860,
151
  share=False
152
+ )