saakshigupta commited on
Commit
4048570
·
verified ·
1 Parent(s): cd7498a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -33
app.py CHANGED
@@ -424,9 +424,9 @@ def process_image_with_gradcam(image, model, device, pred_class):
424
 
425
  # ----- BLIP Image Captioning -----
426
 
427
- # Define simple prompts for BLIP
428
- ORIGINAL_IMAGE_PROMPT = "" # Empty prompt for original images - BLIP works better with no prompt
429
- GRADCAM_IMAGE_PROMPT = "Describe what you see in this heatmap visualization"
430
 
431
  # Function to load BLIP captioning model
432
  @st.cache_resource
@@ -440,71 +440,65 @@ def load_blip_model():
440
  st.error(f"Error loading BLIP model: {str(e)}")
441
  return None, None
442
 
443
- # Function to generate image caption with manual structured formatting
444
- def generate_image_caption(image, processor, model, is_gradcam=False, max_length=150, num_beams=5):
445
  """
446
- Generate a caption for the input image using BLIP model and format it with structured headings
447
  """
448
  try:
449
  # Select the appropriate prompt based on image type
450
- prompt = GRADCAM_IMAGE_PROMPT if is_gradcam else ORIGINAL_IMAGE_PROMPT
451
-
452
- # Preprocess the image
453
- inputs = processor(image, text=prompt, return_tensors="pt")
454
 
455
  # Check for available GPU
456
  device = "cuda" if torch.cuda.is_available() else "cpu"
457
  model = model.to(device)
458
- inputs = {k: v.to(device) for k, v in inputs.items()}
459
 
460
- # Generate caption
 
461
  with torch.no_grad():
462
- output = model.generate(**inputs, max_length=max_length, num_beams=num_beams)
463
-
464
- # Decode the caption
465
- raw_caption = processor.decode(output[0], skip_special_tokens=True)
466
 
467
- # Format the caption into a structured format based on type
 
 
 
 
468
  if is_gradcam:
469
- formatted_caption = format_gradcam_caption(raw_caption)
470
  else:
471
- formatted_caption = format_image_caption(raw_caption)
472
 
473
- return formatted_caption
474
  except Exception as e:
475
  st.error(f"Error generating caption: {str(e)}")
476
  return "Error generating caption"
477
 
478
- def format_image_caption(raw_caption):
479
- """Format a raw caption into a structured description with headings"""
480
-
481
- # Try to extract some basic information from the raw caption
482
- appearance_info = raw_caption # Use the full caption by default
483
 
484
- # Basic structure for image caption with extracted information
485
  structured_caption = f"""
486
- **Subject**: The image shows a person in a portrait-style photograph.
487
 
488
- **Appearance**: {appearance_info}
489
 
490
  **Background**: The background appears to be a controlled environment.
491
 
492
  **Lighting**: The lighting appears to be professional with even illumination.
493
 
494
- **Colors**: The image contains natural skin tones and colors typical of portrait photography.
495
 
496
  **Notable Elements**: The facial features and expression are the central focus of the image.
497
  """
498
  return structured_caption.strip()
499
 
500
- def format_gradcam_caption(raw_caption):
501
- """Format a raw GradCAM description with proper structure"""
502
 
503
- # Basic structure for GradCAM analysis
504
  structured_caption = f"""
505
  **Main Focus Area**: The heatmap is primarily focused on the facial region of the person.
506
 
507
- **High Activation Regions**: The red/yellow areas highlight important features that the model is focusing on. {raw_caption}
508
 
509
  **Medium Activation Regions**: The green/cyan areas correspond to regions of medium importance in the detection process, typically including parts of the face and surrounding areas.
510
 
@@ -778,6 +772,8 @@ def main():
778
  st.warning("⚠️ Please load the CLIP model first to perform initial detection.")
779
  except Exception as e:
780
  st.error(f"Error processing image: {str(e)}")
 
 
781
 
782
  # LLM Analysis section
783
  with st.expander("Stage 3: Detailed Analysis with Vision LLM", expanded=False):
 
424
 
425
  # ----- BLIP Image Captioning -----
426
 
427
+ # Define conditional prompts for BLIP
428
+ ORIGINAL_IMAGE_PROMPT = "an image of" # For the original image
429
+ GRADCAM_IMAGE_PROMPT = "a heatmap showing" # For the GradCAM visualization
430
 
431
  # Function to load BLIP captioning model
432
  @st.cache_resource
 
440
  st.error(f"Error loading BLIP model: {str(e)}")
441
  return None, None
442
 
443
+ # Function to generate image caption using BLIP
444
+ def generate_image_caption(image, processor, model, is_gradcam=False, max_length=75, num_beams=5):
445
  """
446
+ Generate a caption for the input image using BLIP model's conditional captioning
447
  """
448
  try:
449
  # Select the appropriate prompt based on image type
450
+ conditional_prompt = GRADCAM_IMAGE_PROMPT if is_gradcam else ORIGINAL_IMAGE_PROMPT
 
 
 
451
 
452
  # Check for available GPU
453
  device = "cuda" if torch.cuda.is_available() else "cpu"
454
  model = model.to(device)
 
455
 
456
+ # Get conditional caption
457
+ conditional_inputs = processor(image, conditional_prompt, return_tensors="pt").to(device)
458
  with torch.no_grad():
459
+ conditional_output = model.generate(**conditional_inputs, max_length=max_length, num_beams=num_beams)
460
+ conditional_caption = processor.decode(conditional_output[0], skip_special_tokens=True)
 
 
461
 
462
+ # Remove the prompt from the beginning if it appears
463
+ if conditional_prompt in conditional_caption:
464
+ conditional_caption = conditional_caption.replace(conditional_prompt, "").strip()
465
+
466
+ # Format the caption based on image type
467
  if is_gradcam:
468
+ full_info = format_gradcam_caption(conditional_caption)
469
  else:
470
+ full_info = format_image_caption(conditional_caption)
471
 
472
+ return full_info
473
  except Exception as e:
474
  st.error(f"Error generating caption: {str(e)}")
475
  return "Error generating caption"
476
 
477
+ def format_image_caption(caption):
478
+ """Format caption into a structured description with headings"""
 
 
 
479
 
 
480
  structured_caption = f"""
481
+ **Subject**: The image shows a person in a photograph.
482
 
483
+ **Appearance**: {caption}
484
 
485
  **Background**: The background appears to be a controlled environment.
486
 
487
  **Lighting**: The lighting appears to be professional with even illumination.
488
 
489
+ **Colors**: The image contains natural skin tones and colors typical of photography.
490
 
491
  **Notable Elements**: The facial features and expression are the central focus of the image.
492
  """
493
  return structured_caption.strip()
494
 
495
+ def format_gradcam_caption(caption):
496
+ """Format GradCAM caption with proper structure"""
497
 
 
498
  structured_caption = f"""
499
  **Main Focus Area**: The heatmap is primarily focused on the facial region of the person.
500
 
501
+ **High Activation Regions**: The red/yellow areas highlight important features that the model is focusing on. {caption}
502
 
503
  **Medium Activation Regions**: The green/cyan areas correspond to regions of medium importance in the detection process, typically including parts of the face and surrounding areas.
504
 
 
772
  st.warning("⚠️ Please load the CLIP model first to perform initial detection.")
773
  except Exception as e:
774
  st.error(f"Error processing image: {str(e)}")
775
+ import traceback
776
+ st.error(traceback.format_exc()) # This will show the full error traceback
777
 
778
  # LLM Analysis section
779
  with st.expander("Stage 3: Detailed Analysis with Vision LLM", expanded=False):