Chanlefe commited on
Commit
6a6e076
Β·
verified Β·
1 Parent(s): cee92ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -25
app.py CHANGED
@@ -1,46 +1,194 @@
1
  import torch
 
 
2
  from PIL import Image
3
  from transformers import AutoProcessor, AutoModelForImageClassification
4
  import gradio as gr
5
  import pytesseract
6
 
7
- # Load model and processor
8
- model = AutoModelForImageClassification.from_pretrained("Chanlefe/SigLIP2_77")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- processor = AutoProcessor.from_pretrained("Chanlefe/SigLIP2_77")
 
 
 
 
11
 
12
-
13
- labels = model.config.id2label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Classify meme and extract text
16
  def classify_meme(image: Image.Image):
17
- # OCR: extract text from image
18
- extracted_text = pytesseract.image_to_string(image)
19
-
20
- # Process image with SigLIP2 model
21
- inputs = processor(images=image, return_tensors="pt").to(model.device)
22
- with torch.no_grad():
23
- outputs = model(**inputs)
24
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
25
- predictions = {labels[i]: float(probs[0][i]) for i in range(len(labels))}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- return {
28
- "Predictions": predictions,
29
- "Extracted Text": extracted_text.strip()
30
- }
31
- print("Extracted Text:", extracted_text)
32
- print("Predictions:", predictions)
33
  # Gradio interface
34
  demo = gr.Interface(
35
  fn=classify_meme,
36
- inputs=gr.Image(type="pil"),
37
  outputs=[
38
- gr.Label(num_top_classes=2, label="Predictions"),
39
- gr.Textbox(label="Extracted Text")
40
  ],
41
  title="Meme Classifier with OCR",
42
- description="Upload a meme to classify its sentiment and extract text using OCR."
 
 
 
 
 
 
 
 
43
  )
44
 
45
  if __name__ == "__main__":
46
- demo.launch()
 
 
 
 
 
 
1
  import torch
2
+ import os
3
+ import glob
4
  from PIL import Image
5
  from transformers import AutoProcessor, AutoModelForImageClassification
6
  import gradio as gr
7
  import pytesseract
8
 
9
+ def find_model_files():
10
+ """Find model files in the current directory structure"""
11
+ print("=== Searching for model files ===")
12
+
13
+ # Look for key model files
14
+ config_files = glob.glob("**/config.json", recursive=True)
15
+ model_files = glob.glob("**/pytorch_model.bin", recursive=True) + glob.glob("**/model.safetensors", recursive=True)
16
+ preprocessor_files = glob.glob("**/preprocessor_config.json", recursive=True)
17
+
18
+ print(f"Found config.json files: {config_files}")
19
+ print(f"Found model weight files: {model_files}")
20
+ print(f"Found preprocessor_config.json files: {preprocessor_files}")
21
+
22
+ # Find the directory that contains all necessary files
23
+ for config_file in config_files:
24
+ model_dir = os.path.dirname(config_file)
25
+ if not model_dir: # If config.json is in root
26
+ model_dir = "."
27
+
28
+ # Check if this directory has all required files
29
+ has_model = any(os.path.dirname(f) == model_dir or (not os.path.dirname(f) and model_dir == ".") for f in model_files)
30
+ has_preprocessor = any(os.path.dirname(f) == model_dir or (not os.path.dirname(f) and model_dir == ".") for f in preprocessor_files)
31
+
32
+ if has_model and has_preprocessor:
33
+ print(f"Found complete model in directory: {model_dir}")
34
+ return model_dir
35
+ elif has_model:
36
+ print(f"Found model with config but missing preprocessor in: {model_dir}")
37
+ return model_dir # Try anyway, might work
38
+
39
+ print("No complete model directory found")
40
+ return None
41
 
42
+ # Search for model files
43
+ MODEL_PATH = find_model_files()
44
+ if MODEL_PATH is None:
45
+ MODEL_PATH = "." # Fallback to current directory
46
+ print("Falling back to current directory")
47
 
