saakshigupta commited on
Commit
dd6837f
·
verified ·
1 Parent(s): abc33b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -128
app.py CHANGED
@@ -3,8 +3,7 @@ import torch
3
  import torch.nn as nn
4
  from torch.utils.data import DataLoader
5
  from torchvision import transforms
6
- from transformers import CLIPModel
7
- from transformers.models.clip import CLIPModel
8
  from PIL import Image
9
  import numpy as np
10
  import io
@@ -65,7 +64,7 @@ max_tokens = st.sidebar.slider(
65
  # Custom instruction text area in sidebar
66
  custom_instruction = st.sidebar.text_area(
67
  "Custom Instructions (Advanced)",
68
- value="Focus on analyzing the highlighted regions from the GradCAM visualization. Examine facial inconsistencies, lighting irregularities, and other artifacts visible in the heat map.",
69
  help="Add specific instructions for the LLM analysis"
70
  )
71
 
@@ -74,8 +73,8 @@ st.sidebar.markdown("---")
74
  st.sidebar.subheader("About")
75
  st.sidebar.markdown("""
76
  This analyzer performs multi-stage detection:
77
- 1. **Initial Detection**: CLIP-based classifier
78
- 2. **GradCAM Visualization**: Highlights suspicious regions
79
  3. **LLM Analysis**: Fine-tuned Llama 3.2 Vision provides detailed explanations
80
 
81
  The system looks for:
@@ -87,7 +86,7 @@ The system looks for:
87
  - Blending problems
88
  """)
89
 
90
- # ----- GradCAM Implementation -----
91
 
92
  class ImageDataset(torch.utils.data.Dataset):
93
  def __init__(self, image, transform=None, face_only=True, dataset_name=None):
@@ -164,6 +163,8 @@ class ImageDataset(torch.utils.data.Dataset):
164
 
165
  return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name
166
 
 
 
167
  class GradCAM:
168
  def __init__(self, model, target_layer):
169
  self.model = model
@@ -174,75 +175,42 @@ class GradCAM:
174
 
175
  def _register_hooks(self):
176
  def forward_hook(module, input, output):
177
- if isinstance(output, tuple):
178
- self.activations = output[0]
179
- else:
180
- self.activations = output
181
 
182
  def backward_hook(module, grad_in, grad_out):
183
- if isinstance(grad_out, tuple):
184
- self.gradients = grad_out[0]
185
- else:
186
- self.gradients = grad_out
187
 
188
- layer = dict([*self.model.named_modules()])[self.target_layer]
189
- layer.register_forward_hook(forward_hook)
190
- layer.register_backward_hook(backward_hook)
191
 
192
- def generate(self, input_tensor, class_idx):
193
  self.model.zero_grad()
194
 
195
  try:
196
- # Use only the vision part of the model for gradient calculation
197
- vision_outputs = self.model.vision_model(pixel_values=input_tensor)
198
 
199
- # Get the pooler output
200
- features = vision_outputs.pooler_output
201
-
202
- # Create a dummy gradient for the feature based on the class idx
203
- one_hot = torch.zeros_like(features)
204
  one_hot[0, class_idx] = 1
205
 
206
- # Manually backpropagate
207
- features.backward(gradient=one_hot)
208
-
209
- # Check for None values
210
- if self.gradients is None or self.activations is None:
211
- st.warning("Warning: Gradients or activations are None. Using fallback CAM.")
212
- return np.ones((14, 14), dtype=np.float32) * 0.5
213
-
214
- # Process gradients and activations for transformer-based model
215
- gradients = self.gradients.cpu().detach().numpy()
216
- activations = self.activations.cpu().detach().numpy()
217
 
218
- if len(activations.shape) == 3: # [batch, sequence_length, hidden_dim]
219
- seq_len = activations.shape[1]
220
-
221
- # CLIP ViT typically has 196 patch tokens (14×14) + 1 class token = 197
222
- if seq_len >= 197:
223
- # Skip the class token (first token) and reshape the patch tokens into a square
224
- patch_tokens = activations[0, 1:197, :] # Remove the class token
225
- # Take the mean across the hidden dimension
226
- token_importance = np.mean(np.abs(patch_tokens), axis=1)
227
- # Reshape to the expected grid size (14×14 for CLIP ViT)
228
- cam = token_importance.reshape(14, 14)
229
- else:
230
- # Try to find factors close to a square
231
- side_len = int(np.sqrt(seq_len))
232
- # Use the mean across features as importance
233
- token_importance = np.mean(np.abs(activations[0]), axis=1)
234
- # Create as square-like shape as possible
235
- cam = np.zeros((side_len, side_len))
236
- # Fill the cam with available values
237
- flat_cam = cam.flatten()
238
- flat_cam[:min(len(token_importance), len(flat_cam))] = token_importance[:min(len(token_importance), len(flat_cam))]
239
- cam = flat_cam.reshape(side_len, side_len)
240
- else:
241
- # Fallback
242
- st.info("Using fallback CAM shape (14x14)")
243
- cam = np.ones((14, 14), dtype=np.float32) * 0.5 # Default fallback
244
-
245
- # Ensure we have valid values
246
  cam = np.maximum(cam, 0)
247
  if np.max(cam) > 0:
248
  cam = cam / np.max(cam)
@@ -251,7 +219,8 @@ class GradCAM:
251
 
252
  except Exception as e:
253
  st.error(f"Error in GradCAM.generate: {str(e)}")
254
- return np.ones((14, 14), dtype=np.float32) * 0.5
 
255
 
256
  def overlay_cam_on_image(image, cam, face_box=None, alpha=0.5):
257
  """Overlay the CAM on the image"""
@@ -320,7 +289,7 @@ def save_comparison(image, cam, overlay, face_box=None):
320
  else:
321
  cam_resized = cv2.resize(cam, (image.width, image.height))
322
  axes[1].imshow(cam_resized, cmap="jet")
323
- axes[1].set_title("CAM")
324
  axes[1].axis("off")
325
 
326
  # Overlay
@@ -337,31 +306,27 @@ def save_comparison(image, cam, overlay, face_box=None):
337
  buf.seek(0)
338
  return Image.open(buf)
339
 
340
- # Function to load GradCAM CLIP model
341
  @st.cache_resource
342
- def load_clip_model():
343
- with st.spinner("Loading CLIP model for GradCAM..."):
344
- model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
345
-
346
- # Apply a simple classification head
347
- model.classification_head = nn.Linear(1024, 2)
348
- model.classification_head.weight.data.normal_(mean=0.0, std=0.02)
349
- model.classification_head.bias.data.zero_()
350
-
351
  model.eval()
352
  return model
353
 
354
- def get_target_layer_clip(model):
355
- """Get the target layer for GradCAM"""
356
- return "vision_model.encoder.layers.23"
357
 
358
- def process_image_with_gradcam(image, model, device, pred_class):
359
- """Process an image with GradCAM"""
360
  # Set up transformations
361
  transform = transforms.Compose([
362
  transforms.Resize((224, 224)),
363
  transforms.ToTensor(),
364
- transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
365
  ])
366
 
367
  # Create dataset for the single image
@@ -396,11 +361,11 @@ def process_image_with_gradcam(image, model, device, pred_class):
396
 
397
  try:
398
  # Create GradCAM extractor
399
- target_layer = get_target_layer_clip(model)
400
  cam_extractor = GradCAM(model, target_layer)
401
 
402
  # Generate CAM
403
- cam = cam_extractor.generate(input_tensor, pred_class)
404
 
405
  # Create visualizations
406
  overlay = overlay_cam_on_image(original_image, cam, face_box)
@@ -412,11 +377,22 @@ def process_image_with_gradcam(image, model, device, pred_class):
412
  except Exception as e:
413
  st.error(f"Error processing image with GradCAM: {str(e)}")
414
  # Return default values
415
- default_cam = np.ones((14, 14), dtype=np.float32) * 0.5
416
  overlay = overlay_cam_on_image(original_image, default_cam, face_box)
417
  comparison = save_comparison(original_image, default_cam, overlay, face_box)
418
  return default_cam, overlay, comparison, face_box
419
 
 
 
 
 
 
 
 
 
 
 
 
420
  # ----- Fine-tuned Vision LLM -----
421
 
422
  # Function to fix cross-attention masks
@@ -461,9 +437,9 @@ def load_llm_model():
461
  def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confidence, question, model, tokenizer, temperature=0.7, max_tokens=500, custom_instruction=""):
