mknolan commited on
Commit
aa73c2c
·
verified ·
1 Parent(s): 3a1e660

Upload slide analyzer application

Browse files
Files changed (1) hide show
  1. app.py +433 -0
app.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import numpy as np
5
+ import tempfile
6
+ import torch
7
+ import torchvision.transforms as T
8
+ from torchvision.transforms.functional import InterpolationMode
9
+ from PIL import Image
10
+ import gradio as gr
11
+ from transformers import AutoModel, AutoTokenizer
12
+ import io
13
+ import pdf2image
14
+ from pptx import Presentation
15
+
16
+ # Constants
17
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
18
+ IMAGENET_STD = (0.229, 0.224, 0.225)
19
+
20
+ # Configuration
21
+ MODEL_NAME = "OpenGVLab/InternVL2_5-8B"
22
+ IMAGE_SIZE = 448
23
+
24
+ # Set up environment variables
25
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
26
+
27
+ # Utility functions for image processing
28
+ def build_transform(input_size):
29
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
30
+ transform = T.Compose([
31
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
32
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
33
+ T.ToTensor(),
34
+ T.Normalize(mean=MEAN, std=STD)
35
+ ])
36
+ return transform
37
+
38
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
39
+ best_ratio_diff = float('inf')
40
+ best_ratio = (1, 1)
41
+ area = width * height
42
+ for ratio in target_ratios:
43
+ target_aspect_ratio = ratio[0] / ratio[1]
44
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
45
+ if ratio_diff < best_ratio_diff:
46
+ best_ratio_diff = ratio_diff
47
+ best_ratio = ratio
48
+ elif ratio_diff == best_ratio_diff:
49
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
50
+ best_ratio = ratio
51
+ return best_ratio
52
+
53
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
54
+ orig_width, orig_height = image.size
55
+ aspect_ratio = orig_width / orig_height
56
+
57
+ # calculate the existing image aspect ratio
58
+ target_ratios = set(
59
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
60
+ i * j <= max_num and i * j >= min_num)
61
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
62
+
63
+ # find the closest aspect ratio to the target
64
+ target_aspect_ratio = find_closest_aspect_ratio(
65
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
66
+
67
+ # calculate the target width and height
68
+ target_width = image_size * target_aspect_ratio[0]
69
+ target_height = image_size * target_aspect_ratio[1]
70
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
71
+
72
+ # resize the image
73
+ resized_img = image.resize((target_width, target_height))
74
+ processed_images = []
75
+ for i in range(blocks):
76
+ box = (
77
+ (i % (target_width // image_size)) * image_size,
78
+ (i // (target_width // image_size)) * image_size,
79
+ ((i % (target_width // image_size)) + 1) * image_size,
80
+ ((i // (target_width // image_size)) + 1) * image_size
81
+ )
82
+ # split the image
83
+ split_img = resized_img.crop(box)
84
+ processed_images.append(split_img)
85
+ assert len(processed_images) == blocks
86
+ if use_thumbnail and len(processed_images) != 1:
87
+ thumbnail_img = image.resize((image_size, image_size))
88
+ processed_images.append(thumbnail_img)
89
+ return processed_images
90
+
91
+ # Load and preprocess image for the model - following the official documentation pattern
92
+ def load_image(image_pil, max_num=12):
93
+ # Process the image using dynamic_preprocess
94
+ processed_images = dynamic_preprocess(image_pil, image_size=IMAGE_SIZE, max_num=max_num)
95
+
96
+ # Convert PIL images to tensor format expected by the model
97
+ transform = build_transform(IMAGE_SIZE)
98
+ pixel_values = [transform(img) for img in processed_images]
99
+ pixel_values = torch.stack(pixel_values)
100
+
101
+ # Convert to appropriate data type
102
+ if torch.cuda.is_available():
103
+ pixel_values = pixel_values.cuda().to(torch.bfloat16)
104
+ else:
105
+ pixel_values = pixel_values.to(torch.float32)
106
+
107
+ return pixel_values
108
+
109
+ # Function to split model across GPUs
110
+ def split_model(model_name):
111
+ device_map = {}
112
+ world_size = torch.cuda.device_count()
113
+ if world_size <= 1:
114
+ return "auto"
115
+
116
+ num_layers = {
117
+ 'InternVL2_5-1B': 24,
118
+ 'InternVL2_5-2B': 24,
119
+ 'InternVL2_5-4B': 36,
120
+ 'InternVL2_5-8B': 32,
121
+ 'InternVL2_5-26B': 48,
122
+ 'InternVL2_5-38B': 64,
123
+ 'InternVL2_5-78B': 80
124
+ }[model_name]
125
+
126
+ # Since the first GPU will be used for ViT, treat it as half a GPU.
127
+ num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
128
+ num_layers_per_gpu = [num_layers_per_gpu] * world_size
129
+ num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
130
+ layer_cnt = 0
131
+ for i, num_layer in enumerate(num_layers_per_gpu):
132
+ for j in range(num_layer):
133
+ device_map[f'language_model.model.layers.{layer_cnt}'] = i
134
+ layer_cnt += 1
135
+ device_map['vision_model'] = 0
136
+ device_map['mlp1'] = 0
137
+ device_map['language_model.model.tok_embeddings'] = 0
138
+ device_map['language_model.model.embed_tokens'] = 0
139
+ device_map['language_model.model.rotary_emb'] = 0
140
+ device_map['language_model.output'] = 0
141
+ device_map['language_model.model.norm'] = 0
142
+ device_map['language_model.lm_head'] = 0
143
+ device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
144
+
145
+ return device_map
146
+
147
+ # Get model dtype
148
+ def get_model_dtype():
149
+ return torch.bfloat16 if torch.cuda.is_available() else torch.float32
150
+
151
+ # Model loading function
152
+ def load_model():
153
+ print(f"\n=== Loading {MODEL_NAME} ===")
154
+ print(f"CUDA available: {torch.cuda.is_available()}")
155
+
156
+ model_dtype = get_model_dtype()
157
+ print(f"Using model dtype: {model_dtype}")
158
+
159
+ if torch.cuda.is_available():
160
+ print(f"GPU count: {torch.cuda.device_count()}")
161
+ for i in range(torch.cuda.device_count()):
162
+ print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
163
+
164
+ # Memory info
165
+ print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
166
+ print(f"Allocated GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
167
+ print(f"Reserved GPU memory: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
168
+
169
+ # Determine device map
170
+ device_map = "auto"
171
+ if torch.cuda.is_available() and torch.cuda.device_count() > 1:
172
+ model_short_name = MODEL_NAME.split('/')[-1]
173
+ device_map = split_model(model_short_name)
174
+
175
+ # Load model and tokenizer
176
+ try:
177
+ model = AutoModel.from_pretrained(
178
+ MODEL_NAME,
179
+ torch_dtype=model_dtype,
180
+ low_cpu_mem_usage=True,
181
+ trust_remote_code=True,
182
+ device_map=device_map
183
+ )
184
+
185
+ tokenizer = AutoTokenizer.from_pretrained(
186
+ MODEL_NAME,
187
+ use_fast=False,
188
+ trust_remote_code=True
189
+ )
190
+
191
+ print(f"✓ Model and tokenizer loaded successfully!")
192
+ return model, tokenizer
193
+ except Exception as e:
194
+ print(f"❌ Error loading model: {e}")
195
+ import traceback
196
+ traceback.print_exc()
197
+ return None, None
198
+
199
+ # Extract slides from uploaded PDF or PowerPoint file
200
+ def extract_slides(file_obj):
201
+ try:
202
+ file_bytes = file_obj.read()
203
+ file_extension = os.path.splitext(file_obj.name)[1].lower()
204
+
205
+ # Create temporary file
206
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
207
+ temp_file.write(file_bytes)
208
+ temp_path = temp_file.name
209
+
210
+ slides = []
211
+
212
+ if file_extension == '.pdf':
213
+ # Extract images from PDF
214
+ images = pdf2image.convert_from_path(temp_path, dpi=300)
215
+ slides = [(f"Slide {i+1}", img) for i, img in enumerate(images)]
216
+
217
+ elif file_extension in ['.ppt', '.pptx']:
218
+ # Extract slides from PowerPoint
219
+ prs = Presentation(temp_path)
220
+ for i, slide in enumerate(prs.slides):
221
+ # Create image of slide
222
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as img_file:
223
+ slide_path = img_file.name
224
+
225
+ # We need to use pptx-export or other library to render the slide, but for this example
226
+ # we'll create placeholder images for the slides
227
+ img = Image.new('RGB', (1280, 720), color=(255, 255, 255))
228
+ slides.append((f"Slide {i+1}", img))
229
+
230
+ # Clean up temporary file
231
+ os.unlink(temp_path)
232
+
233
+ return slides
234
+
235
+ except Exception as e:
236
+ import traceback
237
+ error_msg = f"Error extracting slides: {str(e)}\n{traceback.format_exc()}"
238
+ print(error_msg)
239
+ return []
240
+
241
+ # Image analysis function using the chat method from documentation
242
+ def analyze_slide(model, tokenizer, image, prompt):
243
+ try:
244
+ # Check if image is valid
245
+ if image is None:
246
+ return "Please upload an image first."
247
+
248
+ # Process the image following official pattern
249
+ pixel_values = load_image(image)
250
+
251
+ # Debug info
252
+ print(f"Image processed: tensor shape {pixel_values.shape}, dtype {pixel_values.dtype}")
253
+
254
+ # Define generation config
255
+ generation_config = {
256
+ "max_new_tokens": 512,
257
+ "do_sample": False
258
+ }
259
+
260
+ # Use the model.chat method as shown in the official documentation
261
+ question = f"<image>\n{prompt}"
262
+ response, _ = model.chat(
263
+ tokenizer=tokenizer,
264
+ pixel_values=pixel_values,
265
+ question=question,
266
+ generation_config=generation_config,
267
+ history=None,
268
+ return_history=True
269
+ )
270
+
271
+ return response
272
+ except Exception as e:
273
+ import traceback
274
+ error_msg = f"Error analyzing image: {str(e)}\n{traceback.format_exc()}"
275
+ return error_msg
276
+
277
+ # Analyze multiple slides from a PDF or PowerPoint
278
+ def analyze_multiple_slides(model, tokenizer, file_obj, prompt, num_slides=2):
279
+ try:
280
+ if file_obj is None:
281
+ return "Please upload a PDF or PowerPoint file."
282
+
283
+ # Extract slides from the file
284
+ slides = extract_slides(file_obj)
285
+
286
+ if not slides:
287
+ return "No slides were extracted from the file. Please check the file format."
288
+
289
+ # Limit to the requested number of slides
290
+ slides = slides[:num_slides]
291
+
292
+ # Analyze each slide
293
+ analyses = []
294
+ for slide_title, slide_image in slides:
295
+ analysis = analyze_slide(model, tokenizer, slide_image, prompt)
296
+ analyses.append((slide_title, analysis))
297
+
298
+ # Format the results
299
+ result = ""
300
+ for slide_title, analysis in analyses:
301
+ result += f"## {slide_title}\n\n{analysis}\n\n---\n\n"
302
+
303
+ return result
304
+
305
+ except Exception as e:
306
+ import traceback
307
+ error_msg = f"Error analyzing slides: {str(e)}\n{traceback.format_exc()}"
308
+ return error_msg
309
+
310
+ # Main function
311
+ def main():
312
+ # Load the model
313
+ model, tokenizer = load_model()
314
+
315
+ if model is None:
316
+ # Create an error interface if model loading failed
317
+ demo = gr.Interface(
318
+ fn=lambda x: "Model loading failed. Please check the logs for details.",
319
+ inputs=gr.Textbox(),
320
+ outputs=gr.Textbox(),
321
+ title="InternVL2.5 Slide Analyzer - Error",
322
+ description="The model failed to load. Please check the logs for more information."
323
+ )
324
+ return demo
325
+
326
+ # Create tab for single image analysis
327
+ with gr.Blocks(title="InternVL2.5 Slide Analyzer") as demo:
328
+ gr.Markdown("# InternVL2.5 Slide Analyzer")
329
+ gr.Markdown("Upload an image, PDF, or PowerPoint file and ask the model to analyze it.")
330
+
331
+ with gr.Tab("Single Image Analysis"):
332
+ # Predefined prompts for analysis
333
+ image_prompts = [
334
+ "Describe this image in detail.",
335
+ "What can you tell me about this image?",
336
+ "Is there any text in this image? If so, can you read it?",
337
+ "What is the main subject of this image?",
338
+ "What emotions or feelings does this image convey?",
339
+ "Describe the composition and visual elements of this image.",
340
+ "Summarize what you see in this image in one paragraph."
341
+ ]
342
+
343
+ with gr.Row():
344
+ image_input = gr.Image(type="pil", label="Upload Image")
345
+ image_prompt = gr.Dropdown(
346
+ choices=image_prompts,
347
+ value=image_prompts[0],
348
+ label="Select a prompt or write your own below",
349
+ allow_custom_value=True
350
+ )
351
+
352
+ image_analyze_btn = gr.Button("Analyze Image")
353
+ image_output = gr.Textbox(label="Analysis Results", lines=15)
354
+
355
+ # Handle the image analysis action
356
+ image_analyze_btn.click(
357
+ fn=lambda img, prompt: analyze_slide(model, tokenizer, img, prompt),
358
+ inputs=[image_input, image_prompt],
359
+ outputs=image_output
360
+ )
361
+
362
+ # Add examples
363
+ gr.Examples(
364
+ examples=[
365
+ ["example_images/example1.jpg", "Describe this image in detail."],
366
+ ["example_images/example2.jpg", "What can you tell me about this image?"]
367
+ ],
368
+ inputs=[image_input, image_prompt]
369
+ )
370
+
371
+ with gr.Tab("Multiple Slides Analysis"):
372
+ # Predefined prompts for slides
373
+ slide_prompts = [
374
+ "Analyze this slide and describe its contents.",
375
+ "What is the main message of this slide?",
376
+ "Extract all the text visible in this slide.",
377
+ "What are the key points presented in this slide?",
378
+ "Describe the visual elements and layout of this slide.",
379
+ "Is there any data visualization in this slide? If so, explain it.",
380
+ "How does this slide fit into a typical presentation?"
381
+ ]
382
+
383
+ with gr.Row():
384
+ file_input = gr.File(label="Upload PDF or PowerPoint")
385
+ slide_prompt = gr.Dropdown(
386
+ choices=slide_prompts,
387
+ value=slide_prompts[0],
388
+ label="Select a prompt or write your own below",
389
+ allow_custom_value=True
390
+ )
391
+
392
+ num_slides = gr.Slider(
393
+ minimum=1,
394
+ maximum=10,
395
+ value=2,
396
+ step=1,
397
+ label="Number of Slides to Analyze"
398
+ )
399
+
400
+ slides_analyze_btn = gr.Button("Analyze Slides")
401
+ slides_output = gr.Markdown(label="Analysis Results")
402
+
403
+ # Handle the slides analysis action
404
+ slides_analyze_btn.click(
405
+ fn=lambda file, prompt, num: analyze_multiple_slides(model, tokenizer, file, prompt, num),
406
+ inputs=[file_input, slide_prompt, num_slides],
407
+ outputs=slides_output
408
+ )
409
+
410
+ # Add example
411
+ gr.Examples(
412
+ examples=[
413
+ ["example_slides/test_slides.pdf", "Extract all the text visible in this slide.", 2]
414
+ ],
415
+ inputs=[file_input, slide_prompt, num_slides]
416
+ )
417
+
418
+ return demo
419
+
420
+ # Run the application
421
+ if __name__ == "__main__":
422
+ try:
423
+ # Check for GPU
424
+ if not torch.cuda.is_available():
425
+ print("WARNING: CUDA is not available. The model requires a GPU to function properly.")
426
+
427
+ # Create and launch the interface
428
+ demo = main()
429
+ demo.launch(server_name="0.0.0.0")
430
+ except Exception as e:
431
+ print(f"Error starting the application: {e}")
432
+ import traceback
433
+ traceback.print_exc()