48
+ try:
49
+ # Load model and processor from detected path
50
+ print(f"=== Attempting to load model from: {MODEL_PATH} ===")
51
+ print(f"Current working directory: {os.getcwd()}")
52
+
53
+ # List all files in the detected model directory
54
+ if MODEL_PATH == ".":
55
+ print("Files in root directory:")
56
+ for item in os.listdir("."):
57
+ if os.path.isfile(item):
58
+ print(f" File: {item}")
59
+ else:
60
+ print(f" Directory: {item}/")
61
+ try:
62
+ sub_files = os.listdir(item)[:5] # Show first 5 files
63
+ print(f" Contains: {sub_files}{'...' if len(os.listdir(item)) > 5 else ''}")
64
+ except:
65
+ pass
66
+ else:
67
+ print(f"Files in {MODEL_PATH}:")
68
+ print(f" {os.listdir(MODEL_PATH)}")
69
+
70
+ # Try to load the model
71
+ print("Loading model...")
72
+ model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True)
73
+ print("Model loaded successfully!")
74
+
75
+ print("Loading processor...")
76
+ processor = AutoProcessor.from_pretrained(MODEL_PATH, local_files_only=True)
77
+ print("Processor loaded successfully!")
78
+
79
+ # Get labels - handle case where id2label might not exist
80
+ if hasattr(model.config, 'id2label') and model.config.id2label:
81
+ labels = model.config.id2label
82
+ else:
83
+ # Create generic labels if none exist
84
+ num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 1000
85
+ labels = {i: f"class_{i}" for i in range(num_labels)}
86
+
87
+ print(f"Model loaded successfully. Number of classes: {len(labels)}")
88
+
89
+ except Exception as e:
90
+ print(f"=== ERROR loading model from {MODEL_PATH} ===")
91
+ print(f"Error: {e}")
92
+ print("\n=== Debugging Information ===")
93
+ print("All files in Space:")
94
+
95
+ def list_all_files(directory=".", prefix=""):
96
+ """Recursively list all files"""
97
+ try:
98
+ items = sorted(os.listdir(directory))
99
+ for item in items:
100
+ item_path = os.path.join(directory, item)
101
+ if os.path.isfile(item_path):
102
+ size = os.path.getsize(item_path)
103
+ print(f"{prefix}πŸ“„ {item} ({size} bytes)")
104
+ elif os.path.isdir(item_path) and not item.startswith('.'):
105
+ print(f"{prefix}πŸ“ {item}/")
106
+ if len(prefix) < 6: # Limit recursion depth
107
+ list_all_files(item_path, prefix + " ")
108
+ except PermissionError:
109
+ print(f"{prefix}❌ Permission denied")
110
+ except Exception as ex:
111
+ print(f"{prefix}❌ Error: {ex}")
112
+
113
+ list_all_files()
114
+
115
+ print("\n=== Required Files for Model ===")
116
+ print("βœ… config.json - Model configuration")
117
+ print("βœ… pytorch_model.bin OR model.safetensors - Model weights")
118
+ print("βœ… preprocessor_config.json - Image processor config")
119
+ print("βœ… tokenizer.json (if applicable) - Tokenizer")
120
+
121
+ print("\n=== Solutions ===")
122
+ print("1. Make sure all model files are uploaded to your Space")
123
+ print("2. Check that files aren't corrupted during upload")
124
+ print("3. Try uploading to a 'model' subfolder")
125
+ print("4. Verify the model was saved correctly during training")
126
+
127
+ raise
128
 
129
  # Classify meme and extract text
130
  def classify_meme(image: Image.Image):
131
+ """
132
+ Classify meme and extract text using OCR
133
+ """
134
+ try:
135
+ # OCR: extract text from image
136
+ extracted_text = pytesseract.image_to_string(image)
137
+
138
+ # Process image with the model
139
+ inputs = processor(images=image, return_tensors="pt")
140
+
141
+ # Move inputs to same device as model if needed
142
+ if torch.cuda.is_available() and next(model.parameters()).is_cuda:
143
+ inputs = {k: v.to('cuda') for k, v in inputs.items()}
144
+
145
+ with torch.no_grad():
146
+ outputs = model(**inputs)
147
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
148
+
149
+ # Get top predictions
150
+ top_k = min(10, len(labels)) # Show top 10 or all if fewer
151
+ top_probs, top_indices = torch.topk(probs[0], top_k)
152
+
153
+ predictions = {}
154
+ for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
155
+ label = labels.get(idx.item(), f"class_{idx.item()}")
156
+ predictions[label] = float(prob)
157
+
158
+ # Debug prints (these will show in the console/logs)
159
+ print("Extracted Text:", extracted_text.strip())
160
+ print("Top Predictions:", predictions)
161
+
162
+ return predictions, extracted_text.strip()
163
+
164
+ except Exception as e:
165
+ print(f"Error in classification: {e}")
166
+ return {"Error": 1.0}, f"Error processing image: {str(e)}"
167
 
 
 
 
 
 
 
168
  # Gradio interface
169
  demo = gr.Interface(
170
  fn=classify_meme,
171
+ inputs=gr.Image(type="pil", label="Upload Meme Image"),
172
  outputs=[
173
+ gr.Label(num_top_classes=5, label="Meme Classification"),
174
+ gr.Textbox(label="Extracted Text from OCR", lines=3)
175
  ],
176
  title="Meme Classifier with OCR",
177
+ description="""
178
+ Upload a meme image to:
179
+ 1. Classify its content using your trained SigLIP2_77 model
180
+ 2. Extract text using OCR (Optical Character Recognition)
181
+
182
+ Note: Make sure all model files are properly uploaded to your Space.
183
+ """,
184
+ examples=None,
185
+ allow_flagging="never"
186
  )
187
 
188
  if __name__ == "__main__":
189
+ print("Starting Gradio interface...")
190
+ demo.launch(
191
+ server_name="0.0.0.0", # Allow external connections in HF Spaces
192
+ server_port=7860, # Standard port for HF Spaces
193
+ share=False # HF Spaces handles sharing
194
+ )