Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ import torch
|
|
3 |
import torch.nn as nn
|
4 |
from torch.utils.data import DataLoader
|
5 |
from torchvision import transforms
|
6 |
-
|
7 |
from PIL import Image
|
8 |
import numpy as np
|
9 |
import io
|
@@ -64,7 +64,7 @@ max_tokens = st.sidebar.slider(
|
|
64 |
# Custom instruction text area in sidebar
|
65 |
custom_instruction = st.sidebar.text_area(
|
66 |
"Custom Instructions (Advanced)",
|
67 |
-
value="Focus on analyzing the highlighted regions from the visualization. Examine facial inconsistencies, lighting irregularities, and other artifacts visible in the heat map.",
|
68 |
help="Add specific instructions for the LLM analysis"
|
69 |
)
|
70 |
|
@@ -73,8 +73,8 @@ st.sidebar.markdown("---")
|
|
73 |
st.sidebar.subheader("About")
|
74 |
st.sidebar.markdown("""
|
75 |
This analyzer performs multi-stage detection:
|
76 |
-
1. **Initial Detection**:
|
77 |
-
2. **Visualization**: Highlights suspicious regions
|
78 |
3. **LLM Analysis**: Fine-tuned Llama 3.2 Vision provides detailed explanations
|
79 |
|
80 |
The system looks for:
|
@@ -86,7 +86,7 @@ The system looks for:
|
|
86 |
- Blending problems
|
87 |
""")
|
88 |
|
89 |
-
# -----
|
90 |
|
91 |
class ImageDataset(torch.utils.data.Dataset):
|
92 |
def __init__(self, image, transform=None, face_only=True, dataset_name=None):
|
@@ -163,8 +163,6 @@ class ImageDataset(torch.utils.data.Dataset):
|
|
163 |
|
164 |
return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name
|
165 |
|
166 |
-
# ----- GradCAM Implementation with ResNet -----
|
167 |
-
|
168 |
class GradCAM:
|
169 |
def __init__(self, model, target_layer):
|
170 |
self.model = model
|
@@ -175,42 +173,75 @@ class GradCAM:
|
|
175 |
|
176 |
def _register_hooks(self):
|
177 |
def forward_hook(module, input, output):
|
178 |
-
|
|
|
|
|
|
|
179 |
|
180 |
def backward_hook(module, grad_in, grad_out):
|
181 |
-
|
|
|
|
|
|
|
182 |
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
|
187 |
-
def generate(self, input_tensor, class_idx
|
188 |
self.model.zero_grad()
|
189 |
|
190 |
try:
|
191 |
-
#
|
192 |
-
|
193 |
-
|
194 |
-
# Create a one-hot tensor for the desired class
|
195 |
-
one_hot = torch.zeros(output.size(), device=input_tensor.device)
|
196 |
-
one_hot[0, class_idx] = 1
|
197 |
|
198 |
-
#
|
199 |
-
|
200 |
|
201 |
-
#
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
# Calculate weights (global average pooling)
|
206 |
-
weights = np.mean(gradients, axis=(1, 2))
|
207 |
|
208 |
-
#
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
cam = np.maximum(cam, 0)
|
215 |
if np.max(cam) > 0:
|
216 |
cam = cam / np.max(cam)
|
@@ -219,8 +250,7 @@ class GradCAM:
|
|
219 |
|
220 |
except Exception as e:
|
221 |
st.error(f"Error in GradCAM.generate: {str(e)}")
|
222 |
-
|
223 |
-
return np.ones((7, 7), dtype=np.float32) * 0.5
|
224 |
|
225 |
def overlay_cam_on_image(image, cam, face_box=None, alpha=0.5):
|
226 |
"""Overlay the CAM on the image"""
|
@@ -289,7 +319,7 @@ def save_comparison(image, cam, overlay, face_box=None):
|
|
289 |
else:
|
290 |
cam_resized = cv2.resize(cam, (image.width, image.height))
|
291 |
axes[1].imshow(cam_resized, cmap="jet")
|
292 |
-
axes[1].set_title("
|
293 |
axes[1].axis("off")
|
294 |
|
295 |
# Overlay
|
@@ -306,27 +336,31 @@ def save_comparison(image, cam, overlay, face_box=None):
|
|
306 |
buf.seek(0)
|
307 |
return Image.open(buf)
|
308 |
|
309 |
-
# Function to load
|
310 |
@st.cache_resource
|
311 |
-
def
|
312 |
-
with st.spinner("Loading
|
313 |
-
|
314 |
-
|
315 |
-
#
|
|
|
|
|
|
|
|
|
316 |
model.eval()
|
317 |
return model
|
318 |
|
319 |
-
def
|
320 |
-
"""Get the target layer for GradCAM
|
321 |
-
return
|
322 |
|
323 |
-
def process_image_with_gradcam(image, model, device,
|
324 |
-
"""Process an image with GradCAM
|
325 |
# Set up transformations
|
326 |
transform = transforms.Compose([
|
327 |
transforms.Resize((224, 224)),
|
328 |
transforms.ToTensor(),
|
329 |
-
transforms.Normalize(mean=[0.
|
330 |
])
|
331 |
|
332 |
# Create dataset for the single image
|
@@ -361,11 +395,11 @@ def process_image_with_gradcam(image, model, device, class_idx=0):
|
|
361 |
|
362 |
try:
|
363 |
# Create GradCAM extractor
|
364 |
-
target_layer =
|
365 |
cam_extractor = GradCAM(model, target_layer)
|
366 |
|
367 |
# Generate CAM
|
368 |
-
cam = cam_extractor.generate(input_tensor,
|
369 |
|
370 |
# Create visualizations
|
371 |
overlay = overlay_cam_on_image(original_image, cam, face_box)
|
@@ -377,22 +411,11 @@ def process_image_with_gradcam(image, model, device, class_idx=0):
|
|
377 |
except Exception as e:
|
378 |
st.error(f"Error processing image with GradCAM: {str(e)}")
|
379 |
# Return default values
|
380 |
-
default_cam = np.ones((
|
381 |
overlay = overlay_cam_on_image(original_image, default_cam, face_box)
|
382 |
comparison = save_comparison(original_image, default_cam, overlay, face_box)
|
383 |
return default_cam, overlay, comparison, face_box
|
384 |
|
385 |
-
# ----- Face Analysis Functions -----
|
386 |
-
|
387 |
-
def analyze_face(face_tensor, device):
|
388 |
-
"""Simple face analysis to determine if real or fake (simplified)"""
|
389 |
-
# This is a placeholder function - in a real app, you would use a trained classifier
|
390 |
-
# We'll return random values for now
|
391 |
-
rand_val = np.random.random()
|
392 |
-
is_fake = rand_val > 0.5
|
393 |
-
confidence = rand_val if is_fake else 1 - rand_val
|
394 |
-
return "Fake" if is_fake else "Real", confidence
|
395 |
-
|
396 |
# ----- Fine-tuned Vision LLM -----
|
397 |
|
398 |
# Function to fix cross-attention masks
|
@@ -437,9 +460,9 @@ def load_llm_model():
|
|
437 |
def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confidence, question, model, tokenizer, temperature=0.7, max_tokens=500, custom_instruction=""):
|
438 |
# Create a prompt that includes GradCAM information
|
439 |
if custom_instruction.strip():
|
440 |
-
full_prompt = f"{question}\n\nThe image has been
|
441 |
else:
|
442 |
-
full_prompt = f"{question}\n\nThe image has been
|
443 |
|
444 |
# Format the message to include both the original image and the GradCAM visualization
|
445 |
messages = [
|
@@ -489,9 +512,9 @@ def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confide
|
|
489 |
# Main app
|
490 |
def main():
|
491 |
# Create placeholders for model state
|
492 |
-
if '
|
493 |
-
st.session_state.
|
494 |
-
st.session_state.
|
495 |
|
496 |
if 'llm_model_loaded' not in st.session_state:
|
497 |
st.session_state.llm_model_loaded = False
|
@@ -500,22 +523,22 @@ def main():
|
|
500 |
|
501 |
# Create expanders for each stage
|
502 |
with st.expander("Stage 1: Model Loading", expanded=True):
|
503 |
-
# Button for loading
|
504 |
-
|
505 |
|
506 |
-
with
|
507 |
-
if not st.session_state.
|
508 |
-
if st.button("📥 Load
|
509 |
-
# Load
|
510 |
-
model =
|
511 |
if model is not None:
|
512 |
-
st.session_state.
|
513 |
-
st.session_state.
|
514 |
-
st.success("✅
|
515 |
else:
|
516 |
-
st.error("❌ Failed to load
|
517 |
else:
|
518 |
-
st.success("✅
|
519 |
|
520 |
with llm_col:
|
521 |
if not st.session_state.llm_model_loaded:
|
@@ -533,7 +556,7 @@ def main():
|
|
533 |
st.success("✅ Vision LLM loaded and ready!")
|
534 |
|
535 |
# Image upload section
|
536 |
-
with st.expander("Stage 2: Image Upload & Initial
|
537 |
st.subheader("Upload an Image")
|
538 |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
539 |
|
@@ -542,14 +565,14 @@ def main():
|
|
542 |
image = Image.open(uploaded_file).convert("RGB")
|
543 |
st.image(image, caption="Uploaded Image", use_column_width=True)
|
544 |
|
545 |
-
#
|
546 |
-
if st.session_state.
|
547 |
-
with st.spinner("Analyzing image..."):
|
548 |
-
# Preprocess image
|
549 |
transform = transforms.Compose([
|
550 |
transforms.Resize((224, 224)),
|
551 |
transforms.ToTensor(),
|
552 |
-
transforms.Normalize(mean=[0.
|
553 |
])
|
554 |
|
555 |
# Create a simple dataset for the image
|
@@ -560,27 +583,34 @@ def main():
|
|
560 |
# Get device
|
561 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
562 |
|
563 |
-
#
|
564 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
565 |
|
566 |
# Display results
|
567 |
result_col1, result_col2 = st.columns(2)
|
568 |
with result_col1:
|
569 |
-
st.metric("
|
570 |
with result_col2:
|
571 |
st.metric("Confidence", f"{confidence:.2%}")
|
572 |
|
573 |
# GradCAM visualization
|
574 |
-
st.subheader("
|
575 |
-
|
576 |
-
# Use the first class for visualization
|
577 |
-
class_idx = 0
|
578 |
cam, overlay, comparison, detected_face_box = process_image_with_gradcam(
|
579 |
-
image,
|
580 |
)
|
581 |
|
582 |
# Display GradCAM results
|
583 |
-
st.image(comparison, caption="Original |
|
584 |
|
585 |
# Save results in session state for LLM analysis
|
586 |
st.session_state.current_image = image
|
@@ -589,9 +619,9 @@ def main():
|
|
589 |
st.session_state.current_pred_label = pred_label
|
590 |
st.session_state.current_confidence = confidence
|
591 |
|
592 |
-
st.success("✅ Initial
|
593 |
else:
|
594 |
-
st.warning("⚠️ Please load the
|
595 |
|
596 |
# LLM Analysis section
|
597 |
with st.expander("Stage 3: Detailed Analysis with Vision LLM", expanded=False):
|
@@ -599,7 +629,7 @@ def main():
|
|
599 |
st.subheader("Detailed Deepfake Analysis")
|
600 |
|
601 |
# Default question with option to customize
|
602 |
-
default_question = f"This image has been
|
603 |
question = st.text_area("Question/Prompt:", value=default_question, height=100)
|
604 |
|
605 |
# Analyze button
|
@@ -642,7 +672,7 @@ def main():
|
|
642 |
st.subheader("Analysis Result")
|
643 |
st.markdown(result)
|
644 |
elif not hasattr(st.session_state, 'current_image'):
|
645 |
-
st.warning("⚠️ Please upload an image and complete the initial
|
646 |
else:
|
647 |
st.warning("⚠️ Please load the Vision LLM to perform detailed analysis.")
|
648 |
|
|
|
3 |
import torch.nn as nn
|
4 |
from torch.utils.data import DataLoader
|
5 |
from torchvision import transforms
|
6 |
+
from transformers.models.clip import CLIPModel # Updated import path
|
7 |
from PIL import Image
|
8 |
import numpy as np
|
9 |
import io
|
|
|
64 |
# Custom instruction text area in sidebar
|
65 |
custom_instruction = st.sidebar.text_area(
|
66 |
"Custom Instructions (Advanced)",
|
67 |
+
value="Focus on analyzing the highlighted regions from the GradCAM visualization. Examine facial inconsistencies, lighting irregularities, and other artifacts visible in the heat map.",
|
68 |
help="Add specific instructions for the LLM analysis"
|
69 |
)
|
70 |
|
|
|
73 |
st.sidebar.subheader("About")
|
74 |
st.sidebar.markdown("""
|
75 |
This analyzer performs multi-stage detection:
|
76 |
+
1. **Initial Detection**: CLIP-based classifier
|
77 |
+
2. **GradCAM Visualization**: Highlights suspicious regions
|
78 |
3. **LLM Analysis**: Fine-tuned Llama 3.2 Vision provides detailed explanations
|
79 |
|
80 |
The system looks for:
|
|
|
86 |
- Blending problems
|
87 |
""")
|
88 |
|
89 |
+
# ----- GradCAM Implementation -----
|
90 |
|
91 |
class ImageDataset(torch.utils.data.Dataset):
|
92 |
def __init__(self, image, transform=None, face_only=True, dataset_name=None):
|
|
|
163 |
|
164 |
return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name
|
165 |
|
|
|
|
|
166 |
class GradCAM:
|
167 |
def __init__(self, model, target_layer):
|
168 |
self.model = model
|
|
|
173 |
|
174 |
def _register_hooks(self):
|
175 |
def forward_hook(module, input, output):
|
176 |
+
if isinstance(output, tuple):
|
177 |
+
self.activations = output[0]
|
178 |
+
else:
|
179 |
+
self.activations = output
|
180 |
|
181 |
def backward_hook(module, grad_in, grad_out):
|
182 |
+
if isinstance(grad_out, tuple):
|
183 |
+
self.gradients = grad_out[0]
|
184 |
+
else:
|
185 |
+
self.gradients = grad_out
|
186 |
|
187 |
+
layer = dict([*self.model.named_modules()])[self.target_layer]
|
188 |
+
layer.register_forward_hook(forward_hook)
|
189 |
+
layer.register_backward_hook(backward_hook)
|
190 |
|
191 |
+
def generate(self, input_tensor, class_idx):
|
192 |
self.model.zero_grad()
|
193 |
|
194 |
try:
|
195 |
+
# Use only the vision part of the model for gradient calculation
|
196 |
+
vision_outputs = self.model.vision_model(pixel_values=input_tensor)
|
|
|
|
|
|
|
|
|
197 |
|
198 |
+
# Get the pooler output
|
199 |
+
features = vision_outputs.pooler_output
|
200 |
|
201 |
+
# Create a dummy gradient for the feature based on the class idx
|
202 |
+
one_hot = torch.zeros_like(features)
|
203 |
+
one_hot[0, class_idx] = 1
|
|
|
|
|
|
|
204 |
|
205 |
+
# Manually backpropagate
|
206 |
+
features.backward(gradient=one_hot)
|
207 |
+
|
208 |
+
# Check for None values
|
209 |
+
if self.gradients is None or self.activations is None:
|
210 |
+
st.warning("Warning: Gradients or activations are None. Using fallback CAM.")
|
211 |
+
return np.ones((14, 14), dtype=np.float32) * 0.5
|
212 |
+
|
213 |
+
# Process gradients and activations for transformer-based model
|
214 |
+
gradients = self.gradients.cpu().detach().numpy()
|
215 |
+
activations = self.activations.cpu().detach().numpy()
|
216 |
|
217 |
+
if len(activations.shape) == 3: # [batch, sequence_length, hidden_dim]
|
218 |
+
seq_len = activations.shape[1]
|
219 |
+
|
220 |
+
# CLIP ViT typically has 196 patch tokens (14×14) + 1 class token = 197
|
221 |
+
if seq_len >= 197:
|
222 |
+
# Skip the class token (first token) and reshape the patch tokens into a square
|
223 |
+
patch_tokens = activations[0, 1:197, :] # Remove the class token
|
224 |
+
# Take the mean across the hidden dimension
|
225 |
+
token_importance = np.mean(np.abs(patch_tokens), axis=1)
|
226 |
+
# Reshape to the expected grid size (14×14 for CLIP ViT)
|
227 |
+
cam = token_importance.reshape(14, 14)
|
228 |
+
else:
|
229 |
+
# Try to find factors close to a square
|
230 |
+
side_len = int(np.sqrt(seq_len))
|
231 |
+
# Use the mean across features as importance
|
232 |
+
token_importance = np.mean(np.abs(activations[0]), axis=1)
|
233 |
+
# Create as square-like shape as possible
|
234 |
+
cam = np.zeros((side_len, side_len))
|
235 |
+
# Fill the cam with available values
|
236 |
+
flat_cam = cam.flatten()
|
237 |
+
flat_cam[:min(len(token_importance), len(flat_cam))] = token_importance[:min(len(token_importance), len(flat_cam))]
|
238 |
+
cam = flat_cam.reshape(side_len, side_len)
|
239 |
+
else:
|
240 |
+
# Fallback
|
241 |
+
st.info("Using fallback CAM shape (14x14)")
|
242 |
+
cam = np.ones((14, 14), dtype=np.float32) * 0.5 # Default fallback
|
243 |
+
|
244 |
+
# Ensure we have valid values
|
245 |
cam = np.maximum(cam, 0)
|
246 |
if np.max(cam) > 0:
|
247 |
cam = cam / np.max(cam)
|
|
|
250 |
|
251 |
except Exception as e:
|
252 |
st.error(f"Error in GradCAM.generate: {str(e)}")
|
253 |
+
return np.ones((14, 14), dtype=np.float32) * 0.5
|
|
|
254 |
|
255 |
def overlay_cam_on_image(image, cam, face_box=None, alpha=0.5):
|
256 |
"""Overlay the CAM on the image"""
|
|
|
319 |
else:
|
320 |
cam_resized = cv2.resize(cam, (image.width, image.height))
|
321 |
axes[1].imshow(cam_resized, cmap="jet")
|
322 |
+
axes[1].set_title("CAM")
|
323 |
axes[1].axis("off")
|
324 |
|
325 |
# Overlay
|
|
|
336 |
buf.seek(0)
|
337 |
return Image.open(buf)
|
338 |
|
339 |
+
# Function to load GradCAM CLIP model
|
340 |
@st.cache_resource
|
341 |
+
def load_clip_model():
|
342 |
+
with st.spinner("Loading CLIP model for GradCAM..."):
|
343 |
+
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
344 |
+
|
345 |
+
# Apply a simple classification head
|
346 |
+
model.classification_head = nn.Linear(1024, 2)
|
347 |
+
model.classification_head.weight.data.normal_(mean=0.0, std=0.02)
|
348 |
+
model.classification_head.bias.data.zero_()
|
349 |
+
|
350 |
model.eval()
|
351 |
return model
|
352 |
|
353 |
+
def get_target_layer_clip(model):
|
354 |
+
"""Get the target layer for GradCAM"""
|
355 |
+
return "vision_model.encoder.layers.23"
|
356 |
|
357 |
+
def process_image_with_gradcam(image, model, device, pred_class):
|
358 |
+
"""Process an image with GradCAM"""
|
359 |
# Set up transformations
|
360 |
transform = transforms.Compose([
|
361 |
transforms.Resize((224, 224)),
|
362 |
transforms.ToTensor(),
|
363 |
+
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
|
364 |
])
|
365 |
|
366 |
# Create dataset for the single image
|
|
|
395 |
|
396 |
try:
|
397 |
# Create GradCAM extractor
|
398 |
+
target_layer = get_target_layer_clip(model)
|
399 |
cam_extractor = GradCAM(model, target_layer)
|
400 |
|
401 |
# Generate CAM
|
402 |
+
cam = cam_extractor.generate(input_tensor, pred_class)
|
403 |
|
404 |
# Create visualizations
|
405 |
overlay = overlay_cam_on_image(original_image, cam, face_box)
|
|
|
411 |
except Exception as e:
|
412 |
st.error(f"Error processing image with GradCAM: {str(e)}")
|
413 |
# Return default values
|
414 |
+
default_cam = np.ones((14, 14), dtype=np.float32) * 0.5
|
415 |
overlay = overlay_cam_on_image(original_image, default_cam, face_box)
|
416 |
comparison = save_comparison(original_image, default_cam, overlay, face_box)
|
417 |
return default_cam, overlay, comparison, face_box
|
418 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
# ----- Fine-tuned Vision LLM -----
|
420 |
|
421 |
# Function to fix cross-attention masks
|
|
|
460 |
def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confidence, question, model, tokenizer, temperature=0.7, max_tokens=500, custom_instruction=""):
|
461 |
# Create a prompt that includes GradCAM information
|
462 |
if custom_instruction.strip():
|
463 |
+
full_prompt = f"{question}\n\nThe image has been processed with GradCAM and classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show the areas the detection model found suspicious.\n\n{custom_instruction}"
|
464 |
else:
|
465 |
+
full_prompt = f"{question}\n\nThe image has been processed with GradCAM and classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show the areas the detection model found suspicious."
|
466 |
|
467 |
# Format the message to include both the original image and the GradCAM visualization
|
468 |
messages = [
|
|
|
512 |
# Main app
|
513 |
def main():
|
514 |
# Create placeholders for model state
|
515 |
+
if 'clip_model_loaded' not in st.session_state:
|
516 |
+
st.session_state.clip_model_loaded = False
|
517 |
+
st.session_state.clip_model = None
|
518 |
|
519 |
if 'llm_model_loaded' not in st.session_state:
|
520 |
st.session_state.llm_model_loaded = False
|
|
|
523 |
|
524 |
# Create expanders for each stage
|
525 |
with st.expander("Stage 1: Model Loading", expanded=True):
|
526 |
+
# Button for loading CLIP model
|
527 |
+
clip_col, llm_col = st.columns(2)
|
528 |
|
529 |
+
with clip_col:
|
530 |
+
if not st.session_state.clip_model_loaded:
|
531 |
+
if st.button("📥 Load CLIP Model for Detection", type="primary"):
|
532 |
+
# Load CLIP model
|
533 |
+
model = load_clip_model()
|
534 |
if model is not None:
|
535 |
+
st.session_state.clip_model = model
|
536 |
+
st.session_state.clip_model_loaded = True
|
537 |
+
st.success("✅ CLIP model loaded successfully!")
|
538 |
else:
|
539 |
+
st.error("❌ Failed to load CLIP model.")
|
540 |
else:
|
541 |
+
st.success("✅ CLIP model loaded and ready!")
|
542 |
|
543 |
with llm_col:
|
544 |
if not st.session_state.llm_model_loaded:
|
|
|
556 |
st.success("✅ Vision LLM loaded and ready!")
|
557 |
|
558 |
# Image upload section
|
559 |
+
with st.expander("Stage 2: Image Upload & Initial Detection", expanded=True):
|
560 |
st.subheader("Upload an Image")
|
561 |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
562 |
|
|
|
565 |
image = Image.open(uploaded_file).convert("RGB")
|
566 |
st.image(image, caption="Uploaded Image", use_column_width=True)
|
567 |
|
568 |
+
# Detect with CLIP model if loaded
|
569 |
+
if st.session_state.clip_model_loaded:
|
570 |
+
with st.spinner("Analyzing image with CLIP model..."):
|
571 |
+
# Preprocess image for CLIP
|
572 |
transform = transforms.Compose([
|
573 |
transforms.Resize((224, 224)),
|
574 |
transforms.ToTensor(),
|
575 |
+
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
|
576 |
])
|
577 |
|
578 |
# Create a simple dataset for the image
|
|
|
583 |
# Get device
|
584 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
585 |
|
586 |
+
# Move model and tensor to device
|
587 |
+
model = st.session_state.clip_model.to(device)
|
588 |
+
tensor = tensor.to(device)
|
589 |
+
|
590 |
+
# Forward pass
|
591 |
+
with torch.no_grad():
|
592 |
+
outputs = model.vision_model(pixel_values=tensor).pooler_output
|
593 |
+
logits = model.classification_head(outputs)
|
594 |
+
probs = torch.softmax(logits, dim=1)[0]
|
595 |
+
pred_class = torch.argmax(probs).item()
|
596 |
+
confidence = probs[pred_class].item()
|
597 |
+
pred_label = "Fake" if pred_class == 1 else "Real"
|
598 |
|
599 |
# Display results
|
600 |
result_col1, result_col2 = st.columns(2)
|
601 |
with result_col1:
|
602 |
+
st.metric("Prediction", pred_label)
|
603 |
with result_col2:
|
604 |
st.metric("Confidence", f"{confidence:.2%}")
|
605 |
|
606 |
# GradCAM visualization
|
607 |
+
st.subheader("GradCAM Visualization")
|
|
|
|
|
|
|
608 |
cam, overlay, comparison, detected_face_box = process_image_with_gradcam(
|
609 |
+
image, model, device, pred_class
|
610 |
)
|
611 |
|
612 |
# Display GradCAM results
|
613 |
+
st.image(comparison, caption="Original | CAM | Overlay", use_column_width=True)
|
614 |
|
615 |
# Save results in session state for LLM analysis
|
616 |
st.session_state.current_image = image
|
|
|
619 |
st.session_state.current_pred_label = pred_label
|
620 |
st.session_state.current_confidence = confidence
|
621 |
|
622 |
+
st.success("✅ Initial detection and GradCAM visualization complete!")
|
623 |
else:
|
624 |
+
st.warning("⚠️ Please load the CLIP model first to perform initial detection.")
|
625 |
|
626 |
# LLM Analysis section
|
627 |
with st.expander("Stage 3: Detailed Analysis with Vision LLM", expanded=False):
|
|
|
629 |
st.subheader("Detailed Deepfake Analysis")
|
630 |
|
631 |
# Default question with option to customize
|
632 |
+
default_question = f"This image has been classified as {st.session_state.current_pred_label}. Analyze the key features that led to this classification, focusing on the highlighted areas in the GradCAM visualization. Provide both a technical explanation for experts and a simple explanation for non-technical users."
|
633 |
question = st.text_area("Question/Prompt:", value=default_question, height=100)
|
634 |
|
635 |
# Analyze button
|
|
|
672 |
st.subheader("Analysis Result")
|
673 |
st.markdown(result)
|
674 |
elif not hasattr(st.session_state, 'current_image'):
|
675 |
+
st.warning("⚠️ Please upload an image and complete the initial detection first.")
|
676 |
else:
|
677 |
st.warning("⚠️ Please load the Vision LLM to perform detailed analysis.")
|
678 |
|