File size: 4,902 Bytes
473311a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# app.py
import streamlit as st
import torch
from PIL import Image
import io
from transformers import AutoProcessor, AutoModelForCausalLM
from peft import PeftModel

# Page config
st.set_page_config(
    page_title="Deepfake Explainer",
    page_icon="πŸ”",
    layout="wide"
)

# App title and description
st.title("Deepfake Image Analyzer")
st.markdown("Upload an image to analyze it for possible deepfake manipulation")

@st.cache_resource
def load_model():
    """Load model and processor (cached to avoid reloading)"""
    # Load base model
    base_model_id = "unsloth/llama-3.2-11b-vision-instruct"
    processor = AutoProcessor.from_pretrained(base_model_id)
    model = AutoModelForCausalLM.from_pretrained(
        base_model_id, 
        device_map="auto", 
        torch_dtype=torch.float16
    )
    
    # Load adapter
    adapter_id = "saakshigupta/deepfake-explainer-1"
    model = PeftModel.from_pretrained(model, adapter_id)
    
    return model, processor

# Function to fix cross-attention masks
def fix_processor_outputs(inputs):
    if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
        batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
        visual_features = 6404  # The exact dimension we fixed in training
        new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles),
                          device=inputs['cross_attention_mask'].device)
        inputs['cross_attention_mask'] = new_mask
        st.write("βœ… Fixed cross-attention mask dimensions")
    return inputs

# Load model on first run
with st.spinner("Loading model... this may take a minute."):
    model, processor = load_model()
    st.success("Model loaded successfully!")

# Create sidebar with options
with st.sidebar:
    st.header("Options")
    temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1,
                           help="Higher values make output more random, lower values more deterministic")
    max_length = st.slider("Maximum response length", min_value=100, max_value=1000, value=500, step=50)
    
    custom_prompt = st.text_area(
        "Custom instruction (optional)",
        value="Analyze this image and determine if it's a deepfake. Provide both technical and non-technical explanations.",
        height=100
    )
    
    st.markdown("### About")
    st.markdown("This app uses a fine-tuned Llama 3.2 Vision model to detect and explain deepfakes.")
    st.markdown("Model by [saakshigupta](https://huggingface.co/saakshigupta/deepfake-explainer-1)")

# Main content area - file uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    # Display the image
    image = Image.open(uploaded_file).convert('RGB')
    st.image(image, caption="Uploaded Image", use_column_width=True)
    
    # Analyze button
    if st.button("Analyze Image"):
        with st.spinner("Analyzing the image..."):
            # Process the image
            inputs = processor(text=custom_prompt, images=image, return_tensors="pt")
            
            # Fix cross-attention mask
            inputs = fix_processor_outputs(inputs)
            
            # Move to device
            inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
            
            # Generate the analysis
            with torch.no_grad():
                output_ids = model.generate(
                    **inputs,
                    max_new_tokens=max_length,
                    temperature=temperature,
                    top_p=0.9
                )
            
            # Decode the output
            response = processor.decode(output_ids[0], skip_special_tokens=True)
            
            # Extract the actual response (removing the prompt)
            if custom_prompt in response:
                result = response.split(custom_prompt)[-1].strip()
            else:
                result = response
        
        # Display result in a nice format
        st.success("Analysis complete!")
        
        # Show technical and non-technical explanations separately if they exist
        if "Technical Explanation:" in result and "Non-Technical Explanation:" in result:
            technical, non_technical = result.split("Non-Technical Explanation:")
            technical = technical.replace("Technical Explanation:", "").strip()
            
            col1, col2 = st.columns(2)
            with col1:
                st.subheader("Technical Analysis")
                st.write(technical)
            
            with col2:
                st.subheader("Simple Explanation")
                st.write(non_technical)
        else:
            st.subheader("Analysis Result")
            st.write(result)
else:
    st.info("Please upload an image to begin analysis")