saakshigupta commited on
Commit
c692452
Β·
verified Β·
1 Parent(s): bfe7823

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -193
app.py CHANGED
@@ -1,222 +1,246 @@
1
  import streamlit as st
2
  import torch
3
  from PIL import Image
4
- import os
5
- import gc
6
- from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
7
  from peft import PeftModel
 
 
 
8
 
9
- # Page config
10
  st.set_page_config(
11
- page_title="Deepfake Image Analyzer",
12
- page_icon="πŸ”",
13
- layout="wide"
14
  )
15
 
16
- # App title and description
17
  st.title("Deepfake Image Analyzer")
18
  st.markdown("Upload an image to analyze it for possible deepfake manipulation")
19
 
20
- # Function to free up memory
21
- def free_memory():
22
- gc.collect()
23
  if torch.cuda.is_available():
24
- torch.cuda.empty_cache()
25
-
26
- # Helper function to check CUDA
27
- def init_device():
28
- if torch.cuda.is_available():
29
- st.sidebar.success("βœ“ GPU available: Using CUDA")
30
- return "cuda"
31
  else:
32
- st.sidebar.warning("⚠️ No GPU detected: Using CPU (analysis will be slow)")
33
- return "cpu"
34
 
35
- # Set device
36
- device = init_device()
37
 
38
- @st.cache_resource
39
- def load_model():
40
- """Load pre-quantized model"""
41
- try:
42
- # Using your original base model
43
- base_model_id = "unsloth/llama-3.2-11b-vision-instruct-unsloth-bnb-4bit"
44
-
45
- # Load processor
46
- processor = AutoProcessor.from_pretrained(base_model_id)
47
-
48
- # Configure quantization settings for unsloth model
49
- quantization_config = BitsAndBytesConfig(
50
- load_in_4bit=True,
51
- bnb_4bit_compute_dtype=torch.float16,
52
- bnb_4bit_use_double_quant=True,
53
- bnb_4bit_quant_type="nf4",
54
- bnb_4bit_quant_storage=torch.float16,
55
- llm_int8_skip_modules=["lm_head"],
56
- llm_int8_enable_fp32_cpu_offload=True
57
- )
58
-
59
- # Load the pre-quantized model with unsloth settings
60
- model = AutoModelForCausalLM.from_pretrained(
61
- base_model_id,
62
- device_map="auto",
63
- quantization_config=quantization_config,
64
- torch_dtype=torch.float16,
65
- trust_remote_code=True,
66
- low_cpu_mem_usage=True,
67
- use_cache=True,
68
- offload_folder="offload" # Enable disk offloading
69
- )
70
-
71
- # Load adapter
72
- adapter_id = "saakshigupta/deepfake-explainer-1"
73
- model = PeftModel.from_pretrained(model, adapter_id)
74
-
75
- return model, processor
76
-
77
- except Exception as e:
78
- st.error(f"Error loading model: {str(e)}")
79
- st.exception(e)
80
- return None, None
81
 
82
  # Function to fix cross-attention masks
83
- def fix_processor_outputs(inputs):
84
- """Fix cross-attention mask dimensions if needed"""
85
  if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
86
  batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
87
- visual_features = 6404 # The exact dimension used in training
88
- new_mask = torch.ones(
89
- (batch_size, seq_len, visual_features, num_tiles),
90
- device=inputs['cross_attention_mask'].device
91
- )
92
  inputs['cross_attention_mask'] = new_mask
93
- return True, inputs
94
- return False, inputs
95
-
96
- # Create sidebar with options
97
- with st.sidebar:
98
- st.header("Options")
99
- temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1,
100
- help="Higher values make output more random, lower values more deterministic")
101
- max_length = st.slider("Maximum response length", min_value=100, max_value=1000, value=500, step=50)
102
-
103
- custom_prompt = st.text_area(
104
- "Custom instruction (optional)",
105
- value="Analyze this image and determine if it's a deepfake. Provide both technical and non-technical explanations.",
106
- height=100
107
- )
108
-
109
- st.markdown("### About")
110
- st.markdown("""
111
- This app uses a fine-tuned Llama 3.2 Vision model to detect and explain deepfakes.
112
-
113
- The analyzer looks for:
114
- - Inconsistencies in facial features
115
- - Unusual lighting or shadows
116
- - Unnatural blur patterns
117
- - Artifacts around edges
118
- - Texture inconsistencies
119
-
120
- Model by [saakshigupta](https://huggingface.co/saakshigupta/deepfake-explainer-1)
121
- """)
122
 
