saakshigupta commited on
Commit
7aef21e
·
verified ·
1 Parent(s): 5b8fcd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +380 -159
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import torch.nn as nn
4
  from torch.utils.data import DataLoader
5
  from torchvision import transforms
6
- from transformers import CLIPModel
7
  from transformers.models.clip import CLIPModel
8
  from PIL import Image
9
  import numpy as np
@@ -26,8 +26,8 @@ st.set_page_config(
26
  )
27
 
28
  # Main title and description
29
- st.title("Advanced Deepfake Image Analyzer")
30
- st.markdown("Analyze images for deepfake manipulation with multi-stage analysis")
31
 
32
  # Check for GPU availability
33
  def check_gpu():
@@ -40,52 +40,37 @@ def check_gpu():
40
  return False
41
 
42
  # Sidebar components
43
- st.sidebar.title("Options")
44
-
45
- # Temperature slider
46
- temperature = st.sidebar.slider(
47
- "Temperature",
48
- min_value=0.1,
49
- max_value=1.0,
50
- value=0.7,
51
- step=0.1,
52
- help="Higher values make output more random, lower values more deterministic"
53
- )
 
 
 
 
54
 
55
- # Max response length slider
56
- max_tokens = st.sidebar.slider(
57
- "Maximum Response Length",
58
- min_value=100,
59
- max_value=1000,
60
- value=500,
61
- step=50,
62
- help="The maximum number of tokens in the response"
63
- )
64
 
65
  # Custom instruction text area in sidebar
66
- custom_instruction = st.sidebar.text_area(
67
- "Custom Instructions (Advanced)",
68
- value="Focus on analyzing the highlighted regions from the GradCAM visualization. Examine facial inconsistencies, lighting irregularities, and other artifacts visible in the heat map.",
69
- help="Add specific instructions for the LLM analysis"
70
- )
71
 
72
- # About section in sidebar
73
- st.sidebar.markdown("---")
74
- st.sidebar.subheader("About")
75
- st.sidebar.markdown("""
76
- This analyzer performs multi-stage detection:
77
- 1. **Initial Detection**: CLIP-based classifier
78
- 2. **GradCAM Visualization**: Highlights suspicious regions
79
- 3. **LLM Analysis**: Fine-tuned Llama 3.2 Vision provides detailed explanations
80
-
81
- The system looks for:
82
- - Facial inconsistencies
83
- - Unnatural movements
84
- - Lighting issues
85
- - Texture anomalies
86
- - Edge artifacts
87
- - Blending problems
88
- """)
89
 
90
  # ----- GradCAM Implementation -----
91
 
@@ -341,15 +326,19 @@ def save_comparison(image, cam, overlay, face_box=None):
341
  @st.cache_resource
342
  def load_clip_model():
343
  with st.spinner("Loading CLIP model for GradCAM..."):
344
- model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
345
-
346
- # Apply a simple classification head
347
- model.classification_head = nn.Linear(1024, 2)
348
- model.classification_head.weight.data.normal_(mean=0.0, std=0.02)
349
- model.classification_head.bias.data.zero_()
350
-
351
- model.eval()
352
- return model
 
 
 
 
353
 
354
  def get_target_layer_clip(model):
355
  """Get the target layer for GradCAM"""
@@ -417,6 +406,105 @@ def process_image_with_gradcam(image, model, device, pred_class):
417
  comparison = save_comparison(original_image, default_cam, overlay, face_box)
418
  return default_cam, overlay, comparison, face_box
419
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  # ----- Fine-tuned Vision LLM -----
421
 
422
  # Function to fix cross-attention masks
@@ -427,7 +515,6 @@ def fix_cross_attention_mask(inputs):
427
  new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles),
428
  device=inputs['cross_attention_mask'].device)
429
  inputs['cross_attention_mask'] = new_mask
430
- st.success("Fixed cross-attention mask dimensions")
431
  return inputs
432
 
433
  # Load model function
@@ -512,7 +599,7 @@ def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confide
512
 
513
  # Main app
514
  def main():
515
- # Create placeholders for model state
516
  if 'clip_model_loaded' not in st.session_state:
517
  st.session_state.clip_model_loaded = False
518
  st.session_state.clip_model = None
