saakshigupta commited on
Commit
a9f2f98
·
verified ·
1 Parent(s): 25d0259

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -97
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import torch.nn as nn
4
  from torch.utils.data import DataLoader
5
  from torchvision import transforms
6
- import torchvision.models as models
7
  from PIL import Image
8
  import numpy as np
9
  import io
@@ -64,7 +64,7 @@ max_tokens = st.sidebar.slider(
64
  # Custom instruction text area in sidebar
65
  custom_instruction = st.sidebar.text_area(
66
  "Custom Instructions (Advanced)",
67
- value="Focus on analyzing the highlighted regions from the visualization. Examine facial inconsistencies, lighting irregularities, and other artifacts visible in the heat map.",
68
  help="Add specific instructions for the LLM analysis"
69
  )
70
 
@@ -73,8 +73,8 @@ st.sidebar.markdown("---")
73
  st.sidebar.subheader("About")
74
  st.sidebar.markdown("""
75
  This analyzer performs multi-stage detection:
76
- 1. **Initial Detection**: Face detection and feature extraction
77
- 2. **Visualization**: Highlights suspicious regions
78
  3. **LLM Analysis**: Fine-tuned Llama 3.2 Vision provides detailed explanations
79
 
80
  The system looks for:
@@ -86,7 +86,7 @@ The system looks for:
86
  - Blending problems
87
  """)
88
 
89
- # ----- Face Detection and Image Processing -----
90
 
91
  class ImageDataset(torch.utils.data.Dataset):
92
  def __init__(self, image, transform=None, face_only=True, dataset_name=None):
@@ -163,8 +163,6 @@ class ImageDataset(torch.utils.data.Dataset):
163
 
164
  return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name
165
 
166
- # ----- GradCAM Implementation with ResNet -----
167
-
168
  class GradCAM:
169
  def __init__(self, model, target_layer):
170
  self.model = model
@@ -175,42 +173,75 @@ class GradCAM:
175
 
176
  def _register_hooks(self):
177
  def forward_hook(module, input, output):
178
- self.activations = output
 
 
 
179
 
180
  def backward_hook(module, grad_in, grad_out):
181
- self.gradients = grad_out[0]
 
 
 
182
 
183
- target_layer = self.target_layer
184
- target_layer.register_forward_hook(forward_hook)
185
- target_layer.register_backward_hook(backward_hook)
186
 
187
- def generate(self, input_tensor, class_idx=0):
188
  self.model.zero_grad()
189
 
190
  try:
191
- # Forward pass
192
- output = self.model(input_tensor)
193
-
194
- # Create a one-hot tensor for the desired class
195
- one_hot = torch.zeros(output.size(), device=input_tensor.device)
196
- one_hot[0, class_idx] = 1
197
 
198
- # Backward pass
199
- output.backward(gradient=one_hot, retain_graph=True)
200
 
201
- # Get gradients and activations
202
- gradients = self.gradients.cpu().detach().numpy()[0]
203
- activations = self.activations.cpu().detach().numpy()[0]
204
-
205
- # Calculate weights (global average pooling)
206
- weights = np.mean(gradients, axis=(1, 2))
207
 
208
- # Calculate CAM
209
- cam = np.zeros(activations.shape[1:], dtype=np.float32)
210
- for i, w in enumerate(weights):
211
- cam += w * activations[i, :, :]
 
 
 
 
 
 
 
212
 
213
- # Apply ReLU and normalize
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  cam = np.maximum(cam, 0)
215
  if np.max(cam) > 0:
216
  cam = cam / np.max(cam)
@@ -219,8 +250,7 @@ class GradCAM:
219
 
220
  except Exception as e:
221
  st.error(f"Error in GradCAM.generate: {str(e)}")
222
- # Return a default heatmap
223
- return np.ones((7, 7), dtype=np.float32) * 0.5
224
 
225
  def overlay_cam_on_image(image, cam, face_box=None, alpha=0.5):
226
  """Overlay the CAM on the image"""
@@ -289,7 +319,7 @@ def save_comparison(image, cam, overlay, face_box=None):
289
  else:
290
  cam_resized = cv2.resize(cam, (image.width, image.height))
291
  axes[1].imshow(cam_resized, cmap="jet")
292
- axes[1].set_title("Heatmap")
293
  axes[1].axis("off")
294
 
295
  # Overlay
@@ -306,27 +336,31 @@ def save_comparison(image, cam, overlay, face_box=None):
306
  buf.seek(0)
307
  return Image.open(buf)
308
 
309
- # Function to load ResNet model for feature extraction
310
  @st.cache_resource
311
- def load_resnet_model():
312
- with st.spinner("Loading ResNet model for visualization..."):
313
- # Load a pretrained ResNet
314
- model = models.resnet50(pretrained=True)
315
- # Set to evaluation mode
 
 
 
 
316
  model.eval()
317
  return model
318
 
319
- def get_target_layer_resnet(model):
320
- """Get the target layer for GradCAM in ResNet"""
321
- return model.layer4[-1]
322
 
323
- def process_image_with_gradcam(image, model, device, class_idx=0):
324
- """Process an image with GradCAM using ResNet"""
325
  # Set up transformations
326
  transform = transforms.Compose([
327
  transforms.Resize((224, 224)),
328
  transforms.ToTensor(),
329
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
330
  ])
331
 
332
  # Create dataset for the single image
@@ -361,11 +395,11 @@ def process_image_with_gradcam(image, model, device, class_idx=0):
361
 
362
  try:
363
  # Create GradCAM extractor
364
- target_layer = get_target_layer_resnet(model)
365
  cam_extractor = GradCAM(model, target_layer)
366
 
367
  # Generate CAM
368
- cam = cam_extractor.generate(input_tensor, class_idx)
369
 
370
  # Create visualizations
371
  overlay = overlay_cam_on_image(original_image, cam, face_box)
@@ -377,22 +411,11 @@ def process_image_with_gradcam(image, model, device, class_idx=0):
377
  except Exception as e:
378
  st.error(f"Error processing image with GradCAM: {str(e)}")
379
  # Return default values
380
- default_cam = np.ones((7, 7), dtype=np.float32) * 0.5
381
  overlay = overlay_cam_on_image(original_image, default_cam, face_box)
382
  comparison = save_comparison(original_image, default_cam, overlay, face_box)
383
  return default_cam, overlay, comparison, face_box
384
 
385
- # ----- Face Analysis Functions -----
386
-
387
- def analyze_face(face_tensor, device):
388
- """Simple face analysis to determine if real or fake (simplified)"""
389
- # This is a placeholder function - in a real app, you would use a trained classifier
390
- # We'll return random values for now
391
- rand_val = np.random.random()
392
- is_fake = rand_val > 0.5
393
- confidence = rand_val if is_fake else 1 - rand_val
394
- return "Fake" if is_fake else "Real", confidence
395
-
396
  # ----- Fine-tuned Vision LLM -----
397
 
398
  # Function to fix cross-attention masks
@@ -437,9 +460,9 @@ def load_llm_model():
437
  def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confidence, question, model, tokenizer, temperature=0.7, max_tokens=500, custom_instruction=""):
