Spaces:
Running
Running
import streamlit as st | |
import torch | |
from PIL import Image | |
import os | |
import gc | |
from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig | |
from peft import PeftModel | |
# Page config | |
st.set_page_config( | |
page_title="Deepfake Image Analyzer", | |
page_icon="🔍", | |
layout="wide" | |
) | |
# App title and description | |
st.title("Deepfake Image Analyzer") | |
st.markdown("Upload an image to analyze it for possible deepfake manipulation") | |
# Function to free up memory | |
def free_memory(): | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Helper function to check CUDA | |
def init_device(): | |
if torch.cuda.is_available(): | |
st.sidebar.success("✓ GPU available: Using CUDA") | |
return "cuda" | |
else: | |
st.sidebar.warning("⚠️ No GPU detected: Using CPU (analysis will be slow)") | |
return "cpu" | |
# Set device | |
device = init_device() | |
def load_model(): | |
"""Load pre-quantized model""" | |
try: | |
# Using your original base model | |
base_model_id = "unsloth/llama-3.2-11b-vision-instruct-unsloth-bnb-4bit" | |
# Load processor | |
processor = AutoProcessor.from_pretrained(base_model_id) | |
# Configure quantization settings for unsloth model | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_quant_storage=torch.float16, | |
llm_int8_skip_modules=["lm_head"], | |
llm_int8_enable_fp32_cpu_offload=True | |
) | |
# Load the pre-quantized model with unsloth settings | |
model = AutoModelForCausalLM.from_pretrained( | |
base_model_id, | |
device_map="auto", | |
quantization_config=quantization_config, | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
use_cache=True, | |
offload_folder="offload" # Enable disk offloading | |
) | |
# Load adapter | |
adapter_id = "saakshigupta/deepfake-explainer-1" | |
model = PeftModel.from_pretrained(model, adapter_id) | |
return model, processor | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
st.exception(e) | |
return None, None | |
# Function to fix cross-attention masks | |
def fix_processor_outputs(inputs): | |
"""Fix cross-attention mask dimensions if needed""" | |
if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape: | |
batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape | |
visual_features = 6404 # The exact dimension used in training | |
new_mask = torch.ones( | |
(batch_size, seq_len, visual_features, num_tiles), | |
device=inputs['cross_attention_mask'].device | |
) | |
inputs['cross_attention_mask'] = new_mask | |
return True, inputs | |
return False, inputs | |
# Create sidebar with options | |
with st.sidebar: | |
st.header("Options") | |
temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1, | |
help="Higher values make output more random, lower values more deterministic") | |
max_length = st.slider("Maximum response length", min_value=100, max_value=1000, value=500, step=50) | |
custom_prompt = st.text_area( | |
"Custom instruction (optional)", | |
value="Analyze this image and determine if it's a deepfake. Provide both technical and non-technical explanations.", | |
height=100 | |
) | |
st.markdown("### About") | |
st.markdown(""" | |
This app uses a fine-tuned Llama 3.2 Vision model to detect and explain deepfakes. | |
The analyzer looks for: | |
- Inconsistencies in facial features | |
- Unusual lighting or shadows | |
- Unnatural blur patterns | |
- Artifacts around edges | |
- Texture inconsistencies | |
Model by [saakshigupta](https://huggingface.co/saakshigupta/deepfake-explainer-1) | |
""") | |
# Load model button | |
if st.button("Load Model"): | |
with st.spinner("Loading model... this may take several minutes"): | |
try: | |
model, processor = load_model() | |
if model is not None and processor is not None: | |
st.session_state['model'] = model | |
st.session_state['processor'] = processor | |
st.success("Model loaded successfully!") | |
else: | |
st.error("Failed to load model.") | |
except Exception as e: | |
st.error(f"Error during model loading: {str(e)}") | |
st.exception(e) | |
# Main content area - file uploader | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
# Check if model is loaded | |
model_loaded = 'model' in st.session_state and st.session_state['model'] is not None | |
if uploaded_file is not None: | |
# Display the image | |
image = Image.open(uploaded_file).convert('RGB') | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
# Analyze button (only enabled if model is loaded) | |
if st.button("Analyze Image", disabled=not model_loaded): | |
if not model_loaded: | |
st.warning("Please load the model first by clicking the 'Load Model' button.") | |
else: | |
with st.spinner("Analyzing the image... This may take 15-30 seconds"): | |
try: | |
# Get components from session state | |
model = st.session_state['model'] | |
processor = st.session_state['processor'] | |
# Process the image using the processor | |
inputs = processor(text=custom_prompt, images=image, return_tensors="pt") | |
# Fix cross-attention mask if needed | |
fixed, inputs = fix_processor_outputs(inputs) | |
if fixed: | |
st.info("Fixed cross-attention mask dimensions") | |
# Move to device | |
inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)} | |
# Generate the analysis | |
with torch.no_grad(): | |
output_ids = model.generate( | |
**inputs, | |
max_new_tokens=max_length, | |
temperature=temperature, | |
top_p=0.9 | |
) | |
# Decode the output | |
response = processor.decode(output_ids[0], skip_special_tokens=True) | |
# Extract the actual response (removing the prompt) | |
if custom_prompt in response: | |
result = response.split(custom_prompt)[-1].strip() | |
else: | |
result = response | |
# Display result in a nice format | |
st.success("Analysis complete!") | |
# Show technical and non-technical explanations separately if they exist | |
if "Technical Explanation:" in result and "Non-Technical Explanation:" in result: | |
technical, non_technical = result.split("Non-Technical Explanation:") | |
technical = technical.replace("Technical Explanation:", "").strip() | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("Technical Analysis") | |
st.write(technical) | |
with col2: | |
st.subheader("Simple Explanation") | |
st.write(non_technical) | |
else: | |
st.subheader("Analysis Result") | |
st.write(result) | |
# Free memory after analysis | |
free_memory() | |
except Exception as e: | |
st.error(f"Error analyzing image: {str(e)}") | |
st.exception(e) | |
elif not model_loaded: | |
st.warning("Please load the model first by clicking the 'Load Model' button at the top of the page.") | |
else: | |
st.info("Please upload an image to begin analysis") | |
# Add footer | |
st.markdown("---") | |
st.markdown("Deepfake Image Analyzer") |