@@ -522,10 +609,23 @@ def main():
522
  st.session_state.llm_model = None
523
  st.session_state.tokenizer = None
524
 
 
 
 
 
 
 
 
 
 
 
 
525
  # Create expanders for each stage
526
  with st.expander("Stage 1: Model Loading", expanded=True):
527
- # Button for loading CLIP model
528
- clip_col, llm_col = st.columns(2)
 
 
529
 
530
  with clip_col:
531
  if not st.session_state.clip_model_loaded:
@@ -541,6 +641,23 @@ def main():
541
  else:
542
  st.success("✅ CLIP model loaded and ready!")
543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  with llm_col:
545
  if not st.session_state.llm_model_loaded:
546
  if st.button("📥 Load Vision LLM for Analysis", type="primary"):
@@ -562,116 +679,221 @@ def main():
562
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
563
 
564
  if uploaded_file is not None:
565
- # Display the uploaded image
566
- image = Image.open(uploaded_file).convert("RGB")
567
- st.image(image, caption="Uploaded Image", use_column_width=True)
568
-
569
- # Detect with CLIP model if loaded
570
- if st.session_state.clip_model_loaded:
571
- with st.spinner("Analyzing image with CLIP model..."):
572
- # Preprocess image for CLIP
573
- transform = transforms.Compose([
574
- transforms.Resize((224, 224)),
575
- transforms.ToTensor(),
576
- transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
577
- ])
578
-
579
- # Create a simple dataset for the image
580
- dataset = ImageDataset(image, transform=transform, face_only=True)
581
- tensor, _, _, _, face_box, _ = dataset[0]
582
- tensor = tensor.unsqueeze(0)
583
-
584
- # Get device
585
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
586
-
587
- # Move model and tensor to device
588
- model = st.session_state.clip_model.to(device)
589
- tensor = tensor.to(device)
590
-
591
- # Forward pass
592
- with torch.no_grad():
593
- outputs = model.vision_model(pixel_values=tensor).pooler_output
594
- logits = model.classification_head(outputs)
595
- probs = torch.softmax(logits, dim=1)[0]
596
- pred_class = torch.argmax(probs).item()
597
- confidence = probs[pred_class].item()
598
- pred_label = "Fake" if pred_class == 1 else "Real"
599
-
600
- # Display results
601
- result_col1, result_col2 = st.columns(2)
602
- with result_col1:
603
- st.metric("Prediction", pred_label)
604
- with result_col2:
605
- st.metric("Confidence", f"{confidence:.2%}")
606
-
607
- # GradCAM visualization
608
- st.subheader("GradCAM Visualization")
609
- cam, overlay, comparison, detected_face_box = process_image_with_gradcam(
610
- image, model, device, pred_class
611
- )
612
-
613
- # Display GradCAM results
614
- st.image(comparison, caption="Original | CAM | Overlay", use_column_width=True)
615
-
616
- # Save results in session state for LLM analysis
617
- st.session_state.current_image = image
618
- st.session_state.current_overlay = overlay
619
- st.session_state.current_face_box = detected_face_box
620
- st.session_state.current_pred_label = pred_label
621
- st.session_state.current_confidence = confidence
622
-
623
- st.success("✅ Initial detection and GradCAM visualization complete!")
624
- else:
625
- st.warning("⚠️ Please load the CLIP model first to perform initial detection.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
626
 
627
- # LLM Analysis section
628
  with st.expander("Stage 3: Detailed Analysis with Vision LLM", expanded=False):
629
  if hasattr(st.session_state, 'current_image') and st.session_state.llm_model_loaded:
630
  st.subheader("Detailed Deepfake Analysis")
631
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632
  # Default question with option to customize
633
  default_question = f"This image has been classified as {st.session_state.current_pred_label}. Analyze the key features that led to this classification, focusing on the highlighted areas in the GradCAM visualization. Provide both a technical explanation for experts and a simple explanation for non-technical users."
634
- question = st.text_area("Question/Prompt:", value=default_question, height=100)
635
-
636
- # Analyze button
637
- if st.button("🔍 Perform Detailed Analysis", type="primary"):
638
- result = analyze_image_with_llm(
639
- st.session_state.current_image,
640
- st.session_state.current_overlay,
641
- st.session_state.current_face_box,
642
- st.session_state.current_pred_label,
643
- st.session_state.current_confidence,
644
- question,
645
- st.session_state.llm_model,
646
- st.session_state.tokenizer,
647
- temperature=temperature,
648
- max_tokens=max_tokens,
649
- custom_instruction=custom_instruction
650
- )
651
-
652
- # Display results
653
- st.success("✅ Analysis complete!")
654
 
655
- # Check if the result contains both technical and non-technical explanations
656
- if "Technical" in result and "Non-Technical" in result:
657
- # Split the result into technical and non-technical sections
658
- parts = result.split("Non-Technical")
659
- technical = parts[0]
660
- non_technical = "Non-Technical" + parts[1]
 
661
 
662
- # Display in two columns
663
- col1, col2 = st.columns(2)
664
- with col1:
665
- st.subheader("Technical Analysis")
666
- st.markdown(technical)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
 
668
- with col2:
669
- st.subheader("Simple Explanation")
670
- st.markdown(non_technical)
671
- else:
672
- # Just display the whole result
673
- st.subheader("Analysis Result")
674
- st.markdown(result)
675
  elif not hasattr(st.session_state, 'current_image'):
676
  st.warning("⚠️ Please upload an image and complete the initial detection first.")
677
  else:
@@ -679,7 +901,6 @@ def main():
679
 
680
  # Footer
681
  st.markdown("---")
682
- st.caption("Advanced Deepfake Image Analyzer")
683
 
684
  if __name__ == "__main__":
685
  main()
 
3
  import torch.nn as nn
4
  from torch.utils.data import DataLoader
5
  from torchvision import transforms
6
+ from transformers import CLIPModel, BlipProcessor, BlipForConditionalGeneration
7
  from transformers.models.clip import CLIPModel
8
  from PIL import Image
9
  import numpy as np
 
26
  )
