sahalhes commited on
Commit
a37e4d7
·
1 Parent(s): 5292153
Files changed (1) hide show
  1. app.py +27 -303
app.py CHANGED
@@ -1,315 +1,39 @@
1
  import gradio as gr
2
- import torch
3
  from PIL import Image
 
4
  from transformers import BlipProcessor, BlipForQuestionAnswering
5
- from transformers import Blip2Processor, Blip2ForConditionalGeneration
6
- import requests
7
- from io import BytesIO
8
- import logging
9
 
10
- # Set up logging
11
- logging.basicConfig(level=logging.INFO)
12
- logger = logging.getLogger(__name__)
13
 
14
- class VQAApp:
15
- def __init__(self):
16
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- logger.info(f"Using device: {self.device}")
18
-
19
- # Initialize models
20
- self.models = {}
21
- self.processors = {}
22
- self.current_model = "blip2"
23
-
24
- # Load models
25
- self.load_models()
26
-
27
- def load_models(self):
28
- """Load all available VQA models"""
29
- try:
30
- # BLIP-2 (Recommended for best performance)
31
- logger.info("Loading BLIP-2 model...")
32
- self.processors["blip2"] = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
33
- self.models["blip2"] = Blip2ForConditionalGeneration.from_pretrained(
34
- "Salesforce/blip2-opt-2.7b",
35
- torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32
36
- ).to(self.device)
37
-
38
- # Original BLIP (Faster but less accurate)
39
- logger.info("Loading BLIP model...")
40
- self.processors["blip"] = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
41
- self.models["blip"] = BlipForQuestionAnswering.from_pretrained(
42
- "Salesforce/blip-vqa-base"
43
- ).to(self.device)
44
-
45
- logger.info("All models loaded successfully!")
46
-
47
- except Exception as e:
48
- logger.error(f"Error loading models: {str(e)}")
49
- raise e
50
-
51
- def answer_question(self, image, question, model_choice="blip2", max_length=50):
52
- """
53
- Answer a question about an image using the selected model
54
-
55
- Args:
56
- image: PIL Image or path to image
57
- question: String question about the image
58
- model_choice: Model to use ("blip2" or "blip")
59
- max_length: Maximum length of generated answer
60
-
61
- Returns:
62
- String answer to the question
63
- """
64
- try:
65
- if image is None:
66
- return "Please upload an image first."
67
-
68
- if not question.strip():
69
- return "Please ask a question about the image."
70
-
71
- # Ensure image is PIL Image
72
- if isinstance(image, str):
73
- if image.startswith('http'):
74
- response = requests.get(image)
75
- image = Image.open(BytesIO(response.content)).convert('RGB')
76
- else:
77
- image = Image.open(image).convert('RGB')
78
- elif not isinstance(image, Image.Image):
79
- image = Image.fromarray(image).convert('RGB')
80
-
81
- # Get model and processor
82
- model = self.models[model_choice]
83
- processor = self.processors[model_choice]
84
-
85
- if model_choice == "blip2":
86
- # BLIP-2 processing
87
- inputs = processor(image, question, return_tensors="pt").to(self.device)
88
-
89
- with torch.no_grad():
90
- generated_ids = model.generate(
91
- **inputs,
92
- max_length=max_length,
93
- num_beams=5,
94
- temperature=0.7,
95
- do_sample=True,
96
- top_p=0.9
97
- )
98
-
99
- answer = processor.decode(generated_ids[0], skip_special_tokens=True)
100
-
101
- else: # blip
102
- # Original BLIP processing
103
- inputs = processor(image, question, return_tensors="pt").to(self.device)
104
-
105
- with torch.no_grad():
106
- outputs = model.generate(**inputs, max_length=max_length, num_beams=5)
107
-
108
- answer = processor.decode(outputs[0], skip_special_tokens=True)
109
-
110
- return answer.strip()
111
-
112
- except Exception as e:
113
- logger.error(f"Error in answer_question: {str(e)}")
114
- return f"Error processing question: {str(e)}"
115
-
116
- def batch_qa(self, image, questions_text):
117
- """
118
- Answer multiple questions about the same image
119
-
120
- Args:
121
- image: PIL Image
122
- questions_text: String with questions separated by newlines
123
-
124
- Returns:
125
- String with questions and answers
126
- """
127
- if not questions_text.strip():
128
- return "Please enter questions (one per line)."
129
-
130
- questions = [q.strip() for q in questions_text.split('\n') if q.strip()]
131
- results = []
132
-
133
- for i, question in enumerate(questions, 1):
134
- answer = self.answer_question(image, question, self.current_model)
135
- results.append(f"Q{i}: {question}")
136
- results.append(f"A{i}: {answer}")
137
- results.append("")
138
-
139
- return "\n".join(results)
140
 
