File size: 8,519 Bytes
3680540
 
 
c692452
3680540
c692452
 
 
3680540
c692452
3680540
c692452
 
 
3680540
 
c692452
3680540
 
 
c692452
 
3680540
c692452
 
 
3680540
c692452
 
3680540
c692452
 
3680540
c692452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3680540
 
c692452
3680540
 
c692452
 
 
3680540
c692452
 
3680540
c692452
 
 
 
3680540
c692452
 
 
 
 
 
 
 
 
3680540
c692452
 
 
3680540
c692452
 
 
 
 
 
 
3680540
c692452
 
 
 
 
 
 
3680540
c692452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3680540
c692452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3680540
c692452
 
 
 
 
3680540
c692452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import streamlit as st
import torch
from PIL import Image
import io
from peft import PeftModel
from unsloth import FastVisionModel
import tempfile
import os

# App title and description
st.set_page_config(
    page_title="Deepfake Analyzer", 
    layout="wide",
    page_icon="πŸ”"
)

# Main title and description
st.title("Deepfake Image Analyzer")
st.markdown("Upload an image to analyze it for possible deepfake manipulation")

# Check for GPU availability
def check_gpu():
    if torch.cuda.is_available():
        gpu_info = torch.cuda.get_device_properties(0)
        st.sidebar.success(f"βœ… GPU available: {gpu_info.name} ({gpu_info.total_memory / (1024**3):.2f} GB)")
        return True
    else:
        st.sidebar.warning("⚠️ No GPU detected. Analysis will be slower.")
        return False

# Sidebar components
st.sidebar.title("Options")

# Temperature slider
temperature = st.sidebar.slider(
    "Temperature", 
    min_value=0.1, 
    max_value=1.0, 
    value=0.7, 
    step=0.1,
    help="Higher values make output more random, lower values more deterministic"
)

# Max response length slider
max_tokens = st.sidebar.slider(
    "Maximum Response Length", 
    min_value=100, 
    max_value=1000, 
    value=500, 
    step=50,
    help="The maximum number of tokens in the response"
)

# Custom instruction text area in sidebar
custom_instruction = st.sidebar.text_area(
    "Custom Instructions (Advanced)",
    value="Analyze for facial inconsistencies, lighting irregularities, mismatched shadows, and other signs of manipulation.",
    help="Add specific instructions for the model"
)

# About section in sidebar
st.sidebar.markdown("---")
st.sidebar.subheader("About")
st.sidebar.markdown("""
This analyzer looks for:
- Facial inconsistencies
- Unnatural movements
- Lighting issues
- Texture anomalies
- Edge artifacts
- Blending problems

**Model**: Fine-tuned Llama 3.2 Vision
**Creator**: [Saakshi Gupta](https://huggingface.co/saakshigupta)
""")

# Function to fix cross-attention masks
def fix_cross_attention_mask(inputs):
    if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
        batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
        visual_features = 6404  # Critical dimension
        new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles),
                            device=inputs['cross_attention_mask'].device)
        inputs['cross_attention_mask'] = new_mask
        st.success("Fixed cross-attention mask dimensions")
    return inputs

# Load model function
@st.cache_resource
def load_model():
    with st.spinner("Loading model... This may take a few minutes. Please be patient..."):
        try:
            # Check for GPU
            has_gpu = check_gpu()
            
            # Load base model and tokenizer using Unsloth
            base_model_id = "unsloth/llama-3.2-11b-vision-instruct"
            model, tokenizer = FastVisionModel.from_pretrained(
                base_model_id,
                load_in_4bit=True,
            )

            # Load the adapter
            adapter_id = "saakshigupta/deepfake-explainer-1"
            model = PeftModel.from_pretrained(model, adapter_id)

            # Set to inference mode
            FastVisionModel.for_inference(model)
            
            return model, tokenizer
        except Exception as e:
            st.error(f"Error loading model: {str(e)}")
            return None, None

# Analyze image function
def analyze_image(image, question, model, tokenizer, temperature=0.7, max_tokens=500, custom_instruction=""):
    # Combine question with custom instruction if provided
    if custom_instruction.strip():
        full_prompt = f"{question}\n\nAdditional instructions: {custom_instruction}"
    else:
        full_prompt = question
    
    # Format the message
    messages = [
        {"role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": full_prompt}
        ]}
    ]

    # Apply chat template
    input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)

    # Process with image
    inputs = tokenizer(
        image,
        input_text,
        add_special_tokens=False,
        return_tensors="pt",
    ).to(model.device)

    # Fix cross-attention mask if needed
    inputs = fix_cross_attention_mask(inputs)

    # Generate response
    with st.spinner("Analyzing image... (this may take 15-30 seconds)"):
        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                use_cache=True,
                temperature=temperature,
                top_p=0.9
            )

        # Decode the output
        response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        
        # Try to extract just the model's response (after the prompt)
        if full_prompt in response:
            result = response.split(full_prompt)[-1].strip()
        else:
            result = response
            
        return result

# Main app
def main():
    # Create a button to load the model
    if 'model_loaded' not in st.session_state:
        st.session_state.model_loaded = False
        st.session_state.model = None
        st.session_state.tokenizer = None
    
    # Load model button
    if not st.session_state.model_loaded:
        if st.button("πŸ“₯ Load Deepfake Analysis Model", type="primary"):
            model, tokenizer = load_model()
            if model is not None and tokenizer is not None:
                st.session_state.model = model
                st.session_state.tokenizer = tokenizer
                st.session_state.model_loaded = True
                st.success("βœ… Model loaded successfully! You can now analyze images.")
            else:
                st.error("❌ Failed to load model. Please check the logs for errors.")
    else:
        st.success("βœ… Model loaded successfully! You can now analyze images.")
    
    # Image upload section
    st.subheader("Upload an Image")
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
    
    # Default question with option to customize
    default_question = "Analyze this image and tell me if it's a deepfake. Provide both technical and non-technical explanations."
    question = st.text_area("Question/Prompt:", value=default_question, height=100)
    
    if uploaded_file is not None:
        # Display the uploaded image
        image = Image.open(uploaded_file).convert("RGB")
        st.image(image, caption="Uploaded Image", use_column_width=True)
        
        # Analyze button - only enabled if model is loaded
        if st.session_state.model_loaded:
            if st.button("πŸ” Analyze Image", type="primary"):
                result = analyze_image(
                    image, 
                    question, 
                    st.session_state.model, 
                    st.session_state.tokenizer,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    custom_instruction=custom_instruction
                )
                
                # Display results
                st.success("βœ… Analysis complete!")
                
                # Check if the result contains both technical and non-technical explanations
                if "Technical" in result and "Non-Technical" in result:
                    # Split the result into technical and non-technical sections
                    parts = result.split("Non-Technical")
                    technical = parts[0]
                    non_technical = "Non-Technical" + parts[1]
                    
                    # Display in two columns
                    col1, col2 = st.columns(2)
                    with col1:
                        st.subheader("Technical Analysis")
                        st.markdown(technical)
                    
                    with col2:
                        st.subheader("Simple Explanation")
                        st.markdown(non_technical)
                else:
                    # Just display the whole result
                    st.subheader("Analysis Result")
                    st.markdown(result)
        else:
            st.warning("⚠️ Please load the model first before analyzing images.")
    
    # Footer
    st.markdown("---")
    st.caption("Deepfake Image Analyzer")

if __name__ == "__main__":
    main()