saakshigupta commited on
Commit
cf310a7
·
verified ·
1 Parent(s): af6361c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -45
app.py CHANGED
@@ -533,7 +533,7 @@ def load_llm_model():
533
  )
534
 
535
  # Load the adapter
536
- adapter_id = "saakshigupta/deepfake-explainer-1"
537
  model = PeftModel.from_pretrained(model, adapter_id)
538
 
539
  # Set to inference mode
@@ -552,50 +552,77 @@ def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confide
552
  else:
553
  full_prompt = f"{question}\n\nThe image has been processed with GradCAM and classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show the areas the detection model found suspicious."
554
 
555
- # Format the message to include both the original image and the GradCAM visualization
556
- messages = [
557
- {"role": "user", "content": [
558
- {"type": "image", "image": image}, # Original image
559
- {"type": "image", "image": gradcam_overlay}, # GradCAM overlay
560
- {"type": "text", "text": full_prompt}
561
- ]}
562
- ]
563
-
564
- # Apply chat template
565
- input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
566
-
567
- # Process with image
568
- inputs = tokenizer(
569
- [image, gradcam_overlay], # Send both images
570
- input_text,
571
- add_special_tokens=False,
572
- return_tensors="pt",
573
- ).to(model.device)
574
-
575
- # Fix cross-attention mask if needed
576
- inputs = fix_cross_attention_mask(inputs)
577
-
578
- # Generate response
579
- with st.spinner("Generating detailed analysis... (this may take 15-30 seconds)"):
580
- with torch.no_grad():
581
- output_ids = model.generate(
582
- **inputs,
583
- max_new_tokens=max_tokens,
584
- use_cache=True,
585
- temperature=temperature,
586
- top_p=0.9
587
- )
588
 
589
- # Decode the output
590
- response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
591
-
592
- # Try to extract just the model's response (after the prompt)
593
- if full_prompt in response:
594
- result = response.split(full_prompt)[-1].strip()
595
- else:
596
- result = response
597
-
598
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
 
600
  # Main app
601
  def main():
@@ -818,7 +845,7 @@ def main():
818
  caption_text += f"\n\nGradCAM Analysis:\n{st.session_state.gradcam_caption}"
819
 
820
  # Default question with option to customize
821
- default_question = f"This image has been classified as {st.session_state.current_pred_label}. Analyze the key features that led to this classification, focusing on the highlighted areas in the GradCAM visualization. Provide both a technical explanation for experts and a simple explanation for non-technical users."
822
 
823
  # User input for new question
824
  new_question = st.text_area("Ask a question about the image:", value=default_question if not st.session_state.chat_history else "", height=100)
@@ -902,5 +929,8 @@ def main():
902
  # Footer
903
  st.markdown("---")
904
 
 
 
 
905
  if __name__ == "__main__":
906
  main()
 
533
  )
534
 
535
  # Load the adapter
536
+ adapter_id = "saakshigupta/deepfake-explainer-2"
537
  model = PeftModel.from_pretrained(model, adapter_id)
538
 
539
  # Set to inference mode
 
552
  else:
553
  full_prompt = f"{question}\n\nThe image has been processed with GradCAM and classified as {pred_label} with confidence {confidence:.2f}. Focus on the highlighted regions in red/yellow which show the areas the detection model found suspicious."
554
 
555
+ try:
556
+ # Format the message to include all available images
557
+ message_content = [{"type": "text", "text": full_prompt}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
 
559
+ # Add original image
560
+ message_content.insert(0, {"type": "image", "image": image})
561
+
562
+ # Add GradCAM overlay
563
+ message_content.insert(1, {"type": "image", "image": gradcam_overlay})
564
+
565
+ # Add comparison image if available
566
+ if hasattr(st.session_state, 'comparison_image'):
567
+ message_content.insert(2, {"type": "image", "image": st.session_state.comparison_image})
568
+
569
+ messages = [{"role": "user", "content": message_content}]
570
+
571
+ # Apply chat template
572
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
573
+
574
+ # Create list of images to process
575
+ image_list = [image, gradcam_overlay]
576
+ if hasattr(st.session_state, 'comparison_image'):
577
+ image_list.append(st.session_state.comparison_image)
578
+
579
+ try:
580
+ # Try with multiple images first
581
+ inputs = tokenizer(
582
+ image_list,
583
+ input_text,
584
+ add_special_tokens=False,
585
+ return_tensors="pt",
586
+ ).to(model.device)
587
+ except Exception as e:
588
+ st.warning(f"Multiple image analysis encountered an issue: {str(e)}")
589
+ st.info("Falling back to single image analysis")
590
+ # Fallback to single image
591
+ inputs = tokenizer(
592
+ image,
593
+ input_text,
594
+ add_special_tokens=False,
595
+ return_tensors="pt",
596
+ ).to(model.device)
597
+
598
+ # Fix cross-attention mask if needed
599
+ inputs = fix_cross_attention_mask(inputs)
600
+
601
+ # Generate response
602
+ with st.spinner("Generating detailed analysis... (this may take 15-30 seconds)"):
603
+ with torch.no_grad():
604
+ output_ids = model.generate(
605
+ **inputs,
606
+ max_new_tokens=max_tokens,
607
+ use_cache=True,
608
+ temperature=temperature,
609
+ top_p=0.9
610
+ )
611
+
612
+ # Decode the output
613
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
614
+
615
+ # Try to extract just the model's response (after the prompt)
616
+ if full_prompt in response:
617
+ result = response.split(full_prompt)[-1].strip()
618
+ else:
619
+ result = response
620
+
621
+ return result
622
+
623
+ except Exception as e:
624
+ st.error(f"Error during LLM analysis: {str(e)}")
625
+ return f"Error analyzing image: {str(e)}"
626
 
627
  # Main app
628
  def main():
 
845
  caption_text += f"\n\nGradCAM Analysis:\n{st.session_state.gradcam_caption}"
846
 
847
  # Default question with option to customize
848
+ default_question = f"This image has been classified as {{pred_label}}. Analyze all the provided images (original, GradCAM visualization, and comparison) to determine if this is a deepfake. Focus on highlighted areas in the GradCAM visualization. Provide both a technical explanation for experts and a simple explanation for non-technical users."
849
 
850
  # User input for new question
851
  new_question = st.text_area("Ask a question about the image:", value=default_question if not st.session_state.chat_history else "", height=100)
 
929
  # Footer
930
  st.markdown("---")
931
 
932
+ # Add model version indicator in sidebar
933
+ st.sidebar.info("Using deepfake-explainer-2 model")
934
+
935
  if __name__ == "__main__":
936
  main()