141
- def create_gradio_interface():
142
- """Create the Gradio interface for the VQA app"""
143
-
144
- # Initialize VQA app
145
- vqa_app = VQAApp()
146
 
147
- # Sample images for demo
148
- sample_images = [
149
- "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png",
150
- "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/lena.png"
151
- ]
152
 
153
- with gr.Blocks(title="Visual Question Answering App", theme=gr.themes.Soft()) as demo:
154
- gr.Markdown("""
155
- # 🔍 Visual Question Answering App
156
-
157
- Upload an image and ask questions about its content! This app uses state-of-the-art multimodal models
158
- from Hugging Face to understand and answer questions about images.
159
-
160
- **Models available:**
161
- - **BLIP-2**: Advanced model with better understanding (recommended)
162
- - **BLIP**: Faster model for quick answers
163
- """)
164
-
165
- with gr.Tab("Single Question"):
166
- with gr.Row():
167
- with gr.Column(scale=1):
168
- image_input = gr.Image(
169
- label="Upload Image",
170
- type="pil",
171
- height=300
172
- )
173
-
174
- model_choice = gr.Dropdown(
175
- choices=["blip2", "blip"],
176
- value="blip2",
177
- label="Choose Model",
178
- info="BLIP-2 is more accurate but slower"
179
- )
180
-
181
- max_length_slider = gr.Slider(
182
- minimum=10,
183
- maximum=100,
184
- value=50,
185
- step=5,
186
- label="Max Answer Length"
187
- )
188
-
189
- with gr.Column(scale=1):
190
- question_input = gr.Textbox(
191
- label="Ask a question about the image",
192
- placeholder="What do you see in this image?",
193
- lines=3
194
- )
195
-
196
- answer_button = gr.Button("Get Answer", variant="primary", size="lg")
197
-
198
- answer_output = gr.Textbox(
199
- label="Answer",
200
- lines=5,
201
- interactive=False
202
- )
203
-
204
- # Example questions
205
- gr.Markdown("### Example Questions:")
206
- example_questions = [
207
- "What objects are in this image?",
208
- "What color is the main subject?",
209
- "How many people are in the image?",
210
- "What is the setting or location?",
211
- "What activity is taking place?",
212
- "What's the weather like in this image?"
213
- ]
214
-
215
- with gr.Row():
216
- for i, eq in enumerate(example_questions[:3]):
217
- gr.Button(eq, size="sm").click(
218
- lambda q=eq: q, outputs=question_input
219
- )
220
-
221
- with gr.Row():
222
- for i, eq in enumerate(example_questions[3:]):
223
- gr.Button(eq, size="sm").click(
224
- lambda q=eq: q, outputs=question_input
225
- )
226
-
227
- with gr.Tab("Multiple Questions"):
228
- with gr.Row():
229
- with gr.Column(scale=1):
230
- batch_image_input = gr.Image(
231
- label="Upload Image",
232
- type="pil",
233
- height=300
234
- )
235
-
236
- batch_model_choice = gr.Dropdown(
237
- choices=["blip2", "blip"],
238
- value="blip2",
239
- label="Choose Model"
240
- )
241
-
242
- with gr.Column(scale=1):
243
- batch_questions_input = gr.Textbox(
244
- label="Questions (one per line)",
245
- placeholder="What do you see?\nHow many objects are there?\nWhat color is dominant?",
246
- lines=6
247
- )
248
-
249
- batch_button = gr.Button("Answer All Questions", variant="primary")
250
-
251
- batch_output = gr.Textbox(
252
- label="Questions & Answers",
253
- lines=10,
254
- interactive=False
255
- )
256
-
257
- with gr.Tab("Sample Images"):
258
- gr.Markdown("### Try these sample images:")
259
-
260
- with gr.Row():
261
- for img_url in sample_images:
262
- with gr.Column():
263
- sample_img = gr.Image(value=img_url, label="Sample Image")
264
- gr.Button("Use This Image").click(
265
- lambda x=img_url: x,
266
- outputs=image_input
267
- )
268
-
269
- # Event handlers
270
- def update_model_choice(choice):
271
- vqa_app.current_model = choice
272
- return choice
273
-
274
- model_choice.change(update_model_choice, inputs=model_choice)
275
- batch_model_choice.change(update_model_choice, inputs=batch_model_choice)
276
-
277
- answer_button.click(
278
- vqa_app.answer_question,
279
- inputs=[image_input, question_input, model_choice, max_length_slider],
280
- outputs=answer_output
281
- )
282
-
283
- batch_button.click(
284
- vqa_app.batch_qa,
285
- inputs=[batch_image_input, batch_questions_input],
286
- outputs=batch_output
287
- )
288
-
289
- gr.Markdown("""
290
- ### Tips for better results:
291
- - Use clear, specific questions
292
- - BLIP-2 works better for complex reasoning
293
- - Try different phrasings if you don't get good results
294
- - Upload high-quality images for best performance
295
- """)
296
-
297
- return demo
298
 