27
 
28
  # Main title and description
29
+ st.title("Deepfake Image Analyser")
30
+ st.markdown("Analyse images for deepfake manipulation")
31
 
32
  # Check for GPU availability
33
  def check_gpu():
 
40
  return False
41
 
42
  # Sidebar components
43
+ st.sidebar.title("About")
44
+ st.sidebar.markdown("""
45
+ This tool detects deepfakes using four AI models:
46
+ - **CLIP**: Initial Real/Fake classification
47
+ - **GradCAM**: Highlights suspicious regions
48
+ - **BLIP**: Describes image content
49
+ - **Llama 3.2**: Explains potential manipulations
50
+
51
+ ### Quick Start
52
+ 1. **Load Models** - Start with CLIP, add others as needed
53
+ 2. **Upload Image** - View classification and heat map
54
+ 3. **Analyze** - Get explanations and ask questions
55
+
56
+ *GPU recommended for better performance*
57
+ """)
58
 
59
+ # Fixed values for temperature and max tokens
60
+ temperature = 0.7
61
+ max_tokens = 500
 
 
 
 
 
 
62
 
63
  # Custom instruction text area in sidebar
64
+ use_custom_instructions = st.sidebar.toggle("Enable Custom Instructions", value=False, help="Toggle to enable/disable custom instructions")
 
 
 
 
65
 
66
+ if use_custom_instructions:
67
+ custom_instruction = st.sidebar.text_area(
68
+ "Custom Instructions (Advanced)",
69
+ value="Specify your preferred style of explanation (e.g., 'Provide technical, detailed explanations' or 'Use simple, non-technical language'). You can also specify what aspects of the image to focus on.",
70
+ help="Add specific instructions for the analysis"
71
+ )
72
+ else:
73
+ custom_instruction = ""
 
 
 
 
 
 
 
 
 
74
 
75
  # ----- GradCAM Implementation -----
76
 
 
326
  @st.cache_resource
327
  def load_clip_model():
328
  with st.spinner("Loading CLIP model for GradCAM..."):
