saakshigupta commited on
Commit
f0a1db6
Β·
verified Β·
1 Parent(s): 3bb619d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -6
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import torch.nn as nn
4
  from torch.utils.data import DataLoader
5
  from torchvision import transforms
6
- from transformers import CLIPModel
7
  from transformers.models.clip import CLIPModel
8
  from PIL import Image
9
  import numpy as np
@@ -76,7 +76,8 @@ st.sidebar.markdown("""
76
  This analyzer performs multi-stage detection:
77
  1. **Initial Detection**: CLIP-based classifier
78
  2. **GradCAM Visualization**: Highlights suspicious regions
79
- 3. **LLM Analysis**: Fine-tuned Llama 3.2 Vision provides detailed explanations
 
80
 
81
  The system looks for:
82
  - Facial inconsistencies
@@ -417,6 +418,55 @@ def process_image_with_gradcam(image, model, device, pred_class):
417
  comparison = save_comparison(original_image, default_cam, overlay, face_box)
418
  return default_cam, overlay, comparison, face_box
419
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  # ----- Fine-tuned Vision LLM -----
421
 
422
  # Function to fix cross-attention masks
@@ -522,10 +572,15 @@ def main():
522
  st.session_state.llm_model = None
523
  st.session_state.tokenizer = None
524
 
 
 
 
 
 
525
  # Create expanders for each stage
526
  with st.expander("Stage 1: Model Loading", expanded=True):
527
- # Button for loading CLIP model
528
- clip_col, llm_col = st.columns(2)
529
 
530
  with clip_col:
531
  if not st.session_state.clip_model_loaded:
@@ -555,6 +610,21 @@ def main():
555
  st.error("❌ Failed to load Vision LLM.")
556
  else:
557
  st.success("βœ… Vision LLM loaded and ready!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
 
559
  # Image upload section
560
  with st.expander("Stage 2: Image Upload & Initial Detection", expanded=True):
@@ -566,6 +636,17 @@ def main():
566
  image = Image.open(uploaded_file).convert("RGB")
567
  st.image(image, caption="Uploaded Image", use_column_width=True)
568
 
 
 
 
 
 
 
 
 
 
 
 
569
  # Detect with CLIP model if loaded
570
  if st.session_state.clip_model_loaded:
571
  with st.spinner("Analyzing image with CLIP model..."):
@@ -629,8 +710,13 @@ def main():
629
  if hasattr(st.session_state, 'current_image') and st.session_state.llm_model_loaded:
630
  st.subheader("Detailed Deepfake Analysis")
631
 
 
 
 
 
 
632
  # Default question with option to customize
633
- default_question = f"This image has been classified as {st.session_state.current_pred_label}. Analyze the key features that led to this classification, focusing on the highlighted areas in the GradCAM visualization. Provide both a technical explanation for experts and a simple explanation for non-technical users."
634
  question = st.text_area("Question/Prompt:", value=default_question, height=100)
635
 
636
  # Analyze button
@@ -676,10 +762,28 @@ def main():
676
  st.warning("⚠️ Please upload an image and complete the initial detection first.")
677
  else:
678
  st.warning("⚠️ Please load the Vision LLM to perform detailed analysis.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679
 
680
  # Footer
681
  st.markdown("---")
682
- st.caption("Advanced Deepfake Image Analyzer")
683
 
684
  if __name__ == "__main__":
685
  main()
 
3
  import torch.nn as nn
4
  from torch.utils.data import DataLoader
5
  from torchvision import transforms
6
+ from transformers import CLIPModel, BlipProcessor, BlipForConditionalGeneration
7
  from transformers.models.clip import CLIPModel
8
  from PIL import Image
9
  import numpy as np
 
76
  This analyzer performs multi-stage detection:
77
  1. **Initial Detection**: CLIP-based classifier
78
  2. **GradCAM Visualization**: Highlights suspicious regions
79
+ 3. **Image Captioning**: BLIP model describes the image content
80
+ 4. **LLM Analysis**: Fine-tuned Llama 3.2 Vision provides detailed explanations
81
 
82
  The system looks for:
83
  - Facial inconsistencies
 
418
  comparison = save_comparison(original_image, default_cam, overlay, face_box)
419
  return default_cam, overlay, comparison, face_box
420
 
421
+ # ----- BLIP Image Captioning -----
422
+
423
+ # Function to load BLIP captioning model
424
+ @st.cache_resource
425
+ def load_blip_model():
426
+ with st.spinner("Loading BLIP captioning model..."):
427
+ try:
428
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
429
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
430
+ return processor, model
431
+ except Exception as e:
432
+ st.error(f"Error loading BLIP model: {str(e)}")
433
+ return None, None
434
+
435
+ # Function to generate image caption
436
+ def generate_image_caption(image, processor, model, max_length=50, num_beams=5):
437
+ """
438
+ Generate a caption for the input image using BLIP model
439
+
440
+ Args:
441
+ image (PIL.Image): Input image
442
+ processor: BLIP processor
443
+ model: BLIP model
444
+ max_length (int): Maximum length of the caption
445
+ num_beams (int): Number of beams for beam search
446
+
447
+ Returns:
448
+ str: Generated caption
449
+ """
450
+ try:
451
+ # Preprocess the image
452
+ inputs = processor(image, return_tensors="pt")
453
+
454
+ # Check for available GPU
455
+ device = "cuda" if torch.cuda.is_available() else "cpu"
456
+ model = model.to(device)
457
+ inputs = {k: v.to(device) for k, v in inputs.items()}
458
+
459
+ # Generate caption
460
+ with torch.no_grad():
461
+ output = model.generate(**inputs, max_length=max_length, num_beams=num_beams)
462
+
463
+ # Decode the caption
464
+ caption = processor.decode(output[0], skip_special_tokens=True)
465
+ return caption
466
+ except Exception as e:
467
+ st.error(f"Error generating caption: {str(e)}")
468
+ return "Error generating caption"
469
+
470
  # ----- Fine-tuned Vision LLM -----
471
 
472
  # Function to fix cross-attention masks
 
572
  st.session_state.llm_model = None
573
  st.session_state.tokenizer = None
574
 
575
+ if 'blip_model_loaded' not in st.session_state:
576
+ st.session_state.blip_model_loaded = False
577
+ st.session_state.blip_processor = None
578
+ st.session_state.blip_model = None
579
+
580
  # Create expanders for each stage
581
  with st.expander("Stage 1: Model Loading", expanded=True):
582
+ # Button for loading models
583
+ clip_col, llm_col, blip_col = st.columns(3)
584
 
585
  with clip_col:
586
  if not st.session_state.clip_model_loaded:
 
610
  st.error("❌ Failed to load Vision LLM.")
611
  else:
612
  st.success("βœ… Vision LLM loaded and ready!")
613
+
614
+ with blip_col:
615
+ if not st.session_state.blip_model_loaded:
616
+ if st.button("πŸ“₯ Load BLIP for Captioning", type="primary"):
617
+ # Load BLIP model
618
+ processor, model = load_blip_model()
619
+ if model is not None and processor is not None:
620
+ st.session_state.blip_processor = processor
621
+ st.session_state.blip_model = model
622
+ st.session_state.blip_model_loaded = True
623
+ st.success("βœ… BLIP captioning model loaded successfully!")
624
+ else:
625
+ st.error("❌ Failed to load BLIP model.")
626
+ else:
627
+ st.success("βœ… BLIP captioning model loaded and ready!")
628
 
629
  # Image upload section
630
  with st.expander("Stage 2: Image Upload & Initial Detection", expanded=True):
 
636
  image = Image.open(uploaded_file).convert("RGB")
637
  st.image(image, caption="Uploaded Image", use_column_width=True)
638
 
639
+ # Generate image caption if BLIP model is loaded
640
+ if st.session_state.blip_model_loaded:
641
+ with st.spinner("Generating image caption..."):
642
+ caption = generate_image_caption(
643
+ image,
644
+ st.session_state.blip_processor,
645
+ st.session_state.blip_model
646
+ )
647
+ st.session_state.image_caption = caption
648
+ st.success(f"πŸ“ Image Caption: **{caption}**")
649
+
650
  # Detect with CLIP model if loaded
651
  if st.session_state.clip_model_loaded:
652
  with st.spinner("Analyzing image with CLIP model..."):
 
710
  if hasattr(st.session_state, 'current_image') and st.session_state.llm_model_loaded:
711
  st.subheader("Detailed Deepfake Analysis")
712
 
713
+ # Include caption in the prompt if available
714
+ caption_text = ""
715
+ if hasattr(st.session_state, 'image_caption'):
716
+ caption_text = f"\n\nImage caption: {st.session_state.image_caption}"
717
+
718
  # Default question with option to customize
719
+ default_question = f"This image has been classified as {st.session_state.current_pred_label}.{caption_text} Analyze the key features that led to this classification, focusing on the highlighted areas in the GradCAM visualization. Provide both a technical explanation for experts and a simple explanation for non-technical users."
720
  question = st.text_area("Question/Prompt:", value=default_question, height=100)
721
 
722
  # Analyze button
 
762
  st.warning("⚠️ Please upload an image and complete the initial detection first.")
763
  else:
764
  st.warning("⚠️ Please load the Vision LLM to perform detailed analysis.")
765
+
766
+ # Summary section with caption
767
+ if hasattr(st.session_state, 'current_image') and hasattr(st.session_state, 'image_caption'):
768
+ with st.expander("Image Caption Summary", expanded=True):
769
+ st.subheader("Generated Image Description")
770
+
771
+ # Display image and caption
772
+ col1, col2 = st.columns([1, 2])
773
+ with col1:
774
+ st.image(st.session_state.current_image, use_column_width=True)
775
+ with col2:
776
+ st.markdown("### BLIP Caption:")
777
+ st.markdown(f"**{st.session_state.image_caption}**")
778
+
779
+ # Display detection result if available
780
+ if hasattr(st.session_state, 'current_pred_label'):
781
+ st.markdown("### Detection Result:")
782
+ st.markdown(f"Classification: **{st.session_state.current_pred_label}** (Confidence: {st.session_state.current_confidence:.2%})")
783
 
784
  # Footer
785
  st.markdown("---")
786
+ st.caption("Advanced Deepfake Image Analyzer with BLIP Captioning")
787
 
788
  if __name__ == "__main__":
789
  main()