saakshigupta commited on
Commit
4ae113a
·
verified ·
1 Parent(s): be65f5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -117
app.py CHANGED
@@ -436,24 +436,60 @@ def load_blip_model():
436
  st.error(f"Error loading BLIP model: {str(e)}")
437
  return None, None
438
 
439
- # Function to generate image caption using BLIP
440
- def generate_image_caption(image, processor, model, is_gradcam=False, max_length=75, num_beams=5):
441
  """
442
- Generate a caption for the input image using BLIP model
443
  """
444
  try:
445
  # Check for available GPU
446
  device = "cuda" if torch.cuda.is_available() else "cpu"
447
  model = model.to(device)
448
 
449
- # Choose the right prompting method based on image type
450
- if is_gradcam:
451
- # For GradCAM, use conditional captioning with a specific prompt
452
- text = "a heatmap showing"
453
- inputs = processor(image, text, return_tensors="pt").to(device)
454
- else:
455
- # For original image, use unconditional captioning (works better for portraits)
456
- inputs = processor(image, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
 
458
  # Generate caption
459
  with torch.no_grad():
@@ -462,24 +498,8 @@ def generate_image_caption(image, processor, model, is_gradcam=False, max_length
462
  # Decode the output
463
  caption = processor.decode(output[0], skip_special_tokens=True)
464
 
465
- # Remove the prompt from the beginning if it appears (for conditional captioning)
466
- if is_gradcam and "a heatmap showing" in caption:
467
- caption = caption.replace("a heatmap showing", "").strip()
468
-
469
- # Format based on image type
470
- if is_gradcam:
471
- return format_gradcam_caption(caption)
472
- else:
473
- return format_image_caption(caption)
474
-
475
- except Exception as e:
476
- st.error(f"Error generating caption: {str(e)}")
477
- return "Error generating caption"
478
-
479
- def format_image_caption(caption):
480
- """Format caption into a structured description with headings"""
481
-
482
- structured_caption = f"""
483
  **Subject**: The image shows a person in a photograph.
484
 
485
  **Appearance**: {caption}
@@ -492,23 +512,11 @@ def format_image_caption(caption):
492
 
493
  **Notable Elements**: The facial features and expression are the central focus of the image.
494
  """
495
- return structured_caption.strip()
496
-
497
- def format_gradcam_caption(caption):
498
- """Format GradCAM caption with proper structure"""
499
-
500
- structured_caption = f"""
501
- **Main Focus Area**: The heatmap is primarily focused on the facial region of the person.
502
-
503
- **High Activation Regions**: The red/yellow areas highlight important features that the model is focusing on. {caption}
504
-
505
- **Medium Activation Regions**: The green/cyan areas correspond to regions of medium importance in the detection process, typically including parts of the face and surrounding areas.
506
-
507
- **Low Activation Regions**: The blue/dark blue areas represent features that have less impact on the model's decision, usually the background and peripheral elements.
508
-
509
- **Activation Pattern**: The overall pattern suggests the model is primarily analyzing facial features to make its determination of authenticity.
510
- """
511
- return structured_caption.strip()
512
 
513
  # ----- Fine-tuned Vision LLM -----
514
 
@@ -520,7 +528,6 @@ def fix_cross_attention_mask(inputs):
520
  new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles),
521
  device=inputs['cross_attention_mask'].device)
522
  inputs['cross_attention_mask'] = new_mask
523
- st.success("Fixed cross-attention mask dimensions")
524
  return inputs
525
 
526
  # Load model function
@@ -605,7 +612,7 @@ def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confide
605
 
606
  # Main app
607
  def main():
608
- # Create placeholders for model state
609
  if 'clip_model_loaded' not in st.session_state:
610
  st.session_state.clip_model_loaded = False
611
  st.session_state.clip_model = None
@@ -620,12 +627,16 @@ def main():
620
  st.session_state.blip_processor = None
621
  st.session_state.blip_model = None
622
 
 
 
 
 
623
  # Create expanders for each stage