329
+ try:
330
+ model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
331
+
332
+ # Apply a simple classification head
333
+ model.classification_head = nn.Linear(1024, 2)
334
+ model.classification_head.weight.data.normal_(mean=0.0, std=0.02)
335
+ model.classification_head.bias.data.zero_()
336
+
337
+ model.eval()
338
+ return model
339
+ except Exception as e:
340
+ st.error(f"Error loading CLIP model: {str(e)}")
341
+ return None
342
 
343
  def get_target_layer_clip(model):
344
  """Get the target layer for GradCAM"""
 
406
  comparison = save_comparison(original_image, default_cam, overlay, face_box)
407
  return default_cam, overlay, comparison, face_box
408
 
409
+ # ----- BLIP Image Captioning -----
410
+
411
+ # Function to load BLIP captioning models
412
+ @st.cache_resource
413
+ def load_blip_models():
414
+ with st.spinner("Loading BLIP captioning models..."):
415
+ try:
416
+ # Load original BLIP model for general image captioning
417
+ original_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
418
+ original_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
419
+
420
+ # Load fine-tuned BLIP model for GradCAM analysis
421
+ finetuned_processor = BlipProcessor.from_pretrained("saakshigupta/deepfake-blip-large")
422
+ finetuned_model = BlipForConditionalGeneration.from_pretrained("saakshigupta/deepfake-blip-large")
423
+
424
+ return original_processor, original_model, finetuned_processor, finetuned_model
425
+ except Exception as e:
426
+ st.error(f"Error loading BLIP models: {str(e)}")
427
+ return None, None, None, None
428
+
429
+ # Function to generate image caption using BLIP's VQA approach for GradCAM
430
+ def generate_gradcam_caption(image, processor, model, max_length=60):
431
+ """
432
+ Generate a detailed analysis of GradCAM visualization using the fine-tuned BLIP model
433
+ """
434
+ try:
435
+ # Process image first
436
+ inputs = processor(image, return_tensors="pt")
437
+
438
+ # Check for available GPU and move model and inputs
439
+ device = "cuda" if torch.cuda.is_available() else "cpu"
440
+ model = model.to(device)
441
+ inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
442
+
443
+ # Generate caption
444
+ with torch.no_grad():
445
+ output = model.generate(**inputs, max_length=max_length, num_beams=5)
446
+
447
+ # Decode the output
448
+ caption = processor.decode(output[0], skip_special_tokens=True)
449
+
450
+ # Extract descriptions using the full text
451
+ high_match = caption.split("high activation :")[1].split("moderate")[0] if "high activation :" in caption else ""
452
+ moderate_match = caption.split("moderate activation :")[1].split("low")[0] if "moderate activation :" in caption else ""
453
+ low_match = caption.split("low activation :")[1] if "low activation :" in caption else ""
454
+
455
+ # Format the output
456
+ formatted_text = ""
457
+ if high_match:
458
+ formatted_text += f"**High activation**:\n{high_match.strip()}\n\n"
459
+ if moderate_match:
460
+ formatted_text += f"**Moderate activation**:\n{moderate_match.strip()}\n\n"
461
+ if low_match:
462
+ formatted_text += f"**Low activation**:\n{low_match.strip()}"
463
+
464
+ return formatted_text.strip()
465
+
466
+ except Exception as e:
467
+ st.error(f"Error analyzing GradCAM: {str(e)}")
468
+ return "Error analyzing GradCAM visualization"
469
+
470
+ # Function to generate caption for original image
471
+ def generate_image_caption(image, processor, model, max_length=75, num_beams=5):
472
+ """Generate a caption for the original image using the original BLIP model"""
473
+ try:
474
+ # Check for available GPU
475
+ device = "cuda" if torch.cuda.is_available() else "cpu"
476
+ model = model.to(device)
477
+
478
+ # For original image, use unconditional captioning
479
+ inputs = processor(image, return_tensors="pt").to(device)
480
+
481
+ # Generate caption
482
+ with torch.no_grad():
483
+ output = model.generate(**inputs, max_length=max_length, num_beams=num_beams)
484
+
485
+ # Decode the output
486
+ caption = processor.decode(output[0], skip_special_tokens=True)
487
+
488
+ # Format into structured description
489
+ structured_caption = f"""
490
+ **Subject**: The image shows a person in a photograph.
491
+
492
+ **Appearance**: {caption}
493
+
494
+ **Background**: The background appears to be a controlled environment.
495
+
496
+ **Lighting**: The lighting appears to be professional with even illumination.
497
+
498
+ **Colors**: The image contains natural skin tones and colors typical of photography.
499
+
500
+ **Notable Elements**: The facial features and expression are the central focus of the image.
501
+ """
502
+ return structured_caption.strip()
503
+
504
+ except Exception as e:
505
+ st.error(f"Error generating caption: {str(e)}")
506
+ return "Error generating caption"
507
+
508
  # ----- Fine-tuned Vision LLM -----