299
- # Alternative standalone functions for direct usage
300
- def simple_vqa(image_path, question, model_name="blip2"):
301
- vqa = VQAApp()
302
-
303
- if isinstance(image_path, str):
304
- image = Image.open(image_path).convert('RGB')
305
- else:
306
- image = image_path
307
-
308
- return vqa.answer_question(image, question, model_name)
 
309
 
310
  if __name__ == "__main__":
311
- demo = create_gradio_interface()
312
-
313
  demo.launch()
314
-
315
-
 
1
  import gradio as gr
 
2
  from PIL import Image
3
+ import torch
4
  from transformers import BlipProcessor, BlipForQuestionAnswering
 
 
 
 
5
 
6
+ # Load processor and small BLIP VQA model
7
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
8
+ model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
9
 
10
+ # Use CPU explicitly
11
+ device = torch.device("cpu")
12
+ model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # VQA function
15
+ def answer_question(image: Image.Image, question: str) -> str:
16
+ # Prepare input
17
+ inputs = processor(image.convert("RGB"), question, return_tensors="pt").to(device)
 
18
 
19
+ # Generate answer
20
+ with torch.no_grad():
21
+ output = model.generate(**inputs)
 
 
22
 
23
+ # Decode answer
24
+ return processor.decode(output[0], skip_special_tokens=True).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Gradio interface
27
+ demo = gr.Interface(
28
+ fn=answer_question,
29
+ inputs=[
30
+ gr.Image(type="pil", label="Upload an Image"),
31
+ gr.Textbox(label="Ask a Question About the Image")
32
+ ],
33
+ outputs=gr.Textbox(label="Answer"),
34
+ title="BLIP Visual Question Answering (CPU Friendly)",
35
+ description="Ask a question about an image using Salesforce's BLIP VQA Base model."
36
+ )
37
 
38
  if __name__ == "__main__":
 
 
39
  demo.launch()