462
  # Create a prompt that includes GradCAM information
463
  if custom_instruction.strip():
464
- full_prompt = f"{question}\n\nThe image has been processed with GradCAM and classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show the areas the detection model found suspicious.\n\n{custom_instruction}"
465
  else:
466
- full_prompt = f"{question}\n\nThe image has been processed with GradCAM and classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show the areas the detection model found suspicious."
467
 
468
  # Format the message to include both the original image and the GradCAM visualization
469
  messages = [
@@ -513,9 +489,9 @@ def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confide
513
  # Main app
514
  def main():
515
  # Create placeholders for model state
516
- if 'clip_model_loaded' not in st.session_state:
517
- st.session_state.clip_model_loaded = False
518
- st.session_state.clip_model = None
519
 
520
  if 'llm_model_loaded' not in st.session_state:
521
  st.session_state.llm_model_loaded = False
@@ -524,22 +500,22 @@ def main():
524
 
525
  # Create expanders for each stage
526
  with st.expander("Stage 1: Model Loading", expanded=True):
527
- # Button for loading CLIP model
528
- clip_col, llm_col = st.columns(2)
529
 
530
- with clip_col:
531
- if not st.session_state.clip_model_loaded:
532
- if st.button("📥 Load CLIP Model for Detection", type="primary"):
533
- # Load CLIP model
534
- model = load_clip_model()
535
  if model is not None:
536
- st.session_state.clip_model = model
537
- st.session_state.clip_model_loaded = True
538
- st.success("✅ CLIP model loaded successfully!")
539
  else:
540
- st.error("❌ Failed to load CLIP model.")
541
  else:
542
- st.success("✅ CLIP model loaded and ready!")
543
 
544
  with llm_col:
545
  if not st.session_state.llm_model_loaded:
@@ -557,7 +533,7 @@ def main():
557
  st.success("✅ Vision LLM loaded and ready!")
558
 
559
  # Image upload section
560
- with st.expander("Stage 2: Image Upload & Initial Detection", expanded=True):
561
  st.subheader("Upload an Image")
562
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
563
 
@@ -566,14 +542,14 @@ def main():
566
  image = Image.open(uploaded_file).convert("RGB")
567
  st.image(image, caption="Uploaded Image", use_column_width=True)
568
 
569
- # Detect with CLIP model if loaded
570
- if st.session_state.clip_model_loaded:
571
- with st.spinner("Analyzing image with CLIP model..."):
572
- # Preprocess image for CLIP
573
  transform = transforms.Compose([
574
  transforms.Resize((224, 224)),
575
  transforms.ToTensor(),
576
- transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
577
  ])
578
 
579
  # Create a simple dataset for the image
@@ -584,34 +560,27 @@ def main():
584
  # Get device
585
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
586
 
587
- # Move model and tensor to device
588
- model = st.session_state.clip_model.to(device)
589
- tensor = tensor.to(device)
590
-
591
- # Forward pass
592
- with torch.no_grad():
593
- outputs = model.vision_model(pixel_values=tensor).pooler_output
594
- logits = model.classification_head(outputs)
595
- probs = torch.softmax(logits, dim=1)[0]
596
- pred_class = torch.argmax(probs).item()
597
- confidence = probs[pred_class].item()
598
- pred_label = "Fake" if pred_class == 1 else "Real"
599
 
600
  # Display results
601
  result_col1, result_col2 = st.columns(2)
602
  with result_col1:
603
- st.metric("Prediction", pred_label)
604
  with result_col2:
605
  st.metric("Confidence", f"{confidence:.2%}")
606
 
607
  # GradCAM visualization
608
- st.subheader("GradCAM Visualization")
 
 
 
609
  cam, overlay, comparison, detected_face_box = process_image_with_gradcam(
610
- image, model, device, pred_class
611
  )
612
 
613
  # Display GradCAM results
614
- st.image(comparison, caption="Original | CAM | Overlay", use_column_width=True)
615
 
616
  # Save results in session state for LLM analysis
617
  st.session_state.current_image = image
@@ -620,9 +589,9 @@ def main():
620
  st.session_state.current_pred_label = pred_label
621
  st.session_state.current_confidence = confidence
622
 
623
- st.success("✅ Initial detection and GradCAM visualization complete!")
624
  else:
625
- st.warning("⚠️ Please load the CLIP model first to perform initial detection.")
626
 
627
  # LLM Analysis section
628
  with st.expander("Stage 3: Detailed Analysis with Vision LLM", expanded=False):
@@ -630,7 +599,7 @@ def main():
630
  st.subheader("Detailed Deepfake Analysis")
631
 
632
  # Default question with option to customize
633
- default_question = f"This image has been classified as {st.session_state.current_pred_label}. Analyze the key features that led to this classification, focusing on the highlighted areas in the GradCAM visualization. Provide both a technical explanation for experts and a simple explanation for non-technical users."
634
  question = st.text_area("Question/Prompt:", value=default_question, height=100)
635
 
636
  # Analyze button
@@ -673,7 +642,7 @@ def main():
673
  st.subheader("Analysis Result")
674
  st.markdown(result)
675
  elif not hasattr(st.session_state, 'current_image'):
676
- st.warning("⚠️ Please upload an image and complete the initial detection first.")
677
  else:
678
  st.warning("⚠️ Please load the Vision LLM to perform detailed analysis.")
679
 
 
3
  import torch.nn as nn
4
  from torch.utils.data import DataLoader
5
  from torchvision import transforms
6
+ import torchvision.models as models
 
7
  from PIL import Image
8
  import numpy as np
9
  import io
 
64
  # Custom instruction text area in sidebar
65
  custom_instruction = st.sidebar.text_area(
66
  "Custom Instructions (Advanced)",
67
+ value="Focus on analyzing the highlighted regions from the visualization. Examine facial inconsistencies, lighting irregularities, and other artifacts visible in the heat map.",
68
  help="Add specific instructions for the LLM analysis"
69
  )
70
 
 
73
  st.sidebar.subheader("About")
74
  st.sidebar.markdown("""
75
  This analyzer performs multi-stage detection:
76
+ 1. **Initial Detection**: Face detection and feature extraction
77
+ 2. **Visualization**: Highlights suspicious regions
78
  3. **LLM Analysis**: Fine-tuned Llama 3.2 Vision provides detailed explanations
79
 
80
  The system looks for:
 
86
  - Blending problems
87
  """)
88
 
89
+ # ----- Face Detection and Image Processing -----
90
 
91
  class ImageDataset(torch.utils.data.Dataset):
92
  def __init__(self, image, transform=None, face_only=True, dataset_name=None):
 
163
 
164
  return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name
165
 
166
+ # ----- GradCAM Implementation with ResNet -----
167
+
168
  class GradCAM:
169
  def __init__(self, model, target_layer):
170
  self.model = model
 
175
 
176
  def _register_hooks(self):
177
  def forward_hook(module, input, output):
178
+ self.activations = output
 
 
 
179
 
180
  def backward_hook(module, grad_in, grad_out):
181
+ self.gradients = grad_out[0]
 
 
 
182
 
183
+ target_layer = self.target_layer
184
+ target_layer.register_forward_hook(forward_hook)
185
+ target_layer.register_backward_hook(backward_hook)
186
 
187
+ def generate(self, input_tensor, class_idx=0):
188
  self.model.zero_grad()
189
 
190
  try:
191
+ # Forward pass
192
+ output = self.model(input_tensor)
193
 
194
+ # Create a one-hot tensor for the desired class
195
+ one_hot = torch.zeros(output.size(), device=input_tensor.device)
 
 
 
196
  one_hot[0, class_idx] = 1
197
 
198
+ # Backward pass
199
+ output.backward(gradient=one_hot, retain_graph=True)
 
 
 
 
 
 
 
 
 
200
 
201
+ # Get gradients and activations
202
+ gradients = self.gradients.cpu().detach().numpy()[0]
203
+ activations = self.activations.cpu().detach().numpy()[0]
204
+
205
+ # Calculate weights (global average pooling)
206
+ weights = np.mean(gradients, axis=(1, 2))
207
+
208
+ # Calculate CAM
209
+ cam = np.zeros(activations.shape[1:], dtype=np.float32)
210
+ for i, w in enumerate(weights):
211
+ cam += w * activations[i, :, :]
212
+
213
+ # Apply ReLU and normalize
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  cam = np.maximum(cam, 0)
215
  if np.max(cam) > 0:
216
  cam = cam / np.max(cam)
 
219
 
220
  except Exception as e:
221
  st.error(f"Error in GradCAM.generate: {str(e)}")
222
+ # Return a default heatmap
223
+ return np.ones((7, 7), dtype=np.float32) * 0.5
224
 
225
  def overlay_cam_on_image(image, cam, face_box=None, alpha=0.5):
226
  """Overlay the CAM on the image"""
 
289
  else:
290
  cam_resized = cv2.resize(cam, (image.width, image.height))
291
  axes[1].imshow(cam_resized, cmap="jet")
292
+ axes[1].set_title("Heatmap")
293
  axes[1].axis("off")
294
 
295
  # Overlay
 
306
  buf.seek(0)
307
  return Image.open(buf)
308
 
309
+ # Function to load ResNet model for feature extraction
310
  @st.cache_resource
311
+ def load_resnet_model():
312
+ with st.spinner("Loading ResNet model for visualization..."):
313
+ # Load a pretrained ResNet
314
+ model = models.resnet50(pretrained=True)
315
+ # Set to evaluation mode
 
 
 
 
316
  model.eval()
317
  return model
318
 
319
+ def get_target_layer_resnet(model):
320
+ """Get the target layer for GradCAM in ResNet"""
321
+ return model.layer4[-1]
322
 
323
+ def process_image_with_gradcam(image, model, device, class_idx=0):
324
+ """Process an image with GradCAM using ResNet"""
325
  # Set up transformations
326
  transform = transforms.Compose([
327
  transforms.Resize((224, 224)),
328
  transforms.ToTensor(),
329
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
330
  ])
331
 
332
  # Create dataset for the single image
 
361
 
362
  try:
363
  # Create GradCAM extractor
364
+ target_layer = get_target_layer_resnet(model)
365
  cam_extractor = GradCAM(model, target_layer)
366
 
367
  # Generate CAM
368
+ cam = cam_extractor.generate(input_tensor, class_idx)
369
 
370
  # Create visualizations
371
  overlay = overlay_cam_on_image(original_image, cam, face_box)
 
377
  except Exception as e:
378
  st.error(f"Error processing image with GradCAM: {str(e)}")
379
  # Return default values
380
+ default_cam = np.ones((7, 7), dtype=np.float32) * 0.5
381
  overlay = overlay_cam_on_image(original_image, default_cam, face_box)
382
  comparison = save_comparison(original_image, default_cam, overlay, face_box)
383
  return default_cam, overlay, comparison, face_box
384
 
385
+ # ----- Face Analysis Functions -----
386
+
387
+ def analyze_face(face_tensor, device):
388
+ """Simple face analysis to determine if real or fake (simplified)"""
389
+ # This is a placeholder function - in a real app, you would use a trained classifier
390
+ # We'll return random values for now
391
+ rand_val = np.random.random()
392
+ is_fake = rand_val > 0.5
393
+ confidence = rand_val if is_fake else 1 - rand_val
394
+ return "Fake" if is_fake else "Real", confidence
395
+
396
  # ----- Fine-tuned Vision LLM -----
397
 
398
  # Function to fix cross-attention masks
 
437
  def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confidence, question, model, tokenizer, temperature=0.7, max_tokens=500, custom_instruction=""):