123
- # Load model button
124
- if st.button("Load Model"):
125
- with st.spinner("Loading model... this may take several minutes"):
 
126
  try:
127
- model, processor = load_model()
128
- if model is not None and processor is not None:
129
- st.session_state['model'] = model
130
- st.session_state['processor'] = processor
131
- st.success("Model loaded successfully!")
132
- else:
133
- st.error("Failed to load model.")
134
- except Exception as e:
135
- st.error(f"Error during model loading: {str(e)}")
136
- st.exception(e)
137
 
138
- # Main content area - file uploader
139
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
140
 
141
- # Check if model is loaded
142
- model_loaded = 'model' in st.session_state and st.session_state['model'] is not None
 
 
 
 
 
143
 
144
- if uploaded_file is not None:
145
- # Display the image
146
- image = Image.open(uploaded_file).convert('RGB')
147
- st.image(image, caption="Uploaded Image", use_column_width=True)
 
 
 
148
 
149
- # Analyze button (only enabled if model is loaded)
150
- if st.button("Analyze Image", disabled=not model_loaded):
151
- if not model_loaded:
152
- st.warning("Please load the model first by clicking the 'Load Model' button.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  else:
154
- with st.spinner("Analyzing the image... This may take 15-30 seconds"):
155
- try:
156
- # Get components from session state
157
- model = st.session_state['model']
158
- processor = st.session_state['processor']
159
-
160
- # Process the image using the processor
161
- inputs = processor(text=custom_prompt, images=image, return_tensors="pt")
162
-
163
- # Fix cross-attention mask if needed
164
- fixed, inputs = fix_processor_outputs(inputs)
165
- if fixed:
166
- st.info("Fixed cross-attention mask dimensions")
167
-
168
- # Move to device
169
- inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
170
-
171
- # Generate the analysis
172
- with torch.no_grad():
173
- output_ids = model.generate(
174
- **inputs,
175
- max_new_tokens=max_length,
176
- temperature=temperature,
177
- top_p=0.9
178
- )
179
-
180
- # Decode the output
181
- response = processor.decode(output_ids[0], skip_special_tokens=True)
182
-
183
- # Extract the actual response (removing the prompt)
184
- if custom_prompt in response:
185
- result = response.split(custom_prompt)[-1].strip()
186
- else:
187
- result = response
188
-
189
- # Display result in a nice format
190
- st.success("Analysis complete!")
191
-
192
- # Show technical and non-technical explanations separately if they exist
193
- if "Technical Explanation:" in result and "Non-Technical Explanation:" in result:
194
- technical, non_technical = result.split("Non-Technical Explanation:")
195
- technical = technical.replace("Technical Explanation:", "").strip()
196
-
197
- col1, col2 = st.columns(2)
198
- with col1:
199
- st.subheader("Technical Analysis")
200
- st.write(technical)
201
-
202
- with col2:
203
- st.subheader("Simple Explanation")
204
- st.write(non_technical)
205
- else:
206
- st.subheader("Analysis Result")
207
- st.write(result)
 
 
 
 
 
 
 
208
 
209
- # Free memory after analysis
210
- free_memory()
 
 
 
211
 
212
- except Exception as e:
213
- st.error(f"Error analyzing image: {str(e)}")
214
- st.exception(e)
215
- elif not model_loaded:
216
- st.warning("Please load the model first by clicking the 'Load Model' button at the top of the page.")
217
- else:
218
- st.info("Please upload an image to begin analysis")
219
-
220
- # Add footer
221
- st.markdown("---")
222
- st.markdown("Deepfake Image Analyzer")
 
 
 
 
 
 
1
  import streamlit as st
2
  import torch
3
  from PIL import Image
4
+ import io
 
 
5
  from peft import PeftModel
6
+ from unsloth import FastVisionModel
7
+ import tempfile
8
+ import os
9
 
10
+ # App title and description
11
  st.set_page_config(
12
+ page_title="Deepfake Analyzer",
13
+ layout="wide",
14
+ page_icon="πŸ”"
15
  )
16
 
17
+ # Main title and description
18
  st.title("Deepfake Image Analyzer")
19
  st.markdown("Upload an image to analyze it for possible deepfake manipulation")
20
 
21
+ # Check for GPU availability
22
+ def check_gpu():
 
23
  if torch.cuda.is_available():
24
+ gpu_info = torch.cuda.get_device_properties(0)
25
+ st.sidebar.success(f"βœ… GPU available: {gpu_info.name} ({gpu_info.total_memory / (1024**3):.2f} GB)")
26
+ return True
 
 
 
 
27
  else:
28
+ st.sidebar.warning("⚠️ No GPU detected. Analysis will be slower.")
29
+ return False
30
 
31
+ # Sidebar components
32
+ st.sidebar.title("Options")
33
 
34
+ # Temperature slider
35
+ temperature = st.sidebar.slider(
36
+ "Temperature",
37
+ min_value=0.1,
38
+ max_value=1.0,
39
+ value=0.7,
40
+ step=0.1,
41
+ help="Higher values make output more random, lower values more deterministic"
42
+ )
43
+
44
+ # Max response length slider
45
+ max_tokens = st.sidebar.slider(
46
+ "Maximum Response Length",
47
+ min_value=100,
48
+ max_value=1000,
49
+ value=500,
50
+ step=50,
51
+ help="The maximum number of tokens in the response"
52
+ )
53
+
54
+ # Custom instruction text area in sidebar
55
+ custom_instruction = st.sidebar.text_area(
56
+ "Custom Instructions (Advanced)",
57
+ value="Analyze for facial inconsistencies, lighting irregularities, mismatched shadows, and other signs of manipulation.",
58
+ help="Add specific instructions for the model"
59
+ )
60
+
61
+ # About section in sidebar
62
+ st.sidebar.markdown("---")
63
+ st.sidebar.subheader("About")
64
+ st.sidebar.markdown("""
65
+ This analyzer looks for:
66
+ - Facial inconsistencies
67
+ - Unnatural movements
68
+ - Lighting issues
69
+ - Texture anomalies
70
+ - Edge artifacts
71
+ - Blending problems
72
+
73
+ **Model**: Fine-tuned Llama 3.2 Vision
74
+ **Creator**: [Saakshi Gupta](https://huggingface.co/saakshigupta)
75
+ """)
 
76
 
77
  # Function to fix cross-attention masks
78
+ def fix_cross_attention_mask(inputs):
 
79
  if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
80
  batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
81
+ visual_features = 6404 # Critical dimension
82
+ new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles),
83
+ device=inputs['cross_attention_mask'].device)
 
 
84
  inputs['cross_attention_mask'] = new_mask