438
  # Create a prompt that includes GradCAM information
439
  if custom_instruction.strip():
440
- full_prompt = f"{question}\n\nThe image has been preliminarily classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show areas of interest for analysis.\n\n{custom_instruction}"
441
  else:
442
- full_prompt = f"{question}\n\nThe image has been preliminarily classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show areas of interest for analysis."
443
 
444
  # Format the message to include both the original image and the GradCAM visualization
445
  messages = [
@@ -489,9 +512,9 @@ def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confide
489
  # Main app
490
  def main():
491
  # Create placeholders for model state
492
- if 'resnet_model_loaded' not in st.session_state:
493
- st.session_state.resnet_model_loaded = False
494
- st.session_state.resnet_model = None
495
 
496
  if 'llm_model_loaded' not in st.session_state:
497
  st.session_state.llm_model_loaded = False
@@ -500,22 +523,22 @@ def main():
500
 
501
  # Create expanders for each stage
502
  with st.expander("Stage 1: Model Loading", expanded=True):
503
- # Button for loading models
504
- resnet_col, llm_col = st.columns(2)
505
 
506
- with resnet_col:
507
- if not st.session_state.resnet_model_loaded:
508
- if st.button("📥 Load ResNet for Visualization", type="primary"):
509
- # Load ResNet model
510
- model = load_resnet_model()
511
  if model is not None:
512
- st.session_state.resnet_model = model
513
- st.session_state.resnet_model_loaded = True
514
- st.success("✅ ResNet model loaded successfully!")
515
  else:
516
- st.error("❌ Failed to load ResNet model.")
517
  else:
518
- st.success("✅ ResNet model loaded and ready!")
519
 
520
  with llm_col:
521
  if not st.session_state.llm_model_loaded:
@@ -533,7 +556,7 @@ def main():
533
  st.success("✅ Vision LLM loaded and ready!")
534
 
535
  # Image upload section
536
- with st.expander("Stage 2: Image Upload & Initial Analysis", expanded=True):
537
  st.subheader("Upload an Image")
538
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
539
 
@@ -542,14 +565,14 @@ def main():
542
  image = Image.open(uploaded_file).convert("RGB")
543
  st.image(image, caption="Uploaded Image", use_column_width=True)
544
 
545
- # Analyze with ResNet model if loaded
546
- if st.session_state.resnet_model_loaded:
547
- with st.spinner("Analyzing image..."):
548
- # Preprocess image
549
  transform = transforms.Compose([
550
  transforms.Resize((224, 224)),
551
  transforms.ToTensor(),
552
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
553
  ])
554
 
555
  # Create a simple dataset for the image
@@ -560,27 +583,34 @@ def main():
560
  # Get device
561
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
562
 
563
- # Analyze face (simplified)
564
- pred_label, confidence = analyze_face(tensor, device)
 
 
 
 
 
 
 
 
 
 
565
 
566
  # Display results
567
  result_col1, result_col2 = st.columns(2)
568
  with result_col1:
569
- st.metric("Initial Classification", pred_label)
570
  with result_col2:
571
  st.metric("Confidence", f"{confidence:.2%}")
572
 
573
  # GradCAM visualization
574
- st.subheader("Feature Visualization")
575
-
576
- # Use the first class for visualization
577
- class_idx = 0
578
  cam, overlay, comparison, detected_face_box = process_image_with_gradcam(
579
- image, st.session_state.resnet_model, device, class_idx
580
  )
581
 
582
  # Display GradCAM results
583
- st.image(comparison, caption="Original | Heatmap | Overlay", use_column_width=True)
584
 
585
  # Save results in session state for LLM analysis
586
  st.session_state.current_image = image
@@ -589,9 +619,9 @@ def main():
589
  st.session_state.current_pred_label = pred_label
590
  st.session_state.current_confidence = confidence
591
 
592
- st.success("✅ Initial analysis and visualization complete!")
593
  else:
594
- st.warning("⚠️ Please load the ResNet model first to perform initial analysis.")
595
 
596
  # LLM Analysis section
597
  with st.expander("Stage 3: Detailed Analysis with Vision LLM", expanded=False):
@@ -599,7 +629,7 @@ def main():
599
  st.subheader("Detailed Deepfake Analysis")
600
 
601
  # Default question with option to customize
602
- default_question = f"This image has been preliminarily classified as {st.session_state.current_pred_label}. Analyze whether this image is a deepfake, focusing on the highlighted areas in the visualization. Provide both a technical explanation for experts and a simple explanation for non-technical users."
603
  question = st.text_area("Question/Prompt:", value=default_question, height=100)
604
 
605
  # Analyze button
@@ -642,7 +672,7 @@ def main():
642
  st.subheader("Analysis Result")
643
  st.markdown(result)
644
  elif not hasattr(st.session_state, 'current_image'):
645
- st.warning("⚠️ Please upload an image and complete the initial analysis first.")
646
  else:
647
  st.warning("⚠️ Please load the Vision LLM to perform detailed analysis.")
648
 
 
3
  import torch.nn as nn
4
  from torch.utils.data import DataLoader
5
  from torchvision import transforms
6
+ from transformers.models.clip import CLIPModel # Updated import path
7
  from PIL import Image
8
  import numpy as np
9
  import io
 
64
  # Custom instruction text area in sidebar
65
  custom_instruction = st.sidebar.text_area(
66
  "Custom Instructions (Advanced)",
67
+ value="Focus on analyzing the highlighted regions from the GradCAM visualization. Examine facial inconsistencies, lighting irregularities, and other artifacts visible in the heat map.",
68
  help="Add specific instructions for the LLM analysis"
69
  )
70
 
 
73
  st.sidebar.subheader("About")
74
  st.sidebar.markdown("""
75
  This analyzer performs multi-stage detection:
76
+ 1. **Initial Detection**: CLIP-based classifier
77
+ 2. **GradCAM Visualization**: Highlights suspicious regions
78
  3. **LLM Analysis**: Fine-tuned Llama 3.2 Vision provides detailed explanations
79
 
80
  The system looks for:
 
86
  - Blending problems
87
  """)
88
 
89
+ # ----- GradCAM Implementation -----
90
 
91
  class ImageDataset(torch.utils.data.Dataset):
92
  def __init__(self, image, transform=None, face_only=True, dataset_name=None):
 
163
 
164
  return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name
165
 
 
 
166
  class GradCAM:
167
  def __init__(self, model, target_layer):
168
  self.model = model
 
173
 
174
  def _register_hooks(self):
175
  def forward_hook(module, input, output):
176
+ if isinstance(output, tuple):
177
+ self.activations = output[0]
178
+ else:
179
+ self.activations = output
180
 
181
  def backward_hook(module, grad_in, grad_out):
182
+ if isinstance(grad_out, tuple):
183
+ self.gradients = grad_out[0]
184
+ else:
185
+ self.gradients = grad_out
186
 
187
+ layer = dict([*self.model.named_modules()])[self.target_layer]
188
+ layer.register_forward_hook(forward_hook)
189
+ layer.register_backward_hook(backward_hook)
190
 
191
+ def generate(self, input_tensor, class_idx):
192
  self.model.zero_grad()
193
 
194
  try:
195
+ # Use only the vision part of the model for gradient calculation
196
+ vision_outputs = self.model.vision_model(pixel_values=input_tensor)
 
 
 
 
197
 
198
+ # Get the pooler output
199
+ features = vision_outputs.pooler_output
200
 
201
+ # Create a dummy gradient for the feature based on the class idx
202
+ one_hot = torch.zeros_like(features)
203
+ one_hot[0, class_idx] = 1
 
 
 
204
 
205
+ # Manually backpropagate
206
+ features.backward(gradient=one_hot)
207
+
208
+ # Check for None values
209
+ if self.gradients is None or self.activations is None:
210
+ st.warning("Warning: Gradients or activations are None. Using fallback CAM.")
211
+ return np.ones((14, 14), dtype=np.float32) * 0.5
212
+
213
+ # Process gradients and activations for transformer-based model
214
+ gradients = self.gradients.cpu().detach().numpy()
215
+ activations = self.activations.cpu().detach().numpy()
216
 
217
+ if len(activations.shape) == 3: # [batch, sequence_length, hidden_dim]
218
+ seq_len = activations.shape[1]
219
+
220
+ # CLIP ViT typically has 196 patch tokens (14×14) + 1 class token = 197
221
+ if seq_len >= 197:
222
+ # Skip the class token (first token) and reshape the patch tokens into a square
223
+ patch_tokens = activations[0, 1:197, :] # Remove the class token
224
+ # Take the mean across the hidden dimension
225
+ token_importance = np.mean(np.abs(patch_tokens), axis=1)
226
+ # Reshape to the expected grid size (14×14 for CLIP ViT)
227
+ cam = token_importance.reshape(14, 14)
228
+ else:
229
+ # Try to find factors close to a square
230
+ side_len = int(np.sqrt(seq_len))
231
+ # Use the mean across features as importance
232
+ token_importance = np.mean(np.abs(activations[0]), axis=1)
233
+ # Create as square-like shape as possible
234
+ cam = np.zeros((side_len, side_len))
235
+ # Fill the cam with available values
236
+ flat_cam = cam.flatten()
237
+ flat_cam[:min(len(token_importance), len(flat_cam))] = token_importance[:min(len(token_importance), len(flat_cam))]
238
+ cam = flat_cam.reshape(side_len, side_len)
239
+ else:
240
+ # Fallback
241
+ st.info("Using fallback CAM shape (14x14)")
242
+ cam = np.ones((14, 14), dtype=np.float32) * 0.5 # Default fallback
243
+
244
+ # Ensure we have valid values
245
  cam = np.maximum(cam, 0)
246
  if np.max(cam) > 0:
247
  cam = cam / np.max(cam)
 
250
 
251
  except Exception as e:
252
  st.error(f"Error in GradCAM.generate: {str(e)}")
253
+ return np.ones((14, 14), dtype=np.float32) * 0.5
 
254
 
255
  def overlay_cam_on_image(image, cam, face_box=None, alpha=0.5):
256
  """Overlay the CAM on the image"""
 
319
  else:
320
  cam_resized = cv2.resize(cam, (image.width, image.height))
321
  axes[1].imshow(cam_resized, cmap="jet")
322
+ axes[1].set_title("CAM")
323
  axes[1].axis("off")
324
 
325
  # Overlay
 
336
  buf.seek(0)
337
  return Image.open(buf)
338
 
339
+ # Function to load GradCAM CLIP model
340
  @st.cache_resource
341
+ def load_clip_model():
342
+ with st.spinner("Loading CLIP model for GradCAM..."):
343
+ model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
344
+
345
+ # Apply a simple classification head
346
+ model.classification_head = nn.Linear(1024, 2)
347
+ model.classification_head.weight.data.normal_(mean=0.0, std=0.02)
348
+ model.classification_head.bias.data.zero_()
349
+
350
  model.eval()
351
  return model
352
 
353
+ def get_target_layer_clip(model):
354
+ """Get the target layer for GradCAM"""
355
+ return "vision_model.encoder.layers.23"
356
 
357
+ def process_image_with_gradcam(image, model, device, pred_class):
358
+ """Process an image with GradCAM"""
359
  # Set up transformations
360
  transform = transforms.Compose([
361
  transforms.Resize((224, 224)),
362
  transforms.ToTensor(),
363
+ transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
364
  ])
365
 
366
  # Create dataset for the single image
 
395
 
396
  try:
397
  # Create GradCAM extractor
398
+ target_layer = get_target_layer_clip(model)
399
  cam_extractor = GradCAM(model, target_layer)
400
 
401
  # Generate CAM
402
+ cam = cam_extractor.generate(input_tensor, pred_class)
403
 
404
  # Create visualizations
405
  overlay = overlay_cam_on_image(original_image, cam, face_box)
 
411
  except Exception as e:
412
  st.error(f"Error processing image with GradCAM: {str(e)}")
413
  # Return default values
414
+ default_cam = np.ones((14, 14), dtype=np.float32) * 0.5
415
  overlay = overlay_cam_on_image(original_image, default_cam, face_box)
416
  comparison = save_comparison(original_image, default_cam, overlay, face_box)
417
  return default_cam, overlay, comparison, face_box
418
 
 
 
 
 
 
 
 
 
 
 
 
419
  # ----- Fine-tuned Vision LLM -----
420
 
421
  # Function to fix cross-attention masks
 
460
  def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confidence, question, model, tokenizer, temperature=0.7, max_tokens=500, custom_instruction=""):