509
 
510
  # Function to fix cross-attention masks
 
515
  new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles),
516
  device=inputs['cross_attention_mask'].device)
517
  inputs['cross_attention_mask'] = new_mask
 
518
  return inputs
519
 
520
  # Load model function
 
599
 
600
  # Main app
601
  def main():
602
+ # Initialize session state variables
603
  if 'clip_model_loaded' not in st.session_state:
604
  st.session_state.clip_model_loaded = False
605
  st.session_state.clip_model = None
 
609
  st.session_state.llm_model = None
610
  st.session_state.tokenizer = None
611
 
612
+ if 'blip_model_loaded' not in st.session_state:
613
+ st.session_state.blip_model_loaded = False
614
+ st.session_state.original_processor = None
615
+ st.session_state.original_model = None
616
+ st.session_state.finetuned_processor = None
617
+ st.session_state.finetuned_model = None
618
+
619
+ # Initialize chat history
620
+ if 'chat_history' not in st.session_state:
621
+ st.session_state.chat_history = []
622
+
623
  # Create expanders for each stage
624
  with st.expander("Stage 1: Model Loading", expanded=True):
625
+ st.write("Please load the models using the buttons below:")
626
+
627
+ # Button for loading models
628
+ clip_col, blip_col, llm_col = st.columns(3)
629
 
630
  with clip_col:
631
  if not st.session_state.clip_model_loaded:
 
641
  else:
642
  st.success("✅ CLIP model loaded and ready!")
643
 
644
+ with blip_col:
645
+ if not st.session_state.blip_model_loaded:
646
+ if st.button("📥 Load BLIP for Captioning", type="primary"):
647
+ # Load BLIP models
648
+ original_processor, original_model, finetuned_processor, finetuned_model = load_blip_models()
649
+ if all([original_processor, original_model, finetuned_processor, finetuned_model]):
650
+ st.session_state.original_processor = original_processor
651
+ st.session_state.original_model = original_model
652
+ st.session_state.finetuned_processor = finetuned_processor
653
+ st.session_state.finetuned_model = finetuned_model
654
+ st.session_state.blip_model_loaded = True
655
+ st.success("✅ BLIP captioning models loaded successfully!")
656
+ else:
657
+ st.error("❌ Failed to load BLIP models.")
658
+ else:
659
+ st.success("✅ BLIP captioning models loaded and ready!")
660
+
661
  with llm_col:
662
  if not st.session_state.llm_model_loaded:
663
  if st.button("📥 Load Vision LLM for Analysis", type="primary"):
 
679
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
680
 
681
  if uploaded_file is not None:
682
+ try:
683
+ # Load and display the image (with controlled size)
684
+ image = Image.open(uploaded_file).convert("RGB")
685
+
686
+ # Display the image with a controlled width
687
+ col1, col2 = st.columns([1, 2])
688
+ with col1:
689
+ st.image(image, caption="Uploaded Image", width=300)
690
+
691
+ # Generate detailed caption for original image if BLIP model is loaded
692
+ if st.session_state.blip_model_loaded:
693
+ with st.spinner("Generating image description..."):
694
+ caption = generate_image_caption(
695
+ image,
696
+ st.session_state.original_processor,
697
+ st.session_state.original_model
698
+ )
699
+ st.session_state.image_caption = caption
700
+
701
+ # Store caption but don't display it yet
702
+
703
+ # Detect with CLIP model if loaded
704
+ if st.session_state.clip_model_loaded:
705
+ with st.spinner("Analyzing image with CLIP model..."):
706
+ # Preprocess image for CLIP
707
+ transform = transforms.Compose([
708
+ transforms.Resize((224, 224)),
709
+ transforms.ToTensor(),
710
+ transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
711
+ ])
712
+
713
+ # Create a simple dataset for the image
714
+ dataset = ImageDataset(image, transform=transform, face_only=True)
715
+ tensor, _, _, _, face_box, _ = dataset[0]
716
+ tensor = tensor.unsqueeze(0)
717
+
718
+ # Get device
719
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
720
+
721
+ # Move model and tensor to device
722
+ model = st.session_state.clip_model.to(device)
723
+ tensor = tensor.to(device)
724
+
725
+ # Forward pass
726
+ with torch.no_grad():
727
+ outputs = model.vision_model(pixel_values=tensor).pooler_output
728
+ logits = model.classification_head(outputs)
729
+ probs = torch.softmax(logits, dim=1)[0]
730
+ pred_class = torch.argmax(probs).item()
731
+ confidence = probs[pred_class].item()
732
+ pred_label = "Fake" if pred_class == 1 else "Real"
733
+
734
+ # Display results
735
+ with col2:
736
+ st.markdown("### Detection Result")
737
+ st.markdown(f"**Classification:** {pred_label} (Confidence: {confidence:.2%})")
738
+
739
+ # GradCAM visualization
740
+ st.subheader("GradCAM Visualization")
741
+ cam, overlay, comparison, detected_face_box = process_image_with_gradcam(
742
+ image, model, device, pred_class
743
+ )
744
+
745
+ # Display GradCAM results (controlled size)
746
+ st.image(comparison, caption="Original | CAM | Overlay", width=700)
747
+
748
+ # Generate caption for GradCAM overlay image if BLIP model is loaded
749
+ if st.session_state.blip_model_loaded:
750
+ with st.spinner("Analyzing GradCAM visualization..."):
751
+ gradcam_caption = generate_gradcam_caption(
752
+ overlay,
753
+ st.session_state.finetuned_processor,
754
+ st.session_state.finetuned_model
755
+ )
756
+ st.session_state.gradcam_caption = gradcam_caption
757
+
758
+ # Store caption but don't display it yet
759
+
760
+ # Save results in session state for LLM analysis
761
+ st.session_state.current_image = image
762
+ st.session_state.current_overlay = overlay
763
+ st.session_state.current_face_box = detected_face_box
764
+ st.session_state.current_pred_label = pred_label
765
+ st.session_state.current_confidence = confidence
766
+
767
+ st.success("✅ Initial detection and GradCAM visualization complete!")
768
+ else:
769
+ st.warning("⚠️ Please load the CLIP model first to perform initial detection.")
770
+ except Exception as e:
771
+ st.error(f"Error processing image: {str(e)}")
772
+ import traceback
773
+ st.error(traceback.format_exc()) # This will show the full error traceback
774
+
775
+ # Image Analysis Summary section - AFTER Stage 2
776
+ if hasattr(st.session_state, 'current_image') and (hasattr(st.session_state, 'image_caption') or hasattr(st.session_state, 'gradcam_caption')):
777
+ with st.expander("Image Analysis Summary", expanded=True):
778
+ # Display images and analysis in organized layout
779
+ col1, col2 = st.columns([1, 2])
780
+
781
+ with col1:
782
+ # Display original image
783
+ st.image(st.session_state.current_image, caption="Original Image", width=300)
784
+ # Display GradCAM overlay
785
+ if hasattr(st.session_state, 'current_overlay'):
786
+ st.image(st.session_state.current_overlay, caption="GradCAM Visualization", width=300)
787
+
788
+ with col2:
789
+ # Image description
790
+ if hasattr(st.session_state, 'image_caption'):
791
+ st.markdown("### Image Description")
792
+ st.markdown(st.session_state.image_caption)
793
+ st.markdown("---")
794
+
795
+ # GradCAM analysis
796
+ if hasattr(st.session_state, 'gradcam_caption'):
797
+ st.markdown("### GradCAM Analysis")
798
+ st.markdown(st.session_state.gradcam_caption)
799
+ st.markdown("---")
800
 
801
+ # LLM Analysis section - AFTER Image Analysis Summary
802
  with st.expander("Stage 3: Detailed Analysis with Vision LLM", expanded=False):