85
+ st.success("Fixed cross-attention mask dimensions")
86
+ return inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ # Load model function
89
+ @st.cache_resource
90
+ def load_model():
91
+ with st.spinner("Loading model... This may take a few minutes. Please be patient..."):
92
  try:
93
+ # Check for GPU
94
+ has_gpu = check_gpu()
95
+
96
+ # Load base model and tokenizer using Unsloth
97
+ base_model_id = "unsloth/llama-3.2-11b-vision-instruct"
98
+ model, tokenizer = FastVisionModel.from_pretrained(
99
+ base_model_id,
100
+ load_in_4bit=True,
101
+ )
 
102
 
103
+ # Load the adapter
104
+ adapter_id = "saakshigupta/deepfake-explainer-1"
105
+ model = PeftModel.from_pretrained(model, adapter_id)
106
 
107
+ # Set to inference mode
108
+ FastVisionModel.for_inference(model)
109
+
110
+ return model, tokenizer
111
+ except Exception as e:
112
+ st.error(f"Error loading model: {str(e)}")
113
+ return None, None
114
 
115
+ # Analyze image function
116
+ def analyze_image(image, question, model, tokenizer, temperature=0.7, max_tokens=500, custom_instruction=""):
117
+ # Combine question with custom instruction if provided
118
+ if custom_instruction.strip():
119
+ full_prompt = f"{question}\n\nAdditional instructions: {custom_instruction}"
120
+ else:
121
+ full_prompt = question
122
 