461
  # Create a prompt that includes GradCAM information
462
  if custom_instruction.strip():
463
+ full_prompt = f"{question}\n\nThe image has been processed with GradCAM and classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show the areas the detection model found suspicious.\n\n{custom_instruction}"
464
  else:
465
+ full_prompt = f"{question}\n\nThe image has been processed with GradCAM and classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show the areas the detection model found suspicious."
466
 
467
  # Format the message to include both the original image and the GradCAM visualization
468
  messages = [
 
512
  # Main app
513
  def main():
514
  # Create placeholders for model state
515
+ if 'clip_model_loaded' not in st.session_state:
516
+ st.session_state.clip_model_loaded = False
517
+ st.session_state.clip_model = None
518
 
519
  if 'llm_model_loaded' not in st.session_state:
520
  st.session_state.llm_model_loaded = False
 
523
 
524
  # Create expanders for each stage
525
  with st.expander("Stage 1: Model Loading", expanded=True):
526
+ # Button for loading CLIP model
527
+ clip_col, llm_col = st.columns(2)
528
 
529
+ with clip_col:
530
+ if not st.session_state.clip_model_loaded:
531
+ if st.button("📥 Load CLIP Model for Detection", type="primary"):
532
+ # Load CLIP model
533
+ model = load_clip_model()
534
  if model is not None:
535
+ st.session_state.clip_model = model
536
+ st.session_state.clip_model_loaded = True
537
+ st.success("✅ CLIP model loaded successfully!")
538
  else:
539
+ st.error("❌ Failed to load CLIP model.")
540
  else:
541
+ st.success("✅ CLIP model loaded and ready!")
542
 
543
  with llm_col:
544
  if not st.session_state.llm_model_loaded:
 
556
  st.success("✅ Vision LLM loaded and ready!")
557
 
558
  # Image upload section
559
+ with st.expander("Stage 2: Image Upload & Initial Detection", expanded=True):
560
  st.subheader("Upload an Image")
561
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
562
 
 
565
  image = Image.open(uploaded_file).convert("RGB")
566
  st.image(image, caption="Uploaded Image", use_column_width=True)
567
 
568
+ # Detect with CLIP model if loaded
569
+ if st.session_state.clip_model_loaded:
570
+ with st.spinner("Analyzing image with CLIP model..."):
571
+ # Preprocess image for CLIP
572
  transform = transforms.Compose([
573
  transforms.Resize((224, 224)),
574
  transforms.ToTensor(),
575
+ transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
576
  ])
577
 
578
  # Create a simple dataset for the image
 
583
  # Get device
584
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
585
 
586
+ # Move model and tensor to device
587
+ model = st.session_state.clip_model.to(device)
588
+ tensor = tensor.to(device)
589
+
590
+ # Forward pass
591
+ with torch.no_grad():
592
+ outputs = model.vision_model(pixel_values=tensor).pooler_output
593
+ logits = model.classification_head(outputs)
594
+ probs = torch.softmax(logits, dim=1)[0]
595
+ pred_class = torch.argmax(probs).item()
596
+ confidence = probs[pred_class].item()
597
+ pred_label = "Fake" if pred_class == 1 else "Real"
598
 
599
  # Display results
600
  result_col1, result_col2 = st.columns(2)
601
  with result_col1:
602
+ st.metric("Prediction", pred_label)
603
  with result_col2:
604
  st.metric("Confidence", f"{confidence:.2%}")
605
 
606
  # GradCAM visualization
607
+ st.subheader("GradCAM Visualization")
 
 
 
608
  cam, overlay, comparison, detected_face_box = process_image_with_gradcam(
609
+ image, model, device, pred_class
610
  )
611
 
612
  # Display GradCAM results
613
+ st.image(comparison, caption="Original | CAM | Overlay", use_column_width=True)
614
 
615
  # Save results in session state for LLM analysis
616
  st.session_state.current_image = image
 
619
  st.session_state.current_pred_label = pred_label
620
  st.session_state.current_confidence = confidence
621
 
622
+ st.success("✅ Initial detection and GradCAM visualization complete!")
623
  else:
624
+ st.warning("⚠️ Please load the CLIP model first to perform initial detection.")
625
 
626
  # LLM Analysis section
627
  with st.expander("Stage 3: Detailed Analysis with Vision LLM", expanded=False):
 
629
  st.subheader("Detailed Deepfake Analysis")
630
 
631
  # Default question with option to customize
632
+ default_question = f"This image has been classified as {st.session_state.current_pred_label}. Analyze the key features that led to this classification, focusing on the highlighted areas in the GradCAM visualization. Provide both a technical explanation for experts and a simple explanation for non-technical users."
633
  question = st.text_area("Question/Prompt:", value=default_question, height=100)
634
 
635
  # Analyze button
 
672
  st.subheader("Analysis Result")
673
  st.markdown(result)
674
  elif not hasattr(st.session_state, 'current_image'):
675
+ st.warning("⚠️ Please upload an image and complete the initial detection first.")
676
  else:
677
  st.warning("⚠️ Please load the Vision LLM to perform detailed analysis.")
678