saakshigupta commited on
Commit
99ef832
·
verified ·
1 Parent(s): 10433cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -91
app.py CHANGED
@@ -1,14 +1,14 @@
1
- # app.py
2
  import streamlit as st
3
  import torch
 
4
  from PIL import Image
5
- import io
6
- from transformers import AutoProcessor, BitsAndBytesConfig, MllamaForCausalLM
7
  from peft import PeftModel
 
8
 
9
  # Page config
10
  st.set_page_config(
11
- page_title="Deepfake Explainer",
12
  page_icon="🔍",
13
  layout="wide"
14
  )
@@ -17,50 +17,73 @@ st.set_page_config(
17
  st.title("Deepfake Image Analyzer")
18
  st.markdown("Upload an image to analyze it for possible deepfake manipulation")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  @st.cache_resource
21
  def load_model():
22
- """Load model and processor (cached to avoid reloading)"""
23
- # Load base model
24
- base_model_id = "unsloth/llama-3.2-11b-vision-instruct-unsloth-bnb-4bit"
25
- processor = AutoProcessor.from_pretrained(base_model_id)
26
-
27
- # Configure 4-bit quantization
28
- quantization_config = BitsAndBytesConfig(
29
- load_in_4bit=True,
30
- bnb_4bit_compute_dtype=torch.float16,
31
- bnb_4bit_quant_type="nf4",
32
- bnb_4bit_use_double_quant=True
33
- )
34
-
35
- # Use the specific model class MllamaForCausalLM instead of AutoModelForCausalLM
36
- model = MllamaForCausalLM.from_pretrained(
37
- base_model_id,
38
- device_map="auto",
39
- torch_dtype=torch.float16,
40
- quantization_config=quantization_config
41
- )
42
-
43
- # Load adapter
44
- adapter_id = "saakshigupta/deepfake-explainer-1"
45
- model = PeftModel.from_pretrained(model, adapter_id)
46
-
47
- return model, processor
 
 
 
 
 
48
 
49
  # Function to fix cross-attention masks
50
  def fix_processor_outputs(inputs):
 
51
  if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
52
  batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
53
- visual_features = 6404 # The exact dimension we fixed in training
54
- new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles),
55
- device=inputs['cross_attention_mask'].device)
 
 
56
  inputs['cross_attention_mask'] = new_mask
57
- st.write("✅ Fixed cross-attention mask dimensions")
58
- return inputs
59
-
60
- # Load model on first run
61
- with st.spinner("Loading model... this may take a minute."):
62
- model, processor = load_model()
63
- st.success("Model loaded successfully!")
64
 
65
  # Create sidebar with options
66
  with st.sidebar:
@@ -76,65 +99,118 @@ with st.sidebar:
76
  )
77
 
78
  st.markdown("### About")