438
  # Create a prompt that includes GradCAM information
439
  if custom_instruction.strip():
440
+ full_prompt = f"{question}\n\nThe image has been preliminarily classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show areas of interest for analysis.\n\n{custom_instruction}"
441
  else:
442
+ full_prompt = f"{question}\n\nThe image has been preliminarily classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show areas of interest for analysis."
443
 
444
  # Format the message to include both the original image and the GradCAM visualization
445
  messages = [
 
489
  # Main app
490
  def main():
491
  # Create placeholders for model state
492
+ if 'resnet_model_loaded' not in st.session_state:
493
+ st.session_state.resnet_model_loaded = False
494
+ st.session_state.resnet_model = None
495
 
496
  if 'llm_model_loaded' not in st.session_state:
497
  st.session_state.llm_model_loaded = False
 
500
 
501
  # Create expanders for each stage
502
  with st.expander("Stage 1: Model Loading", expanded=True):
503
+ # Button for loading models
504
+ resnet_col, llm_col = st.columns(2)
505
 
506
+ with resnet_col:
507
+ if not st.session_state.resnet_model_loaded:
508
+ if st.button("📥 Load ResNet for Visualization", type="primary"):
509
+ # Load ResNet model
510
+ model = load_resnet_model()
511
  if model is not None:
512
+ st.session_state.resnet_model = model
513
+ st.session_state.resnet_model_loaded = True
514
+ st.success("✅ ResNet model loaded successfully!")
515
  else:
516
+ st.error("❌ Failed to load ResNet model.")
517
  else:
518
+ st.success("✅ ResNet model loaded and ready!")
519
 
520
  with llm_col:
521
  if not st.session_state.llm_model_loaded:
 
533
  st.success("✅ Vision LLM loaded and ready!")
534
 
535
  # Image upload section
536
+ with st.expander("Stage 2: Image Upload & Initial Analysis", expanded=True):
537
  st.subheader("Upload an Image")
538
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
539
 
 
542
  image = Image.open(uploaded_file).convert("RGB")
543
  st.image(image, caption="Uploaded Image", use_column_width=True)
544
 
545
+ # Analyze with ResNet model if loaded
546
+ if st.session_state.resnet_model_loaded:
547
+ with st.spinner("Analyzing image..."):
548
+ # Preprocess image
549
  transform = transforms.Compose([
550
  transforms.Resize((224, 224)),
551
  transforms.ToTensor(),
552
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
553
  ])
554
 
555
  # Create a simple dataset for the image
 
560
  # Get device
561
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
562
 
563
+ # Analyze face (simplified)
564
+ pred_label, confidence = analyze_face(tensor, device)
 
 
 
 
 
 
 
 
 
 
565
 
566
  # Display results
567
  result_col1, result_col2 = st.columns(2)
568
  with result_col1:
569
+ st.metric("Initial Classification", pred_label)
570
  with result_col2:
571
  st.metric("Confidence", f"{confidence:.2%}")
572
 
573
  # GradCAM visualization
574
+ st.subheader("Feature Visualization")
575
+
576
+ # Use the first class for visualization
577
+ class_idx = 0
578
  cam, overlay, comparison, detected_face_box = process_image_with_gradcam(
579
+ image, st.session_state.resnet_model, device, class_idx
580
  )
581
 
582
  # Display GradCAM results
583
+ st.image(comparison, caption="Original | Heatmap | Overlay", use_column_width=True)
584
 
585
  # Save results in session state for LLM analysis
586
  st.session_state.current_image = image
 
589
  st.session_state.current_pred_label = pred_label
590
  st.session_state.current_confidence = confidence
591
 
592
+ st.success("✅ Initial analysis and visualization complete!")
593
  else:
594
+ st.warning("⚠️ Please load the ResNet model first to perform initial analysis.")
595
 
596
  # LLM Analysis section
597
  with st.expander("Stage 3: Detailed Analysis with Vision LLM", expanded=False):
 
599
  st.subheader("Detailed Deepfake Analysis")
600
 
601
  # Default question with option to customize
602
+ default_question = f"This image has been preliminarily classified as {st.session_state.current_pred_label}. Analyze whether this image is a deepfake, focusing on the highlighted areas in the visualization. Provide both a technical explanation for experts and a simple explanation for non-technical users."
603
  question = st.text_area("Question/Prompt:", value=default_question, height=100)
604
 
605
  # Analyze button
 
642
  st.subheader("Analysis Result")
643
  st.markdown(result)
644
  elif not hasattr(st.session_state, 'current_image'):
645
+ st.warning("⚠️ Please upload an image and complete the initial analysis first.")
646
  else:
647
  st.warning("⚠️ Please load the Vision LLM to perform detailed analysis.")
648