saakshigupta's picture
Update app.py
99ef832 verified
raw
history blame
8 kB
import streamlit as st
import torch
import os
from PIL import Image
from transformers import AutoProcessor, MllamaForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import gc
# Page config
st.set_page_config(
page_title="Deepfake Image Analyzer",
page_icon="πŸ”",
layout="wide"
)
# App title and description
st.title("Deepfake Image Analyzer")
st.markdown("Upload an image to analyze it for possible deepfake manipulation")
# Function to free up memory
def free_memory():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# Helper functions
def init_device():
"""Set the appropriate device and return it"""
if torch.cuda.is_available():
st.sidebar.success("βœ“ GPU available: Using CUDA")
return "cuda"
else:
st.sidebar.warning("⚠️ No GPU detected: Using CPU (analysis will be slow)")
return "cpu"
# Set device
device = init_device()
@st.cache_resource
def load_model():
"""Load model and processor with caching to avoid reloading"""
try:
# Load base model
base_model_id = "unsloth/llama-3.2-11b-vision-instruct-unsloth-bnb-4bit"
processor = AutoProcessor.from_pretrained(base_model_id)
# Configure 4-bit quantization with correct dtype
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
# Load model with explicit dtype settings using MllamaForCausalLM
model = MllamaForCausalLM.from_pretrained(
base_model_id,
device_map="auto",
torch_dtype=torch.float16,
quantization_config=quantization_config
)
# Load adapter
adapter_id = "saakshigupta/deepfake-explainer-1"
model = PeftModel.from_pretrained(model, adapter_id)
return model, processor
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None, None
# Function to fix cross-attention masks
def fix_processor_outputs(inputs):
"""Fix cross-attention mask dimensions if needed"""
if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
visual_features = 6404 # The exact dimension used in training
new_mask = torch.ones(
(batch_size, seq_len, visual_features, num_tiles),
device=inputs['cross_attention_mask'].device
)
inputs['cross_attention_mask'] = new_mask
return True, inputs
return False, inputs
# Create sidebar with options
with st.sidebar:
st.header("Options")
temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1,
help="Higher values make output more random, lower values more deterministic")
max_length = st.slider("Maximum response length", min_value=100, max_value=1000, value=500, step=50)
custom_prompt = st.text_area(
"Custom instruction (optional)",
value="Analyze this image and determine if it's a deepfake. Provide both technical and non-technical explanations.",
height=100
)
st.markdown("### About")
st.markdown("""
This app uses a fine-tuned Llama 3.2 Vision model to detect and explain deepfakes.
The analyzer looks for:
- Inconsistencies in facial features
- Unusual lighting or shadows
- Unnatural blur patterns
- Artifacts around edges
- Texture inconsistencies
Model by [saakshigupta](https://huggingface.co/saakshigupta/deepfake-explainer-1)
""")
# Load model on app startup with a progress bar
if 'model_loaded' not in st.session_state:
progress_bar = st.progress(0)
st.info("Loading model... this may take a minute.")
for i in range(10):
# Simulate progress while model loads
progress_bar.progress((i + 1) * 10)
if i == 2:
# Start loading the model at 30% progress
model, processor = load_model()
if model is not None:
st.session_state['model'] = model
st.session_state['processor'] = processor
st.session_state['model_loaded'] = True
progress_bar.empty()
if 'model_loaded' in st.session_state and st.session_state['model_loaded']:
st.success("Model loaded successfully!")
else:
st.error("Failed to load model. Try refreshing the page.")
# Main content area - file uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
# Check if model is loaded
model_loaded = 'model_loaded' in st.session_state and st.session_state['model_loaded']
if uploaded_file is not None and model_loaded:
# Display the image
image = Image.open(uploaded_file).convert('RGB')
st.image(image, caption="Uploaded Image", use_column_width=True)
# Analyze button
if st.button("Analyze Image"):
with st.spinner("Analyzing the image... This may take 15-30 seconds"):
try:
# Get components from session state
model = st.session_state['model']
processor = st.session_state['processor']
# Process the image
inputs = processor(text=custom_prompt, images=image, return_tensors="pt")
# Fix cross-attention mask
fixed, inputs = fix_processor_outputs(inputs)
if fixed:
st.info("Fixed cross-attention mask dimensions")
# Move to device
inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
# Generate the analysis
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
top_p=0.9
)
# Decode the output
response = processor.decode(output_ids[0], skip_special_tokens=True)
# Extract the actual response (removing the prompt)
if custom_prompt in response:
result = response.split(custom_prompt)[-1].strip()
else:
result = response
# Display result in a nice format
st.success("Analysis complete!")
# Show technical and non-technical explanations separately if they exist
if "Technical Explanation:" in result and "Non-Technical Explanation:" in result:
technical, non_technical = result.split("Non-Technical Explanation:")
technical = technical.replace("Technical Explanation:", "").strip()
col1, col2 = st.columns(2)
with col1:
st.subheader("Technical Analysis")
st.write(technical)
with col2:
st.subheader("Simple Explanation")
st.write(non_technical)
else:
st.subheader("Analysis Result")
st.write(result)
# Free memory after analysis
free_memory()
except Exception as e:
st.error(f"Error analyzing image: {str(e)}")
else:
st.info("Please upload an image to begin analysis")
# Add footer
st.markdown("---")
st.markdown("Deepfake Image Analyzer")