803
  if hasattr(st.session_state, 'current_image') and st.session_state.llm_model_loaded:
804
  st.subheader("Detailed Deepfake Analysis")
805
 
806
+ # Display chat history
807
+ for i, (question, answer) in enumerate(st.session_state.chat_history):
808
+ st.markdown(f"**Question {i+1}:** {question}")
809
+ st.markdown(f"**Answer:** {answer}")
810
+ st.markdown("---")
811
+
812
+ # Include both captions in the prompt if available
813
+ caption_text = ""
814
+ if hasattr(st.session_state, 'image_caption'):
815
+ caption_text += f"\n\nImage Description:\n{st.session_state.image_caption}"
816
+
817
+ if hasattr(st.session_state, 'gradcam_caption'):
818
+ caption_text += f"\n\nGradCAM Analysis:\n{st.session_state.gradcam_caption}"
819
+
820
  # Default question with option to customize
821
  default_question = f"This image has been classified as {st.session_state.current_pred_label}. Analyze the key features that led to this classification, focusing on the highlighted areas in the GradCAM visualization. Provide both a technical explanation for experts and a simple explanation for non-technical users."
822
+
823
+ # User input for new question
824
+ new_question = st.text_area("Ask a question about the image:", value=default_question if not st.session_state.chat_history else "", height=100)
825
+
826
+ # Analyze button and Clear Chat button in the same row
827
+ col1, col2 = st.columns([3, 1])
828
+ with col1:
829
+ analyze_button = st.button("🔍 Send Question", type="primary")
830
+ with col2:
831
+ clear_button = st.button("🗑️ Clear Chat History")
832
+
833
+ if clear_button:
834
+ st.session_state.chat_history = []
835
+ st.experimental_rerun()
 
 
 
 
 
 
836
 
837
+ if analyze_button and new_question:
838
+ try:
839
+ # Add caption info if it's the first question
840
+ if not st.session_state.chat_history:
841
+ full_question = new_question + caption_text
842
+ else:
843
+ full_question = new_question
844
 
845
+ result = analyze_image_with_llm(
846
+ st.session_state.current_image,
847
+ st.session_state.current_overlay,
848
+ st.session_state.current_face_box,
849
+ st.session_state.current_pred_label,
850
+ st.session_state.current_confidence,
851
+ full_question,
852
+ st.session_state.llm_model,
853
+ st.session_state.tokenizer,
854
+ temperature=temperature,
855
+ max_tokens=max_tokens,
856
+ custom_instruction=custom_instruction
857
+ )
858
+
859
+ # Add to chat history
860
+ st.session_state.chat_history.append((new_question, result))
861
+
862
+ # Display the latest result too
863
+ st.success("✅ Analysis complete!")
864
+
865
+ # Check if the result contains both technical and non-technical explanations
866
+ if "Technical" in result and "Non-Technical" in result:
867
+ try:
868
+ # Split the result into technical and non-technical sections
869
+ parts = result.split("Non-Technical")
870
+ technical = parts[0]
871
+ non_technical = "Non-Technical" + parts[1]
872
+
873
+ # Display in two columns
874
+ tech_col, simple_col = st.columns(2)
875
+ with tech_col:
876
+ st.subheader("Technical Analysis")
877
+ st.markdown(technical)
878
+
879
+ with simple_col:
880
+ st.subheader("Simple Explanation")
881
+ st.markdown(non_technical)
882
+ except Exception as e:
883
+ # Fallback if splitting fails
884
+ st.subheader("Analysis Result")
885
+ st.markdown(result)
886
+ else:
887
+ # Just display the whole result
888
+ st.subheader("Analysis Result")
889
+ st.markdown(result)
890
+
891
+ # Rerun to update the chat history display
892
+ st.experimental_rerun()
893
+
894
+ except Exception as e:
895
+ st.error(f"Error during LLM analysis: {str(e)}")
896
 
 
 
 
 
 
 
 
897
  elif not hasattr(st.session_state, 'current_image'):
898
  st.warning("⚠️ Please upload an image and complete the initial detection first.")
899
  else:
 
901
 
902
  # Footer
903
  st.markdown("---")
 
904
 
905
  if __name__ == "__main__":
906
  main()