624
  with st.expander("Stage 1: Model Loading", expanded=True):
625
  st.write("Please load the models using the buttons below:")
626
 
627
  # Button for loading models
628
- clip_col, llm_col, blip_col = st.columns(3)
629
 
630
  with clip_col:
631
  if not st.session_state.clip_model_loaded:
@@ -641,21 +652,6 @@ def main():
641
  else:
642
  st.success("✅ CLIP model loaded and ready!")
643
 
644
- with llm_col:
645
- if not st.session_state.llm_model_loaded:
646
- if st.button("📥 Load Vision LLM for Analysis", type="primary"):
647
- # Load LLM model
648
- model, tokenizer = load_llm_model()
649
- if model is not None and tokenizer is not None:
650
- st.session_state.llm_model = model
651
- st.session_state.tokenizer = tokenizer
652
- st.session_state.llm_model_loaded = True
653
- st.success("✅ Vision LLM loaded successfully!")
654
- else:
655
- st.error("❌ Failed to load Vision LLM.")
656
- else:
657
- st.success("✅ Vision LLM loaded and ready!")
658
-
659
  with blip_col:
660
  if not st.session_state.blip_model_loaded:
661
  if st.button("📥 Load BLIP for Captioning", type="primary"):
@@ -670,6 +666,21 @@ def main():
670
  st.error("❌ Failed to load BLIP model.")
671
  else:
672
  st.success("✅ BLIP captioning model loaded and ready!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
673
 
674
  # Image upload section
675
  with st.expander("Stage 2: Image Upload & Initial Detection", expanded=True):
@@ -692,12 +703,11 @@ def main():
692
  caption = generate_image_caption(
693
  image,
694
  st.session_state.blip_processor,
695
- st.session_state.blip_model,
696
- is_gradcam=False
697
  )
698
  st.session_state.image_caption = caption
699
 
700
- # Store caption but don't display it here - it will be shown in the summary section
701
 
702
  # Detect with CLIP model if loaded
703
  if st.session_state.clip_model_loaded:
@@ -732,11 +742,8 @@ def main():
732
 
733
  # Display results
734
  with col2:
735
- result_col1, result_col2 = st.columns(2)
736
- with result_col1:
737
- st.metric("Prediction", pred_label)
738
- with result_col2:
739
- st.metric("Confidence", f"{confidence:.2%}")
740
 
741
  # GradCAM visualization
742
  st.subheader("GradCAM Visualization")
@@ -750,16 +757,14 @@ def main():
750
  # Generate caption for GradCAM overlay image if BLIP model is loaded
751
  if st.session_state.blip_model_loaded:
752
  with st.spinner("Analyzing GradCAM visualization..."):
753
- gradcam_caption = generate_image_caption(
754
  overlay,
755
  st.session_state.blip_processor,
756
- st.session_state.blip_model,
757
- is_gradcam=True,
758
- max_length=150 # Longer for detailed analysis
759
  )
760
  st.session_state.gradcam_caption = gradcam_caption
761
 
762
- # Store caption but don't display it here - it will be shown in the summary section
763
 
764
  # Save results in session state for LLM analysis
765
  st.session_state.current_image = image
@@ -776,11 +781,49 @@ def main():
776
  import traceback
777
  st.error(traceback.format_exc()) # This will show the full error traceback
778
 
779
- # LLM Analysis section
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
780
  with st.expander("Stage 3: Detailed Analysis with Vision LLM", expanded=False):
781
  if hasattr(st.session_state, 'current_image') and st.session_state.llm_model_loaded:
782
  st.subheader("Detailed Deepfake Analysis")
783
 
 
 
 
 
 
 
784
  # Include both captions in the prompt if available
785
  caption_text = ""
786
  if hasattr(st.session_state, 'image_caption'):
@@ -790,19 +833,37 @@ def main():
790
  caption_text += f"\n\nGradCAM Analysis:\n{st.session_state.gradcam_caption}"
791
 
792
  # Default question with option to customize
793
- default_question = f"This image has been classified as {st.session_state.current_pred_label}.{caption_text} Analyze the key features that led to this classification, focusing on the highlighted areas in the GradCAM visualization. Provide both a technical explanation for experts and a simple explanation for non-technical users."
794
- question = st.text_area("Question/Prompt:", value=default_question, height=100)
795
 
796
- # Analyze button
797
- if st.button("🔍 Perform Detailed Analysis", type="primary"):
 
 
 
 
 
 
 
 
 
 
 
 
 
798
  try:
 
 
 
 
 
 
799
  result = analyze_image_with_llm(
800
  st.session_state.current_image,
801
  st.session_state.current_overlay,
802
  st.session_state.current_face_box,
803
  st.session_state.current_pred_label,
804
  st.session_state.current_confidence,
805
- question,
806
  st.session_state.llm_model,
807
  st.session_state.tokenizer,
808
  temperature=temperature,
@@ -810,7 +871,10 @@ def main():
810
  custom_instruction=custom_instruction
811
  )
812
 
813
- # Display results
 
 
 
814
  st.success("✅ Analysis complete!")
815
 
816
  # Check if the result contains both technical and non-technical explanations
@@ -822,12 +886,12 @@ def main():
822
  non_technical = "Non-Technical" + parts[1]
823
 
824
  # Display in two columns
825
- col1, col2 = st.columns(2)
826
- with col1:
827
  st.subheader("Technical Analysis")
828
  st.markdown(technical)
829
 
830
- with col2:
831
  st.subheader("Simple Explanation")
832
  st.markdown(non_technical)
833
  except Exception as e:
@@ -838,6 +902,10 @@ def main():
838
  # Just display the whole result
839
  st.subheader("Analysis Result")
840
  st.markdown(result)
 
 
 
 
841
  except Exception as e:
842
  st.error(f"Error during LLM analysis: {str(e)}")
843
 
@@ -846,36 +914,6 @@ def main():
846
  else:
847
  st.warning("⚠️ Please load the Vision LLM to perform detailed analysis.")
848
 
849
- # Summary section with caption
850
- if hasattr(st.session_state, 'current_image') and (hasattr(st.session_state, 'image_caption') or hasattr(st.session_state, 'gradcam_caption')):
851
- with st.expander("Image Analysis Summary", expanded=True):
852
- st.subheader("Generated Descriptions and Analysis")
853
-
854
- # Display image, captions, and results in organized layout
855
- col1, col2 = st.columns([1, 2])
856
-
857
- with col1:
858
- # Display original image and overlay side by side with controlled size
859
- st.image(st.session_state.current_image, caption="Original Image", width=300)
860
- if hasattr(st.session_state, 'current_overlay'):
861
- st.image(st.session_state.current_overlay, caption="GradCAM Overlay", width=300)
862
-
863
- with col2:
864
- # Detection result
865
- if hasattr(st.session_state, 'current_pred_label'):
866
- st.markdown(f"### Detection Result:")
867
- st.markdown(f"Classification: **{st.session_state.current_pred_label}** (Confidence: {st.session_state.current_confidence:.2%})")
868
-
869
- # Image description
870
- if hasattr(st.session_state, 'image_caption'):
871
- st.markdown("### Image Description:")
872
- st.markdown(st.session_state.image_caption)
873
-
874
- # GradCAM analysis
875
- if hasattr(st.session_state, 'gradcam_caption'):
876
- st.markdown("### GradCAM Analysis:")
877
- st.markdown(st.session_state.gradcam_caption)
878
-
879
  # Footer
880
  st.markdown("---")
881
  st.caption("Advanced Deepfake Image Analyzer with Structured BLIP Captioning")
 
436
  st.error(f"Error loading BLIP model: {str(e)}")
437
  return None, None
438
 
439
+ # Function to generate image caption using BLIP's VQA approach for GradCAM
440
+ def generate_gradcam_caption(image, processor, model, max_length=60):
441
  """
442
+ Generate a detailed analysis of GradCAM visualization using multiple questions
443
  """
444
  try:
445
  # Check for available GPU
446
  device = "cuda" if torch.cuda.is_available() else "cpu"
447
  model = model.to(device)
448
 
449
+ # Multiple specific questions about the GradCAM visualization
450
+ questions = [
451
+ "What facial features are highlighted by the red and yellow areas in this heatmap?",
452
+ "What does this facial heat map visualization show?",
453
+ "What patterns do you see in this facial heatmap visualization?"
454
+ ]
455
+
456
+ # Get answers to each question
457
+ answers = []
458
+ for question in questions:
459
+ inputs = processor(image, text=question, return_tensors="pt").to(device)
460
+ with torch.no_grad():
461
+ output = model.generate(**inputs, max_length=max_length, num_beams=5)
462
+ answer = processor.decode(output[0], skip_special_tokens=True)
463
+ answers.append(answer)
464
+
465
+ # Format answers into a structured analysis
466
+ structured_output = f"""
467
+ **Main Focus Area**: The heatmap is primarily focused on the facial region of the person.
468
+
469
+ **High Activation Regions**: The red/yellow areas highlight {answers[0]}
470
+
471
+ **Medium Activation Regions**: The green/cyan areas correspond to regions of medium importance in the detection process, typically including parts of the face and surrounding areas.
472
+
473
+ **Low Activation Regions**: The blue/dark blue areas represent features that have less impact on the model's decision, usually the background and peripheral elements.
474
+
475
+ **Activation Pattern**: {answers[2]}
476
+ """
477
+ return structured_output.strip()
478
+
479
+ except Exception as e:
480
+ st.error(f"Error analyzing GradCAM: {str(e)}")
481
+ return "Error analyzing GradCAM visualization"
482
+
483
+ # Function to generate caption for original image
484
+ def generate_image_caption(image, processor, model, max_length=75, num_beams=5):
485
+ """Generate a caption for the original image using BLIP model"""
486
+ try:
487
+ # Check for available GPU
488
+ device = "cuda" if torch.cuda.is_available() else "cpu"
489
+ model = model.to(device)
490
+
491
+ # For original image, use unconditional captioning
492
+ inputs = processor(image, return_tensors="pt").to(device)
493
 
494
  # Generate caption
495
  with torch.no_grad():
 
498
  # Decode the output
499
  caption = processor.decode(output[0], skip_special_tokens=True)
500
 
501
+ # Format into structured description
502
+ structured_caption = f"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  **Subject**: The image shows a person in a photograph.
504
 
505
  **Appearance**: {caption}
 
512
 
513
  **Notable Elements**: The facial features and expression are the central focus of the image.
514
  """
515
+ return structured_caption.strip()
516
+
517
+ except Exception as e:
518
+ st.error(f"Error generating caption: {str(e)}")
519
+ return "Error generating caption"
 
 
 
 
 
 
 
 
 
 
 
 
520
 
521
  # ----- Fine-tuned Vision LLM -----
522
 
 
528
  new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles),
529
  device=inputs['cross_attention_mask'].device)
530
  inputs['cross_attention_mask'] = new_mask
 
531
  return inputs
532
 
533
  # Load model function
 
612
 
613
  # Main app
614
  def main():
615
+ # Initialize session state variables
616
  if 'clip_model_loaded' not in st.session_state:
617
  st.session_state.clip_model_loaded = False
618
  st.session_state.clip_model = None
 
627
  st.session_state.blip_processor = None
628
  st.session_state.blip_model = None
629
 
630
+ # Initialize chat history
631
+ if 'chat_history' not in st.session_state:
632
+ st.session_state.chat_history = []
633
+
634
  # Create expanders for each stage
635
  with st.expander("Stage 1: Model Loading", expanded=True):
636
  st.write("Please load the models using the buttons below:")
637
 
638
  # Button for loading models
639
+ clip_col, blip_col, llm_col = st.columns(3)
640
 
641
  with clip_col:
642
  if not st.session_state.clip_model_loaded:
 
652
  else:
653
  st.success("✅ CLIP model loaded and ready!")
654
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
  with blip_col:
656
  if not st.session_state.blip_model_loaded:
657
  if st.button("📥 Load BLIP for Captioning", type="primary"):
 
666
  st.error("❌ Failed to load BLIP model.")
667
  else:
668
  st.success("✅ BLIP captioning model loaded and ready!")
669
+
670
+ with llm_col:
671
+ if not st.session_state.llm_model_loaded:
672
+ if st.button("📥 Load Vision LLM for Analysis", type="primary"):
673
+ # Load LLM model
674
+ model, tokenizer = load_llm_model()
675
+ if model is not None and tokenizer is not None:
676
+ st.session_state.llm_model = model
677
+ st.session_state.tokenizer = tokenizer
678
+ st.session_state.llm_model_loaded = True
679
+ st.success("✅ Vision LLM loaded successfully!")
680
+ else:
681
+ st.error("❌ Failed to load Vision LLM.")
682
+ else:
683
+ st.success("✅ Vision LLM loaded and ready!")
684
 
685
  # Image upload section
686
  with st.expander("Stage 2: Image Upload & Initial Detection", expanded=True):
 
703
  caption = generate_image_caption(
704
  image,
705
  st.session_state.blip_processor,
706
+ st.session_state.blip_model
 
707
  )
708
  st.session_state.image_caption = caption
709
 
710
+ # Store caption but don't display it yet
711
 
712
  # Detect with CLIP model if loaded
713
  if st.session_state.clip_model_loaded:
 
742
 
743
  # Display results
744
  with col2:
745
+ st.markdown("### Detection Result")
746
+ st.markdown(f"**Classification:** {pred_label} (Confidence: {confidence:.2%})")
 
 
 
747
 
748
  # GradCAM visualization
749
  st.subheader("GradCAM Visualization")
 
757
  # Generate caption for GradCAM overlay image if BLIP model is loaded
758
  if st.session_state.blip_model_loaded:
759
  with st.spinner("Analyzing GradCAM visualization..."):
760
+ gradcam_caption = generate_gradcam_caption(
761
  overlay,
762
  st.session_state.blip_processor,
763
+ st.session_state.blip_model
 
 
764
  )
765
  st.session_state.gradcam_caption = gradcam_caption
766
 
767
+ # Store caption but don't display it yet
768
 
769
  # Save results in session state for LLM analysis
770
  st.session_state.current_image = image
 
781
  import traceback
782
  st.error(traceback.format_exc()) # This will show the full error traceback
783
 
784
+ # Image Analysis Summary section - AFTER Stage 2
785
+ if hasattr(st.session_state, 'current_image') and (hasattr(st.session_state, 'image_caption') or hasattr(st.session_state, 'gradcam_caption')):
786
+ with st.expander("Image Analysis Summary", expanded=True):
787
+ st.subheader("Generated Descriptions and Analysis")
788
+
789
+ # Display image, captions, and results in organized layout with proper formatting
790
+ col1, col2 = st.columns([1, 2])
791
+
792
+ with col1:
793
+ # Display original image and overlay side by side with controlled size
794
+ st.image(st.session_state.current_image, caption="Original Image", width=300)
795
+ if hasattr(st.session_state, 'current_overlay'):
796
+ st.image(st.session_state.current_overlay, caption="GradCAM Overlay", width=300)
797
+
798
+ with col2:
799
+ # Detection result
800
+ if hasattr(st.session_state, 'current_pred_label'):
801
+ st.markdown("### Detection Result")
802
+ st.markdown(f"**Classification:** {st.session_state.current_pred_label} (Confidence: {st.session_state.current_confidence:.2%})")
803
+ st.markdown("---")
804
+
805
+ # Image description
806
+ if hasattr(st.session_state, 'image_caption'):
807
+ st.markdown("### Image Description")
808
+ st.markdown(st.session_state.image_caption)
809
+ st.markdown("---")
810
+
811
+ # GradCAM analysis
812
+ if hasattr(st.session_state, 'gradcam_caption'):
813
+ st.markdown("### GradCAM Analysis")
814
+ st.markdown(st.session_state.gradcam_caption)
815
+
816
+ # LLM Analysis section - AFTER Image Analysis Summary
817
  with st.expander("Stage 3: Detailed Analysis with Vision LLM", expanded=False):
818
  if hasattr(st.session_state, 'current_image') and st.session_state.llm_model_loaded:
819
  st.subheader("Detailed Deepfake Analysis")
820
 
821
+ # Display chat history
822
+ for i, (question, answer) in enumerate(st.session_state.chat_history):
823
+ st.markdown(f"**Question {i+1}:** {question}")
824
+ st.markdown(f"**Answer:** {answer}")
825
+ st.markdown("---")
826
+
827
  # Include both captions in the prompt if available
828
  caption_text = ""
829
  if hasattr(st.session_state, 'image_caption'):
 
833
  caption_text += f"\n\nGradCAM Analysis:\n{st.session_state.gradcam_caption}"
834
 
835
  # Default question with option to customize
836
+ default_question = f"This image has been classified as {st.session_state.current_pred_label}. Analyze the key features that led to this classification, focusing on the highlighted areas in the GradCAM visualization. Provide both a technical explanation for experts and a simple explanation for non-technical users."
 
837
 
838
+ # User input for new question
839
+ new_question = st.text_area("Ask a question about the image:", value=default_question if not st.session_state.chat_history else "", height=100)
840
+
841
+ # Analyze button and Clear Chat button in the same row
842
+ col1, col2 = st.columns([3, 1])
843
+ with col1:
844
+ analyze_button = st.button("🔍 Send Question", type="primary")
845
+ with col2:
846
+ clear_button = st.button("🗑️ Clear Chat History")
847
+
848
+ if clear_button:
849
+ st.session_state.chat_history = []
850
+ st.experimental_rerun()
851
+
852
+ if analyze_button and new_question:
853
  try:
854
+ # Add caption info if it's the first question
855
+ if not st.session_state.chat_history:
856
+ full_question = new_question + caption_text
857
+ else:
858
+ full_question = new_question
859
+
860
  result = analyze_image_with_llm(
861
  st.session_state.current_image,
862
  st.session_state.current_overlay,
863
  st.session_state.current_face_box,
864
  st.session_state.current_pred_label,
865
  st.session_state.current_confidence,
866
+ full_question,
867
  st.session_state.llm_model,
868
  st.session_state.tokenizer,
869
  temperature=temperature,
 
871
  custom_instruction=custom_instruction
872
  )
873
 
874
+ # Add to chat history
875
+ st.session_state.chat_history.append((new_question, result))
876
+
877
+ # Display the latest result too
878
  st.success("✅ Analysis complete!")
879
 
880
  # Check if the result contains both technical and non-technical explanations
 
886
  non_technical = "Non-Technical" + parts[1]
887
 
888
  # Display in two columns
889
+ tech_col, simple_col = st.columns(2)
890
+ with tech_col:
891
  st.subheader("Technical Analysis")
892
  st.markdown(technical)
893
 
894
+ with simple_col:
895
  st.subheader("Simple Explanation")
896
  st.markdown(non_technical)
897
  except Exception as e:
 
902
  # Just display the whole result
903
  st.subheader("Analysis Result")
904
  st.markdown(result)
905
+
906
+ # Rerun to update the chat history display
907
+ st.experimental_rerun()
908
+
909
  except Exception as e:
910
  st.error(f"Error during LLM analysis: {str(e)}")
911
 
 
914
  else:
915
  st.warning("⚠️ Please load the Vision LLM to perform detailed analysis.")
916
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
917
  # Footer
918
  st.markdown("---")
919
  st.caption("Advanced Deepfake Image Analyzer with Structured BLIP Captioning")