Update app.py
Browse files
app.py
CHANGED
@@ -3,8 +3,7 @@ import torch
|
|
3 |
import torch.nn as nn
|
4 |
from torch.utils.data import DataLoader
|
5 |
from torchvision import transforms
|
6 |
-
|
7 |
-
from transformers.models.clip import CLIPModel
|
8 |
from PIL import Image
|
9 |
import numpy as np
|
10 |
import io
|
@@ -65,7 +64,7 @@ max_tokens = st.sidebar.slider(
|
|
65 |
# Custom instruction text area in sidebar
|
66 |
custom_instruction = st.sidebar.text_area(
|
67 |
"Custom Instructions (Advanced)",
|
68 |
-
value="Focus on analyzing the highlighted regions from the
|
69 |
help="Add specific instructions for the LLM analysis"
|
70 |
)
|
71 |
|
@@ -74,8 +73,8 @@ st.sidebar.markdown("---")
|
|
74 |
st.sidebar.subheader("About")
|
75 |
st.sidebar.markdown("""
|
76 |
This analyzer performs multi-stage detection:
|
77 |
-
1. **Initial Detection**:
|
78 |
-
2. **
|
79 |
3. **LLM Analysis**: Fine-tuned Llama 3.2 Vision provides detailed explanations
|
80 |
|
81 |
The system looks for:
|
@@ -87,7 +86,7 @@ The system looks for:
|
|
87 |
- Blending problems
|
88 |
""")
|
89 |
|
90 |
-
# -----
|
91 |
|
92 |
class ImageDataset(torch.utils.data.Dataset):
|
93 |
def __init__(self, image, transform=None, face_only=True, dataset_name=None):
|
@@ -164,6 +163,8 @@ class ImageDataset(torch.utils.data.Dataset):
|
|
164 |
|
165 |
return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name
|
166 |
|
|
|
|
|
167 |
class GradCAM:
|
168 |
def __init__(self, model, target_layer):
|
169 |
self.model = model
|
@@ -174,75 +175,42 @@ class GradCAM:
|
|
174 |
|
175 |
def _register_hooks(self):
|
176 |
def forward_hook(module, input, output):
|
177 |
-
|
178 |
-
self.activations = output[0]
|
179 |
-
else:
|
180 |
-
self.activations = output
|
181 |
|
182 |
def backward_hook(module, grad_in, grad_out):
|
183 |
-
|
184 |
-
self.gradients = grad_out[0]
|
185 |
-
else:
|
186 |
-
self.gradients = grad_out
|
187 |
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
|
192 |
-
def generate(self, input_tensor, class_idx):
|
193 |
self.model.zero_grad()
|
194 |
|
195 |
try:
|
196 |
-
#
|
197 |
-
|
198 |
|
199 |
-
#
|
200 |
-
|
201 |
-
|
202 |
-
# Create a dummy gradient for the feature based on the class idx
|
203 |
-
one_hot = torch.zeros_like(features)
|
204 |
one_hot[0, class_idx] = 1
|
205 |
|
206 |
-
#
|
207 |
-
|
208 |
-
|
209 |
-
# Check for None values
|
210 |
-
if self.gradients is None or self.activations is None:
|
211 |
-
st.warning("Warning: Gradients or activations are None. Using fallback CAM.")
|
212 |
-
return np.ones((14, 14), dtype=np.float32) * 0.5
|
213 |
-
|
214 |
-
# Process gradients and activations for transformer-based model
|
215 |
-
gradients = self.gradients.cpu().detach().numpy()
|
216 |
-
activations = self.activations.cpu().detach().numpy()
|
217 |
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
side_len = int(np.sqrt(seq_len))
|
232 |
-
# Use the mean across features as importance
|
233 |
-
token_importance = np.mean(np.abs(activations[0]), axis=1)
|
234 |
-
# Create as square-like shape as possible
|
235 |
-
cam = np.zeros((side_len, side_len))
|
236 |
-
# Fill the cam with available values
|
237 |
-
flat_cam = cam.flatten()
|
238 |
-
flat_cam[:min(len(token_importance), len(flat_cam))] = token_importance[:min(len(token_importance), len(flat_cam))]
|
239 |
-
cam = flat_cam.reshape(side_len, side_len)
|
240 |
-
else:
|
241 |
-
# Fallback
|
242 |
-
st.info("Using fallback CAM shape (14x14)")
|
243 |
-
cam = np.ones((14, 14), dtype=np.float32) * 0.5 # Default fallback
|
244 |
-
|
245 |
-
# Ensure we have valid values
|
246 |
cam = np.maximum(cam, 0)
|
247 |
if np.max(cam) > 0:
|
248 |
cam = cam / np.max(cam)
|
@@ -251,7 +219,8 @@ class GradCAM:
|
|
251 |
|
252 |
except Exception as e:
|
253 |
st.error(f"Error in GradCAM.generate: {str(e)}")
|
254 |
-
|
|
|
255 |
|
256 |
def overlay_cam_on_image(image, cam, face_box=None, alpha=0.5):
|
257 |
"""Overlay the CAM on the image"""
|
@@ -320,7 +289,7 @@ def save_comparison(image, cam, overlay, face_box=None):
|
|
320 |
else:
|
321 |
cam_resized = cv2.resize(cam, (image.width, image.height))
|
322 |
axes[1].imshow(cam_resized, cmap="jet")
|
323 |
-
axes[1].set_title("
|
324 |
axes[1].axis("off")
|
325 |
|
326 |
# Overlay
|
@@ -337,31 +306,27 @@ def save_comparison(image, cam, overlay, face_box=None):
|
|
337 |
buf.seek(0)
|
338 |
return Image.open(buf)
|
339 |
|
340 |
-
# Function to load
|
341 |
@st.cache_resource
|
342 |
-
def
|
343 |
-
with st.spinner("Loading
|
344 |
-
|
345 |
-
|
346 |
-
#
|
347 |
-
model.classification_head = nn.Linear(1024, 2)
|
348 |
-
model.classification_head.weight.data.normal_(mean=0.0, std=0.02)
|
349 |
-
model.classification_head.bias.data.zero_()
|
350 |
-
|
351 |
model.eval()
|
352 |
return model
|
353 |
|
354 |
-
def
|
355 |
-
"""Get the target layer for GradCAM"""
|
356 |
-
return
|
357 |
|
358 |
-
def process_image_with_gradcam(image, model, device,
|
359 |
-
"""Process an image with GradCAM"""
|
360 |
# Set up transformations
|
361 |
transform = transforms.Compose([
|
362 |
transforms.Resize((224, 224)),
|
363 |
transforms.ToTensor(),
|
364 |
-
transforms.Normalize(mean=[0.
|
365 |
])
|
366 |
|
367 |
# Create dataset for the single image
|
@@ -396,11 +361,11 @@ def process_image_with_gradcam(image, model, device, pred_class):
|
|
396 |
|
397 |
try:
|
398 |
# Create GradCAM extractor
|
399 |
-
target_layer =
|
400 |
cam_extractor = GradCAM(model, target_layer)
|
401 |
|
402 |
# Generate CAM
|
403 |
-
cam = cam_extractor.generate(input_tensor,
|
404 |
|
405 |
# Create visualizations
|
406 |
overlay = overlay_cam_on_image(original_image, cam, face_box)
|
@@ -412,11 +377,22 @@ def process_image_with_gradcam(image, model, device, pred_class):
|
|
412 |
except Exception as e:
|
413 |
st.error(f"Error processing image with GradCAM: {str(e)}")
|
414 |
# Return default values
|
415 |
-
default_cam = np.ones((
|
416 |
overlay = overlay_cam_on_image(original_image, default_cam, face_box)
|
417 |
comparison = save_comparison(original_image, default_cam, overlay, face_box)
|
418 |
return default_cam, overlay, comparison, face_box
|
419 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
# ----- Fine-tuned Vision LLM -----
|
421 |
|
422 |
# Function to fix cross-attention masks
|
@@ -461,9 +437,9 @@ def load_llm_model():
|
|
461 |
def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confidence, question, model, tokenizer, temperature=0.7, max_tokens=500, custom_instruction=""):
|
462 |
# Create a prompt that includes GradCAM information
|
463 |
if custom_instruction.strip():
|
464 |
-
full_prompt = f"{question}\n\nThe image has been
|
465 |
else:
|
466 |
-
full_prompt = f"{question}\n\nThe image has been
|
467 |
|
468 |
# Format the message to include both the original image and the GradCAM visualization
|
469 |
messages = [
|
@@ -513,9 +489,9 @@ def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confide
|
|
513 |
# Main app
|
514 |
def main():
|
515 |
# Create placeholders for model state
|
516 |
-
if '
|
517 |
-
st.session_state.
|
518 |
-
st.session_state.
|
519 |
|
520 |
if 'llm_model_loaded' not in st.session_state:
|
521 |
st.session_state.llm_model_loaded = False
|
@@ -524,22 +500,22 @@ def main():
|
|
524 |
|
525 |
# Create expanders for each stage
|
526 |
with st.expander("Stage 1: Model Loading", expanded=True):
|
527 |
-
# Button for loading
|
528 |
-
|
529 |
|
530 |
-
with
|
531 |
-
if not st.session_state.
|
532 |
-
if st.button("📥 Load
|
533 |
-
# Load
|
534 |
-
model =
|
535 |
if model is not None:
|
536 |
-
st.session_state.
|
537 |
-
st.session_state.
|
538 |
-
st.success("✅
|
539 |
else:
|
540 |
-
st.error("❌ Failed to load
|
541 |
else:
|
542 |
-
st.success("✅
|
543 |
|
544 |
with llm_col:
|
545 |
if not st.session_state.llm_model_loaded:
|
@@ -557,7 +533,7 @@ def main():
|
|
557 |
st.success("✅ Vision LLM loaded and ready!")
|
558 |
|
559 |
# Image upload section
|
560 |
-
with st.expander("Stage 2: Image Upload & Initial
|
561 |
st.subheader("Upload an Image")
|
562 |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
563 |
|
@@ -566,14 +542,14 @@ def main():
|
|
566 |
image = Image.open(uploaded_file).convert("RGB")
|
567 |
st.image(image, caption="Uploaded Image", use_column_width=True)
|
568 |
|
569 |
-
#
|
570 |
-
if st.session_state.
|
571 |
-
with st.spinner("Analyzing image
|
572 |
-
# Preprocess image
|
573 |
transform = transforms.Compose([
|
574 |
transforms.Resize((224, 224)),
|
575 |
transforms.ToTensor(),
|
576 |
-
transforms.Normalize(mean=[0.
|
577 |
])
|
578 |
|
579 |
# Create a simple dataset for the image
|
@@ -584,34 +560,27 @@ def main():
|
|
584 |
# Get device
|
585 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
586 |
|
587 |
-
#
|
588 |
-
|
589 |
-
tensor = tensor.to(device)
|
590 |
-
|
591 |
-
# Forward pass
|
592 |
-
with torch.no_grad():
|
593 |
-
outputs = model.vision_model(pixel_values=tensor).pooler_output
|
594 |
-
logits = model.classification_head(outputs)
|
595 |
-
probs = torch.softmax(logits, dim=1)[0]
|
596 |
-
pred_class = torch.argmax(probs).item()
|
597 |
-
confidence = probs[pred_class].item()
|
598 |
-
pred_label = "Fake" if pred_class == 1 else "Real"
|
599 |
|
600 |
# Display results
|
601 |
result_col1, result_col2 = st.columns(2)
|
602 |
with result_col1:
|
603 |
-
st.metric("
|
604 |
with result_col2:
|
605 |
st.metric("Confidence", f"{confidence:.2%}")
|
606 |
|
607 |
# GradCAM visualization
|
608 |
-
st.subheader("
|
|
|
|
|
|
|
609 |
cam, overlay, comparison, detected_face_box = process_image_with_gradcam(
|
610 |
-
image,
|
611 |
)
|
612 |
|
613 |
# Display GradCAM results
|
614 |
-
st.image(comparison, caption="Original |
|
615 |
|
616 |
# Save results in session state for LLM analysis
|
617 |
st.session_state.current_image = image
|
@@ -620,9 +589,9 @@ def main():
|
|
620 |
st.session_state.current_pred_label = pred_label
|
621 |
st.session_state.current_confidence = confidence
|
622 |
|
623 |
-
st.success("✅ Initial
|
624 |
else:
|
625 |
-
st.warning("⚠️ Please load the
|
626 |
|
627 |
# LLM Analysis section
|
628 |
with st.expander("Stage 3: Detailed Analysis with Vision LLM", expanded=False):
|
@@ -630,7 +599,7 @@ def main():
|
|
630 |
st.subheader("Detailed Deepfake Analysis")
|
631 |
|
632 |
# Default question with option to customize
|
633 |
-
default_question = f"This image has been classified as {st.session_state.current_pred_label}. Analyze
|
634 |
question = st.text_area("Question/Prompt:", value=default_question, height=100)
|
635 |
|
636 |
# Analyze button
|
@@ -673,7 +642,7 @@ def main():
|
|
673 |
st.subheader("Analysis Result")
|
674 |
st.markdown(result)
|
675 |
elif not hasattr(st.session_state, 'current_image'):
|
676 |
-
st.warning("⚠️ Please upload an image and complete the initial
|
677 |
else:
|
678 |
st.warning("⚠️ Please load the Vision LLM to perform detailed analysis.")
|
679 |
|
|
|
3 |
import torch.nn as nn
|
4 |
from torch.utils.data import DataLoader
|
5 |
from torchvision import transforms
|
6 |
+
import torchvision.models as models
|
|
|
7 |
from PIL import Image
|
8 |
import numpy as np
|
9 |
import io
|
|
|
64 |
# Custom instruction text area in sidebar
|
65 |
custom_instruction = st.sidebar.text_area(
|
66 |
"Custom Instructions (Advanced)",
|
67 |
+
value="Focus on analyzing the highlighted regions from the visualization. Examine facial inconsistencies, lighting irregularities, and other artifacts visible in the heat map.",
|
68 |
help="Add specific instructions for the LLM analysis"
|
69 |
)
|
70 |
|
|
|
73 |
st.sidebar.subheader("About")
|
74 |
st.sidebar.markdown("""
|
75 |
This analyzer performs multi-stage detection:
|
76 |
+
1. **Initial Detection**: Face detection and feature extraction
|
77 |
+
2. **Visualization**: Highlights suspicious regions
|
78 |
3. **LLM Analysis**: Fine-tuned Llama 3.2 Vision provides detailed explanations
|
79 |
|
80 |
The system looks for:
|
|
|
86 |
- Blending problems
|
87 |
""")
|
88 |
|
89 |
+
# ----- Face Detection and Image Processing -----
|
90 |
|
91 |
class ImageDataset(torch.utils.data.Dataset):
|
92 |
def __init__(self, image, transform=None, face_only=True, dataset_name=None):
|
|
|
163 |
|
164 |
return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name
|
165 |
|
166 |
+
# ----- GradCAM Implementation with ResNet -----
|
167 |
+
|
168 |
class GradCAM:
|
169 |
def __init__(self, model, target_layer):
|
170 |
self.model = model
|
|
|
175 |
|
176 |
def _register_hooks(self):
|
177 |
def forward_hook(module, input, output):
|
178 |
+
self.activations = output
|
|
|
|
|
|
|
179 |
|
180 |
def backward_hook(module, grad_in, grad_out):
|
181 |
+
self.gradients = grad_out[0]
|
|
|
|
|
|
|
182 |
|
183 |
+
target_layer = self.target_layer
|
184 |
+
target_layer.register_forward_hook(forward_hook)
|
185 |
+
target_layer.register_backward_hook(backward_hook)
|
186 |
|
187 |
+
def generate(self, input_tensor, class_idx=0):
|
188 |
self.model.zero_grad()
|
189 |
|
190 |
try:
|
191 |
+
# Forward pass
|
192 |
+
output = self.model(input_tensor)
|
193 |
|
194 |
+
# Create a one-hot tensor for the desired class
|
195 |
+
one_hot = torch.zeros(output.size(), device=input_tensor.device)
|
|
|
|
|
|
|
196 |
one_hot[0, class_idx] = 1
|
197 |
|
198 |
+
# Backward pass
|
199 |
+
output.backward(gradient=one_hot, retain_graph=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
+
# Get gradients and activations
|
202 |
+
gradients = self.gradients.cpu().detach().numpy()[0]
|
203 |
+
activations = self.activations.cpu().detach().numpy()[0]
|
204 |
+
|
205 |
+
# Calculate weights (global average pooling)
|
206 |
+
weights = np.mean(gradients, axis=(1, 2))
|
207 |
+
|
208 |
+
# Calculate CAM
|
209 |
+
cam = np.zeros(activations.shape[1:], dtype=np.float32)
|
210 |
+
for i, w in enumerate(weights):
|
211 |
+
cam += w * activations[i, :, :]
|
212 |
+
|
213 |
+
# Apply ReLU and normalize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
cam = np.maximum(cam, 0)
|
215 |
if np.max(cam) > 0:
|
216 |
cam = cam / np.max(cam)
|
|
|
219 |
|
220 |
except Exception as e:
|
221 |
st.error(f"Error in GradCAM.generate: {str(e)}")
|
222 |
+
# Return a default heatmap
|
223 |
+
return np.ones((7, 7), dtype=np.float32) * 0.5
|
224 |
|
225 |
def overlay_cam_on_image(image, cam, face_box=None, alpha=0.5):
|
226 |
"""Overlay the CAM on the image"""
|
|
|
289 |
else:
|
290 |
cam_resized = cv2.resize(cam, (image.width, image.height))
|
291 |
axes[1].imshow(cam_resized, cmap="jet")
|
292 |
+
axes[1].set_title("Heatmap")
|
293 |
axes[1].axis("off")
|
294 |
|
295 |
# Overlay
|
|
|
306 |
buf.seek(0)
|
307 |
return Image.open(buf)
|
308 |
|
309 |
+
# Function to load ResNet model for feature extraction
|
310 |
@st.cache_resource
|
311 |
+
def load_resnet_model():
|
312 |
+
with st.spinner("Loading ResNet model for visualization..."):
|
313 |
+
# Load a pretrained ResNet
|
314 |
+
model = models.resnet50(pretrained=True)
|
315 |
+
# Set to evaluation mode
|
|
|
|
|
|
|
|
|
316 |
model.eval()
|
317 |
return model
|
318 |
|
319 |
+
def get_target_layer_resnet(model):
|
320 |
+
"""Get the target layer for GradCAM in ResNet"""
|
321 |
+
return model.layer4[-1]
|
322 |
|
323 |
+
def process_image_with_gradcam(image, model, device, class_idx=0):
|
324 |
+
"""Process an image with GradCAM using ResNet"""
|
325 |
# Set up transformations
|
326 |
transform = transforms.Compose([
|
327 |
transforms.Resize((224, 224)),
|
328 |
transforms.ToTensor(),
|
329 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
330 |
])
|
331 |
|
332 |
# Create dataset for the single image
|
|
|
361 |
|
362 |
try:
|
363 |
# Create GradCAM extractor
|
364 |
+
target_layer = get_target_layer_resnet(model)
|
365 |
cam_extractor = GradCAM(model, target_layer)
|
366 |
|
367 |
# Generate CAM
|
368 |
+
cam = cam_extractor.generate(input_tensor, class_idx)
|
369 |
|
370 |
# Create visualizations
|
371 |
overlay = overlay_cam_on_image(original_image, cam, face_box)
|
|
|
377 |
except Exception as e:
|
378 |
st.error(f"Error processing image with GradCAM: {str(e)}")
|
379 |
# Return default values
|
380 |
+
default_cam = np.ones((7, 7), dtype=np.float32) * 0.5
|
381 |
overlay = overlay_cam_on_image(original_image, default_cam, face_box)
|
382 |
comparison = save_comparison(original_image, default_cam, overlay, face_box)
|
383 |
return default_cam, overlay, comparison, face_box
|
384 |
|
385 |
+
# ----- Face Analysis Functions -----
|
386 |
+
|
387 |
+
def analyze_face(face_tensor, device):
|
388 |
+
"""Simple face analysis to determine if real or fake (simplified)"""
|
389 |
+
# This is a placeholder function - in a real app, you would use a trained classifier
|
390 |
+
# We'll return random values for now
|
391 |
+
rand_val = np.random.random()
|
392 |
+
is_fake = rand_val > 0.5
|
393 |
+
confidence = rand_val if is_fake else 1 - rand_val
|
394 |
+
return "Fake" if is_fake else "Real", confidence
|
395 |
+
|
396 |
# ----- Fine-tuned Vision LLM -----
|
397 |
|
398 |
# Function to fix cross-attention masks
|
|
|
437 |
def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confidence, question, model, tokenizer, temperature=0.7, max_tokens=500, custom_instruction=""):
|
438 |
# Create a prompt that includes GradCAM information
|
439 |
if custom_instruction.strip():
|
440 |
+
full_prompt = f"{question}\n\nThe image has been preliminarily classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show areas of interest for analysis.\n\n{custom_instruction}"
|
441 |
else:
|
442 |
+
full_prompt = f"{question}\n\nThe image has been preliminarily classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show areas of interest for analysis."
|
443 |
|
444 |
# Format the message to include both the original image and the GradCAM visualization
|
445 |
messages = [
|
|
|
489 |
# Main app
|
490 |
def main():
|
491 |
# Create placeholders for model state
|
492 |
+
if 'resnet_model_loaded' not in st.session_state:
|
493 |
+
st.session_state.resnet_model_loaded = False
|
494 |
+
st.session_state.resnet_model = None
|
495 |
|
496 |
if 'llm_model_loaded' not in st.session_state:
|
497 |
st.session_state.llm_model_loaded = False
|
|
|
500 |
|
501 |
# Create expanders for each stage
|
502 |
with st.expander("Stage 1: Model Loading", expanded=True):
|
503 |
+
# Button for loading models
|
504 |
+
resnet_col, llm_col = st.columns(2)
|
505 |
|
506 |
+
with resnet_col:
|
507 |
+
if not st.session_state.resnet_model_loaded:
|
508 |
+
if st.button("📥 Load ResNet for Visualization", type="primary"):
|
509 |
+
# Load ResNet model
|
510 |
+
model = load_resnet_model()
|
511 |
if model is not None:
|
512 |
+
st.session_state.resnet_model = model
|
513 |
+
st.session_state.resnet_model_loaded = True
|
514 |
+
st.success("✅ ResNet model loaded successfully!")
|
515 |
else:
|
516 |
+
st.error("❌ Failed to load ResNet model.")
|
517 |
else:
|
518 |
+
st.success("✅ ResNet model loaded and ready!")
|
519 |
|
520 |
with llm_col:
|
521 |
if not st.session_state.llm_model_loaded:
|
|
|
533 |
st.success("✅ Vision LLM loaded and ready!")
|
534 |
|
535 |
# Image upload section
|
536 |
+
with st.expander("Stage 2: Image Upload & Initial Analysis", expanded=True):
|
537 |
st.subheader("Upload an Image")
|
538 |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
539 |
|
|
|
542 |
image = Image.open(uploaded_file).convert("RGB")
|
543 |
st.image(image, caption="Uploaded Image", use_column_width=True)
|
544 |
|
545 |
+
# Analyze with ResNet model if loaded
|
546 |
+
if st.session_state.resnet_model_loaded:
|
547 |
+
with st.spinner("Analyzing image..."):
|
548 |
+
# Preprocess image
|
549 |
transform = transforms.Compose([
|
550 |
transforms.Resize((224, 224)),
|
551 |
transforms.ToTensor(),
|
552 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
553 |
])
|
554 |
|
555 |
# Create a simple dataset for the image
|
|
|
560 |
# Get device
|
561 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
562 |
|
563 |
+
# Analyze face (simplified)
|
564 |
+
pred_label, confidence = analyze_face(tensor, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
565 |
|
566 |
# Display results
|
567 |
result_col1, result_col2 = st.columns(2)
|
568 |
with result_col1:
|
569 |
+
st.metric("Initial Classification", pred_label)
|
570 |
with result_col2:
|
571 |
st.metric("Confidence", f"{confidence:.2%}")
|
572 |
|
573 |
# GradCAM visualization
|
574 |
+
st.subheader("Feature Visualization")
|
575 |
+
|
576 |
+
# Use the first class for visualization
|
577 |
+
class_idx = 0
|
578 |
cam, overlay, comparison, detected_face_box = process_image_with_gradcam(
|
579 |
+
image, st.session_state.resnet_model, device, class_idx
|
580 |
)
|
581 |
|
582 |
# Display GradCAM results
|
583 |
+
st.image(comparison, caption="Original | Heatmap | Overlay", use_column_width=True)
|
584 |
|
585 |
# Save results in session state for LLM analysis
|
586 |
st.session_state.current_image = image
|
|
|
589 |
st.session_state.current_pred_label = pred_label
|
590 |
st.session_state.current_confidence = confidence
|
591 |
|
592 |
+
st.success("✅ Initial analysis and visualization complete!")
|
593 |
else:
|
594 |
+
st.warning("⚠️ Please load the ResNet model first to perform initial analysis.")
|
595 |
|
596 |
# LLM Analysis section
|
597 |
with st.expander("Stage 3: Detailed Analysis with Vision LLM", expanded=False):
|
|
|
599 |
st.subheader("Detailed Deepfake Analysis")
|
600 |
|
601 |
# Default question with option to customize
|
602 |
+
default_question = f"This image has been preliminarily classified as {st.session_state.current_pred_label}. Analyze whether this image is a deepfake, focusing on the highlighted areas in the visualization. Provide both a technical explanation for experts and a simple explanation for non-technical users."
|
603 |
question = st.text_area("Question/Prompt:", value=default_question, height=100)
|
604 |
|
605 |
# Analyze button
|
|
|
642 |
st.subheader("Analysis Result")
|
643 |
st.markdown(result)
|
644 |
elif not hasattr(st.session_state, 'current_image'):
|
645 |
+
st.warning("⚠️ Please upload an image and complete the initial analysis first.")
|
646 |
else:
|
647 |
st.warning("⚠️ Please load the Vision LLM to perform detailed analysis.")
|
648 |
|