123
+ # Format the message
124
+ messages = [
125
+ {"role": "user", "content": [
126
+ {"type": "image"},
127
+ {"type": "text", "text": full_prompt}
128
+ ]}
129
+ ]
130
+
131
+ # Apply chat template
132
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
133
+
134
+ # Process with image
135
+ inputs = tokenizer(
136
+ image,
137
+ input_text,
138
+ add_special_tokens=False,
139
+ return_tensors="pt",
140
+ ).to(model.device)
141
+
142
+ # Fix cross-attention mask if needed
143
+ inputs = fix_cross_attention_mask(inputs)
144
+
145
+ # Generate response
146
+ with st.spinner("Analyzing image... (this may take 15-30 seconds)"):
147
+ with torch.no_grad():
148
+ output_ids = model.generate(
149
+ **inputs,
150
+ max_new_tokens=max_tokens,
151
+ use_cache=True,
152
+ temperature=temperature,
153
+ top_p=0.9
154
+ )
155
+
156
+ # Decode the output
157
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
158
+
159
+ # Try to extract just the model's response (after the prompt)
160
+ if full_prompt in response:
161
+ result = response.split(full_prompt)[-1].strip()
162
  else:
163
+ result = response
164
+
165
+ return result
166
+
167
+ # Main app
168
+ def main():
169
+ # Create a button to load the model
170
+ if 'model_loaded' not in st.session_state:
171
+ st.session_state.model_loaded = False
172
+ st.session_state.model = None
173
+ st.session_state.tokenizer = None
174
+
175
+ # Load model button
176
+ if not st.session_state.model_loaded:
177
+ if st.button("πŸ“₯ Load Deepfake Analysis Model", type="primary"):
178
+ model, tokenizer = load_model()
179
+ if model is not None and tokenizer is not None:
180
+ st.session_state.model = model
181
+ st.session_state.tokenizer = tokenizer
182
+ st.session_state.model_loaded = True
183
+ st.success("βœ… Model loaded successfully! You can now analyze images.")
184
+ else:
185
+ st.error("❌ Failed to load model. Please check the logs for errors.")
186
+ else:
187
+ st.success("βœ… Model loaded successfully! You can now analyze images.")
188
+
189
+ # Image upload section
190
+ st.subheader("Upload an Image")
191
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
192
+
193
+ # Default question with option to customize
194
+ default_question = "Analyze this image and tell me if it's a deepfake. Provide both technical and non-technical explanations."
195
+ question = st.text_area("Question/Prompt:", value=default_question, height=100)
196
+
197
+ if uploaded_file is not None:
198
+ # Display the uploaded image
199
+ image = Image.open(uploaded_file).convert("RGB")
200
+ st.image(image, caption="Uploaded Image", use_column_width=True)
201
+
202
+ # Analyze button - only enabled if model is loaded
203
+ if st.session_state.model_loaded:
204
+ if st.button("πŸ” Analyze Image", type="primary"):
205
+ result = analyze_image(
206
+ image,
207
+ question,
208
+ st.session_state.model,
209
+ st.session_state.tokenizer,
210
+ temperature=temperature,
211
+ max_tokens=max_tokens,
212
+ custom_instruction=custom_instruction
213
+ )
214
+
215
+ # Display results
216
+ st.success("βœ… Analysis complete!")
217
+
218
+ # Check if the result contains both technical and non-technical explanations
219
+ if "Technical" in result and "Non-Technical" in result:
220
+ # Split the result into technical and non-technical sections
221
+ parts = result.split("Non-Technical")
222
+ technical = parts[0]
223
+ non_technical = "Non-Technical" + parts[1]
224
 
225
+ # Display in two columns
226
+ col1, col2 = st.columns(2)
227
+ with col1:
228
+ st.subheader("Technical Analysis")
229
+ st.markdown(technical)
230
 
231
+ with col2:
232
+ st.subheader("Simple Explanation")
233
+ st.markdown(non_technical)
234
+ else:
235
+ # Just display the whole result
236
+ st.subheader("Analysis Result")
237
+ st.markdown(result)
238
+ else:
239
+ st.warning("⚠️ Please load the model first before analyzing images.")
240
+
241
+ # Footer
242
+ st.markdown("---")
243
+ st.caption("Deepfake Image Analyzer")
244
+
245
+ if __name__ == "__main__":
246
+ main()