79
- st.markdown("This app uses a fine-tuned Llama 3.2 Vision model to detect and explain deepfakes.")
80
- st.markdown("Model by [saakshigupta](https://huggingface.co/saakshigupta/deepfake-explainer-1)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  # Main content area - file uploader
83
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
84
 
85
- if uploaded_file is not None:
 
 
 
86
  # Display the image
87
  image = Image.open(uploaded_file).convert('RGB')
88
  st.image(image, caption="Uploaded Image", use_column_width=True)
89
 
90
  # Analyze button
91
  if st.button("Analyze Image"):
92
- with st.spinner("Analyzing the image..."):
93
- # Process the image
94
- inputs = processor(text=custom_prompt, images=image, return_tensors="pt")
95
-
96
- # Fix cross-attention mask
97
- inputs = fix_processor_outputs(inputs)
98
-
99
- # Move to device
100
- inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
101
-
102
- # Generate the analysis
103
- with torch.no_grad():
104
- output_ids = model.generate(
105
- **inputs,
106
- max_new_tokens=max_length,
107
- temperature=temperature,
108
- top_p=0.9
109
- )
110
-
111
- # Decode the output
112
- response = processor.decode(output_ids[0], skip_special_tokens=True)
113
-
114
- # Extract the actual response (removing the prompt)
115
- if custom_prompt in response:
116
- result = response.split(custom_prompt)[-1].strip()
117
- else:
118
- result = response
119
-
120
- # Display result in a nice format
121
- st.success("Analysis complete!")
122
-
123
- # Show technical and non-technical explanations separately if they exist
124
- if "Technical Explanation:" in result and "Non-Technical Explanation:" in result:
125
- technical, non_technical = result.split("Non-Technical Explanation:")
126
- technical = technical.replace("Technical Explanation:", "").strip()
127
-
128
- col1, col2 = st.columns(2)
129
- with col1:
130
- st.subheader("Technical Analysis")
131
- st.write(technical)
132
-
133
- with col2:
134
- st.subheader("Simple Explanation")
135
- st.write(non_technical)
136
- else:
137
- st.subheader("Analysis Result")
138
- st.write(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  else:
140
- st.info("Please upload an image to begin analysis")
 
 
 
 
 
 
1
  import streamlit as st
2
  import torch
3
+ import os
4
  from PIL import Image
5
+ from transformers import AutoProcessor, MllamaForCausalLM, BitsAndBytesConfig
 
6
  from peft import PeftModel
7
+ import gc
8
 
9
  # Page config
10
  st.set_page_config(
11
+ page_title="Deepfake Image Analyzer",
12
  page_icon="🔍",
13
  layout="wide"
14
  )
 
17
  st.title("Deepfake Image Analyzer")
18
  st.markdown("Upload an image to analyze it for possible deepfake manipulation")
19
 
20
+ # Function to free up memory
21
+ def free_memory():
22
+ gc.collect()
23
+ if torch.cuda.is_available():
24
+ torch.cuda.empty_cache()
25
+ torch.cuda.ipc_collect()
26
+
27
+ # Helper functions
28
+ def init_device():
29
+ """Set the appropriate device and return it"""
30
+ if torch.cuda.is_available():
31
+ st.sidebar.success("✓ GPU available: Using CUDA")
32
+ return "cuda"
33
+ else:
34
+ st.sidebar.warning("⚠️ No GPU detected: Using CPU (analysis will be slow)")
35
+ return "cpu"
36
+
37
+ # Set device
38
+ device = init_device()
39
+
40
  @st.cache_resource
41
  def load_model():
42
+ """Load model and processor with caching to avoid reloading"""
43
+ try:
44
+ # Load base model
45
+ base_model_id = "unsloth/llama-3.2-11b-vision-instruct-unsloth-bnb-4bit"
46
+ processor = AutoProcessor.from_pretrained(base_model_id)
47
+
48
+ # Configure 4-bit quantization with correct dtype
49
+ quantization_config = BitsAndBytesConfig(
50
+ load_in_4bit=True,
51
+ bnb_4bit_compute_dtype=torch.float16,
52
+ bnb_4bit_quant_type="nf4",
53
+ bnb_4bit_use_double_quant=True
54
+ )
55
+
56
+ # Load model with explicit dtype settings using MllamaForCausalLM
57
+ model = MllamaForCausalLM.from_pretrained(
58
+ base_model_id,
59
+ device_map="auto",
60
+ torch_dtype=torch.float16,
61
+ quantization_config=quantization_config
62
+ )
63
+
64
+ # Load adapter
65
+ adapter_id = "saakshigupta/deepfake-explainer-1"
66
+ model = PeftModel.from_pretrained(model, adapter_id)
67
+
68
+ return model, processor
69
+
70
+ except Exception as e:
71
+ st.error(f"Error loading model: {str(e)}")
72
+ return None, None
73
 
74
  # Function to fix cross-attention masks
75
  def fix_processor_outputs(inputs):
76
+ """Fix cross-attention mask dimensions if needed"""
77
  if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
78
  batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
79
+ visual_features = 6404 # The exact dimension used in training
80
+ new_mask = torch.ones(
81
+ (batch_size, seq_len, visual_features, num_tiles),
82
+ device=inputs['cross_attention_mask'].device
83
+ )
84
  inputs['cross_attention_mask'] = new_mask
85
+ return True, inputs
86
+ return False, inputs
 
 
 
 
 
87
 
88
  # Create sidebar with options
89
  with st.sidebar:
 
99
  )
100
 
101
  st.markdown("### About")
102
+ st.markdown("""
103
+ This app uses a fine-tuned Llama 3.2 Vision model to detect and explain deepfakes.
104
+
105
+ The analyzer looks for:
106
+ - Inconsistencies in facial features
107
+ - Unusual lighting or shadows
108
+ - Unnatural blur patterns
109
+ - Artifacts around edges
110
+ - Texture inconsistencies
111
+
112
+ Model by [saakshigupta](https://huggingface.co/saakshigupta/deepfake-explainer-1)
113
+ """)
114
+
115
+ # Load model on app startup with a progress bar
116
+ if 'model_loaded' not in st.session_state:
117
+ progress_bar = st.progress(0)
118
+ st.info("Loading model... this may take a minute.")
119
+
120
+ for i in range(10):
121
+ # Simulate progress while model loads
122
+ progress_bar.progress((i + 1) * 10)
123
+ if i == 2:
124
+ # Start loading the model at 30% progress
125
+ model, processor = load_model()
126
+ if model is not None:
127
+ st.session_state['model'] = model
128
+ st.session_state['processor'] = processor
129
+ st.session_state['model_loaded'] = True
130
+
131
+ progress_bar.empty()
132
+
133
+ if 'model_loaded' in st.session_state and st.session_state['model_loaded']:
134
+ st.success("Model loaded successfully!")
135
+ else:
136
+ st.error("Failed to load model. Try refreshing the page.")
137
 
138
  # Main content area - file uploader
139
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
140
 
141
+ # Check if model is loaded
142
+ model_loaded = 'model_loaded' in st.session_state and st.session_state['model_loaded']
143
+
144
+ if uploaded_file is not None and model_loaded:
145
  # Display the image
146
  image = Image.open(uploaded_file).convert('RGB')
147
  st.image(image, caption="Uploaded Image", use_column_width=True)
148
 
149
  # Analyze button
150
  if st.button("Analyze Image"):
151
+ with st.spinner("Analyzing the image... This may take 15-30 seconds"):
152
+ try:
153
+ # Get components from session state
154
+ model = st.session_state['model']
155
+ processor = st.session_state['processor']
156
+
157
+ # Process the image
158
+ inputs = processor(text=custom_prompt, images=image, return_tensors="pt")
159
+
160
+ # Fix cross-attention mask
161
+ fixed, inputs = fix_processor_outputs(inputs)
162
+ if fixed:
163
+ st.info("Fixed cross-attention mask dimensions")
164
+
165
+ # Move to device
166
+ inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
167
+
168
+ # Generate the analysis
169
+ with torch.no_grad():
170
+ output_ids = model.generate(
171
+ **inputs,
172
+ max_new_tokens=max_length,
173
+ temperature=temperature,
174
+ top_p=0.9
175
+ )
176
+
177
+ # Decode the output
178
+ response = processor.decode(output_ids[0], skip_special_tokens=True)
179
+
180
+ # Extract the actual response (removing the prompt)
181
+ if custom_prompt in response:
182
+ result = response.split(custom_prompt)[-1].strip()
183
+ else:
184
+ result = response
185
+
186
+ # Display result in a nice format
187
+ st.success("Analysis complete!")
188
+
189
+ # Show technical and non-technical explanations separately if they exist
190
+ if "Technical Explanation:" in result and "Non-Technical Explanation:" in result:
191
+ technical, non_technical = result.split("Non-Technical Explanation:")
192
+ technical = technical.replace("Technical Explanation:", "").strip()
193
+
194
+ col1, col2 = st.columns(2)
195
+ with col1:
196
+ st.subheader("Technical Analysis")
197
+ st.write(technical)
198
+
199
+ with col2:
200
+ st.subheader("Simple Explanation")
201
+ st.write(non_technical)
202
+ else:
203
+ st.subheader("Analysis Result")
204
+ st.write(result)
205
+
206
+ # Free memory after analysis
207
+ free_memory()
208
+
209
+ except Exception as e:
210
+ st.error(f"Error analyzing image: {str(e)}")
211
  else:
212
+ st.info("Please upload an image to begin analysis")
213
+
214
+ # Add footer
215
+ st.markdown("---")
216
+ st.markdown("Deepfake Image Analyzer")