Update app.py
Browse files
app.py
CHANGED
@@ -533,7 +533,7 @@ def load_llm_model():
|
|
533 |
)
|
534 |
|
535 |
# Load the adapter
|
536 |
-
adapter_id = "saakshigupta/deepfake-explainer-
|
537 |
model = PeftModel.from_pretrained(model, adapter_id)
|
538 |
|
539 |
# Set to inference mode
|
@@ -552,50 +552,77 @@ def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confide
|
|
552 |
else:
|
553 |
full_prompt = f"{question}\n\nThe image has been processed with GradCAM and classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show the areas the detection model found suspicious."
|
554 |
|
555 |
-
|
556 |
-
|
557 |
-
{"
|
558 |
-
{"type": "image", "image": image}, # Original image
|
559 |
-
{"type": "image", "image": gradcam_overlay}, # GradCAM overlay
|
560 |
-
{"type": "text", "text": full_prompt}
|
561 |
-
]}
|
562 |
-
]
|
563 |
-
|
564 |
-
# Apply chat template
|
565 |
-
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
566 |
-
|
567 |
-
# Process with image
|
568 |
-
inputs = tokenizer(
|
569 |
-
[image, gradcam_overlay], # Send both images
|
570 |
-
input_text,
|
571 |
-
add_special_tokens=False,
|
572 |
-
return_tensors="pt",
|
573 |
-
).to(model.device)
|
574 |
-
|
575 |
-
# Fix cross-attention mask if needed
|
576 |
-
inputs = fix_cross_attention_mask(inputs)
|
577 |
-
|
578 |
-
# Generate response
|
579 |
-
with st.spinner("Generating detailed analysis... (this may take 15-30 seconds)"):
|
580 |
-
with torch.no_grad():
|
581 |
-
output_ids = model.generate(
|
582 |
-
**inputs,
|
583 |
-
max_new_tokens=max_tokens,
|
584 |
-
use_cache=True,
|
585 |
-
temperature=temperature,
|
586 |
-
top_p=0.9
|
587 |
-
)
|
588 |
|
589 |
-
#
|
590 |
-
|
591 |
-
|
592 |
-
#
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
599 |
|
600 |
# Main app
|
601 |
def main():
|
@@ -818,7 +845,7 @@ def main():
|
|
818 |
caption_text += f"\n\nGradCAM Analysis:\n{st.session_state.gradcam_caption}"
|
819 |
|
820 |
# Default question with option to customize
|
821 |
-
default_question = f"This image has been classified as {
|
822 |
|
823 |
# User input for new question
|
824 |
new_question = st.text_area("Ask a question about the image:", value=default_question if not st.session_state.chat_history else "", height=100)
|
@@ -902,5 +929,8 @@ def main():
|
|
902 |
# Footer
|
903 |
st.markdown("---")
|
904 |
|
|
|
|
|
|
|
905 |
if __name__ == "__main__":
|
906 |
main()
|
|
|
533 |
)
|
534 |
|
535 |
# Load the adapter
|
536 |
+
adapter_id = "saakshigupta/deepfake-explainer-2"
|
537 |
model = PeftModel.from_pretrained(model, adapter_id)
|
538 |
|
539 |
# Set to inference mode
|
|
|
552 |
else:
|
553 |
full_prompt = f"{question}\n\nThe image has been processed with GradCAM and classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show the areas the detection model found suspicious."
|
554 |
|
555 |
+
try:
|
556 |
+
# Format the message to include all available images
|
557 |
+
message_content = [{"type": "text", "text": full_prompt}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
558 |
|
559 |
+
# Add original image
|
560 |
+
message_content.insert(0, {"type": "image", "image": image})
|
561 |
+
|
562 |
+
# Add GradCAM overlay
|
563 |
+
message_content.insert(1, {"type": "image", "image": gradcam_overlay})
|
564 |
+
|
565 |
+
# Add comparison image if available
|
566 |
+
if hasattr(st.session_state, 'comparison_image'):
|
567 |
+
message_content.insert(2, {"type": "image", "image": st.session_state.comparison_image})
|
568 |
+
|
569 |
+
messages = [{"role": "user", "content": message_content}]
|
570 |
+
|
571 |
+
# Apply chat template
|
572 |
+
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
573 |
+
|
574 |
+
# Create list of images to process
|
575 |
+
image_list = [image, gradcam_overlay]
|
576 |
+
if hasattr(st.session_state, 'comparison_image'):
|
577 |
+
image_list.append(st.session_state.comparison_image)
|
578 |
+
|
579 |
+
try:
|
580 |
+
# Try with multiple images first
|
581 |
+
inputs = tokenizer(
|
582 |
+
image_list,
|
583 |
+
input_text,
|
584 |
+
add_special_tokens=False,
|
585 |
+
return_tensors="pt",
|
586 |
+
).to(model.device)
|
587 |
+
except Exception as e:
|
588 |
+
st.warning(f"Multiple image analysis encountered an issue: {str(e)}")
|
589 |
+
st.info("Falling back to single image analysis")
|
590 |
+
# Fallback to single image
|
591 |
+
inputs = tokenizer(
|
592 |
+
image,
|
593 |
+
input_text,
|
594 |
+
add_special_tokens=False,
|
595 |
+
return_tensors="pt",
|
596 |
+
).to(model.device)
|
597 |
+
|
598 |
+
# Fix cross-attention mask if needed
|
599 |
+
inputs = fix_cross_attention_mask(inputs)
|
600 |
+
|
601 |
+
# Generate response
|
602 |
+
with st.spinner("Generating detailed analysis... (this may take 15-30 seconds)"):
|
603 |
+
with torch.no_grad():
|
604 |
+
output_ids = model.generate(
|
605 |
+
**inputs,
|
606 |
+
max_new_tokens=max_tokens,
|
607 |
+
use_cache=True,
|
608 |
+
temperature=temperature,
|
609 |
+
top_p=0.9
|
610 |
+
)
|
611 |
+
|
612 |
+
# Decode the output
|
613 |
+
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
614 |
+
|
615 |
+
# Try to extract just the model's response (after the prompt)
|
616 |
+
if full_prompt in response:
|
617 |
+
result = response.split(full_prompt)[-1].strip()
|
618 |
+
else:
|
619 |
+
result = response
|
620 |
+
|
621 |
+
return result
|
622 |
+
|
623 |
+
except Exception as e:
|
624 |
+
st.error(f"Error during LLM analysis: {str(e)}")
|
625 |
+
return f"Error analyzing image: {str(e)}"
|
626 |
|
627 |
# Main app
|
628 |
def main():
|
|
|
845 |
caption_text += f"\n\nGradCAM Analysis:\n{st.session_state.gradcam_caption}"
|
846 |
|
847 |
# Default question with option to customize
|
848 |
+
default_question = f"This image has been classified as {{pred_label}}. Analyze all the provided images (original, GradCAM visualization, and comparison) to determine if this is a deepfake. Focus on highlighted areas in the GradCAM visualization. Provide both a technical explanation for experts and a simple explanation for non-technical users."
|
849 |
|
850 |
# User input for new question
|
851 |
new_question = st.text_area("Ask a question about the image:", value=default_question if not st.session_state.chat_history else "", height=100)
|
|
|
929 |
# Footer
|
930 |
st.markdown("---")
|
931 |
|
932 |
+
# Add model version indicator in sidebar
|
933 |
+
st.sidebar.info("Using deepfake-explainer-2 model")
|
934 |
+
|
935 |
if __name__ == "__main__":
|
936 |
main()
|