sahalhes commited on
Commit
5292153
·
1 Parent(s): 348528a
Files changed (1) hide show
  1. app.py +315 -0
app.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import BlipProcessor, BlipForQuestionAnswering
5
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
6
+ import requests
7
+ from io import BytesIO
8
+ import logging
9
+
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class VQAApp:
15
+ def __init__(self):
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ logger.info(f"Using device: {self.device}")
18
+
19
+ # Initialize models
20
+ self.models = {}
21
+ self.processors = {}
22
+ self.current_model = "blip2"
23
+
24
+ # Load models
25
+ self.load_models()
26
+
27
+ def load_models(self):
28
+ """Load all available VQA models"""
29
+ try:
30
+ # BLIP-2 (Recommended for best performance)
31
+ logger.info("Loading BLIP-2 model...")
32
+ self.processors["blip2"] = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
33
+ self.models["blip2"] = Blip2ForConditionalGeneration.from_pretrained(
34
+ "Salesforce/blip2-opt-2.7b",
35
+ torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32
36
+ ).to(self.device)
37
+
38
+ # Original BLIP (Faster but less accurate)
39
+ logger.info("Loading BLIP model...")
40
+ self.processors["blip"] = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
41
+ self.models["blip"] = BlipForQuestionAnswering.from_pretrained(
42
+ "Salesforce/blip-vqa-base"
43
+ ).to(self.device)
44
+
45
+ logger.info("All models loaded successfully!")
46
+
47
+ except Exception as e:
48
+ logger.error(f"Error loading models: {str(e)}")
49
+ raise e
50
+
51
+ def answer_question(self, image, question, model_choice="blip2", max_length=50):
52
+ """
53
+ Answer a question about an image using the selected model
54
+
55
+ Args:
56
+ image: PIL Image or path to image
57
+ question: String question about the image
58
+ model_choice: Model to use ("blip2" or "blip")
59
+ max_length: Maximum length of generated answer
60
+
61
+ Returns:
62
+ String answer to the question
63
+ """
64
+ try:
65
+ if image is None:
66
+ return "Please upload an image first."
67
+
68
+ if not question.strip():
69
+ return "Please ask a question about the image."
70
+
71
+ # Ensure image is PIL Image
72
+ if isinstance(image, str):
73
+ if image.startswith('http'):
74
+ response = requests.get(image)
75
+ image = Image.open(BytesIO(response.content)).convert('RGB')
76
+ else:
77
+ image = Image.open(image).convert('RGB')
78
+ elif not isinstance(image, Image.Image):
79
+ image = Image.fromarray(image).convert('RGB')
80
+
81
+ # Get model and processor
82
+ model = self.models[model_choice]
83
+ processor = self.processors[model_choice]
84
+
85
+ if model_choice == "blip2":
86
+ # BLIP-2 processing
87
+ inputs = processor(image, question, return_tensors="pt").to(self.device)
88
+
89
+ with torch.no_grad():
90
+ generated_ids = model.generate(
91
+ **inputs,
92
+ max_length=max_length,
93
+ num_beams=5,
94
+ temperature=0.7,
95
+ do_sample=True,
96
+ top_p=0.9
97
+ )
98
+
99
+ answer = processor.decode(generated_ids[0], skip_special_tokens=True)
100
+
101
+ else: # blip
102
+ # Original BLIP processing
103
+ inputs = processor(image, question, return_tensors="pt").to(self.device)
104
+
105
+ with torch.no_grad():
106
+ outputs = model.generate(**inputs, max_length=max_length, num_beams=5)
107
+
108
+ answer = processor.decode(outputs[0], skip_special_tokens=True)
109
+
110
+ return answer.strip()
111
+
112
+ except Exception as e:
113
+ logger.error(f"Error in answer_question: {str(e)}")
114
+ return f"Error processing question: {str(e)}"
115
+
116
+ def batch_qa(self, image, questions_text):
117
+ """
118
+ Answer multiple questions about the same image
119
+
120
+ Args:
121
+ image: PIL Image
122
+ questions_text: String with questions separated by newlines
123
+
124
+ Returns:
125
+ String with questions and answers
126
+ """
127
+ if not questions_text.strip():
128
+ return "Please enter questions (one per line)."
129
+
130
+ questions = [q.strip() for q in questions_text.split('\n') if q.strip()]
131
+ results = []
132
+
133
+ for i, question in enumerate(questions, 1):
134
+ answer = self.answer_question(image, question, self.current_model)
135
+ results.append(f"Q{i}: {question}")
136
+ results.append(f"A{i}: {answer}")
137
+ results.append("")
138
+
139
+ return "\n".join(results)
140
+
141
+ def create_gradio_interface():
142
+ """Create the Gradio interface for the VQA app"""
143
+
144
+ # Initialize VQA app
145
+ vqa_app = VQAApp()
146
+
147
+ # Sample images for demo
148
+ sample_images = [
149
+ "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png",
150
+ "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/lena.png"
151
+ ]
152
+
153
+ with gr.Blocks(title="Visual Question Answering App", theme=gr.themes.Soft()) as demo:
154
+ gr.Markdown("""
155
+ # 🔍 Visual Question Answering App
156
+
157
+ Upload an image and ask questions about its content! This app uses state-of-the-art multimodal models
158
+ from Hugging Face to understand and answer questions about images.
159
+
160
+ **Models available:**
161
+ - **BLIP-2**: Advanced model with better understanding (recommended)
162
+ - **BLIP**: Faster model for quick answers
163
+ """)
164
+
165
+ with gr.Tab("Single Question"):
166
+ with gr.Row():
167
+ with gr.Column(scale=1):
168
+ image_input = gr.Image(
169
+ label="Upload Image",
170
+ type="pil",
171
+ height=300
172
+ )
173
+
174
+ model_choice = gr.Dropdown(
175
+ choices=["blip2", "blip"],
176
+ value="blip2",
177
+ label="Choose Model",
178
+ info="BLIP-2 is more accurate but slower"
179
+ )
180
+
181
+ max_length_slider = gr.Slider(
182
+ minimum=10,
183
+ maximum=100,
184
+ value=50,
185
+ step=5,
186
+ label="Max Answer Length"
187
+ )
188
+
189
+ with gr.Column(scale=1):
190
+ question_input = gr.Textbox(
191
+ label="Ask a question about the image",
192
+ placeholder="What do you see in this image?",
193
+ lines=3
194
+ )
195
+
196
+ answer_button = gr.Button("Get Answer", variant="primary", size="lg")
197
+
198
+ answer_output = gr.Textbox(
199
+ label="Answer",
200
+ lines=5,
201
+ interactive=False
202
+ )
203
+
204
+ # Example questions
205
+ gr.Markdown("### Example Questions:")
206
+ example_questions = [
207
+ "What objects are in this image?",
208
+ "What color is the main subject?",
209
+ "How many people are in the image?",
210
+ "What is the setting or location?",
211
+ "What activity is taking place?",
212
+ "What's the weather like in this image?"
213
+ ]
214
+
215
+ with gr.Row():
216
+ for i, eq in enumerate(example_questions[:3]):
217
+ gr.Button(eq, size="sm").click(
218
+ lambda q=eq: q, outputs=question_input
219
+ )
220
+
221
+ with gr.Row():
222
+ for i, eq in enumerate(example_questions[3:]):
223
+ gr.Button(eq, size="sm").click(
224
+ lambda q=eq: q, outputs=question_input
225
+ )
226
+
227
+ with gr.Tab("Multiple Questions"):
228
+ with gr.Row():
229
+ with gr.Column(scale=1):
230
+ batch_image_input = gr.Image(
231
+ label="Upload Image",
232
+ type="pil",
233
+ height=300
234
+ )
235
+
236
+ batch_model_choice = gr.Dropdown(
237
+ choices=["blip2", "blip"],
238
+ value="blip2",
239
+ label="Choose Model"
240
+ )
241
+
242
+ with gr.Column(scale=1):
243
+ batch_questions_input = gr.Textbox(
244
+ label="Questions (one per line)",
245
+ placeholder="What do you see?\nHow many objects are there?\nWhat color is dominant?",
246
+ lines=6
247
+ )
248
+
249
+ batch_button = gr.Button("Answer All Questions", variant="primary")
250
+
251
+ batch_output = gr.Textbox(
252
+ label="Questions & Answers",
253
+ lines=10,
254
+ interactive=False
255
+ )
256
+
257
+ with gr.Tab("Sample Images"):
258
+ gr.Markdown("### Try these sample images:")
259
+
260
+ with gr.Row():
261
+ for img_url in sample_images:
262
+ with gr.Column():
263
+ sample_img = gr.Image(value=img_url, label="Sample Image")
264
+ gr.Button("Use This Image").click(
265
+ lambda x=img_url: x,
266
+ outputs=image_input
267
+ )
268
+
269
+ # Event handlers
270
+ def update_model_choice(choice):
271
+ vqa_app.current_model = choice
272
+ return choice
273
+
274
+ model_choice.change(update_model_choice, inputs=model_choice)
275
+ batch_model_choice.change(update_model_choice, inputs=batch_model_choice)
276
+
277
+ answer_button.click(
278
+ vqa_app.answer_question,
279
+ inputs=[image_input, question_input, model_choice, max_length_slider],
280
+ outputs=answer_output
281
+ )
282
+
283
+ batch_button.click(
284
+ vqa_app.batch_qa,
285
+ inputs=[batch_image_input, batch_questions_input],
286
+ outputs=batch_output
287
+ )
288
+
289
+ gr.Markdown("""
290
+ ### Tips for better results:
291
+ - Use clear, specific questions
292
+ - BLIP-2 works better for complex reasoning
293
+ - Try different phrasings if you don't get good results
294
+ - Upload high-quality images for best performance
295
+ """)
296
+
297
+ return demo
298
+
299
+ # Alternative standalone functions for direct usage
300
+ def simple_vqa(image_path, question, model_name="blip2"):
301
+ vqa = VQAApp()
302
+
303
+ if isinstance(image_path, str):
304
+ image = Image.open(image_path).convert('RGB')
305
+ else:
306
+ image = image_path
307
+
308
+ return vqa.answer_question(image, question, model_name)
309
+
310
+ if __name__ == "__main__":
311
+ demo = create_gradio_interface()
312
+
313
+ demo.launch()
314
+
315
+