saakshigupta commited on
Commit
3680540
·
verified ·
1 Parent(s): c2b50c2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -0
app.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ import os
5
+ import gc
6
+ from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
7
+ from peft import PeftModel
8
+
9
+ # Page config
10
+ st.set_page_config(
11
+ page_title="Deepfake Image Analyzer",
12
+ page_icon="🔍",
13
+ layout="wide"
14
+ )
15
+
16
+ # App title and description
17
+ st.title("Deepfake Image Analyzer")
18
+ st.markdown("Upload an image to analyze it for possible deepfake manipulation")
19
+
20
+ # Function to free up memory
21
+ def free_memory():
22
+ gc.collect()
23
+ if torch.cuda.is_available():
24
+ torch.cuda.empty_cache()
25
+
26
+ # Helper function to check CUDA
27
+ def init_device():
28
+ if torch.cuda.is_available():
29
+ st.sidebar.success("✓ GPU available: Using CUDA")
30
+ return "cuda"
31
+ else:
32
+ st.sidebar.warning("⚠️ No GPU detected: Using CPU (analysis will be slow)")
33
+ return "cpu"
34
+
35
+ # Set device
36
+ device = init_device()
37
+
38
+ @st.cache_resource
39
+ def load_model():
40
+ """Load pre-quantized model"""
41
+ try:
42
+ # Using your original base model
43
+ base_model_id = "unsloth/llama-3.2-11b-vision-instruct-unsloth-bnb-4bit"
44
+
45
+ # Load processor
46
+ processor = AutoProcessor.from_pretrained(base_model_id)
47
+
48
+ # Configure quantization settings for unsloth model
49
+ quantization_config = BitsAndBytesConfig(
50
+ load_in_4bit=True,
51
+ bnb_4bit_compute_dtype=torch.float16,
52
+ bnb_4bit_use_double_quant=True,
53
+ bnb_4bit_quant_type="nf4",
54
+ bnb_4bit_quant_storage=torch.float16,
55
+ llm_int8_skip_modules=["lm_head"],
56
+ llm_int8_enable_fp32_cpu_offload=True
57
+ )
58
+
59
+ # Load the pre-quantized model with unsloth settings
60
+ model = AutoModelForCausalLM.from_pretrained(
61
+ base_model_id,
62
+ device_map="auto",
63
+ quantization_config=quantization_config,
64
+ torch_dtype=torch.float16,
65
+ trust_remote_code=True,
66
+ low_cpu_mem_usage=True,
67
+ use_cache=True,
68
+ offload_folder="offload" # Enable disk offloading
69
+ )
70
+
71
+ # Load adapter
72
+ adapter_id = "saakshigupta/deepfake-explainer-1"
73
+ model = PeftModel.from_pretrained(model, adapter_id)
74
+
75
+ return model, processor
76
+
77
+ except Exception as e:
78
+ st.error(f"Error loading model: {str(e)}")
79
+ st.exception(e)
80
+ return None, None
81
+
82
+ # Function to fix cross-attention masks
83
+ def fix_processor_outputs(inputs):
84
+ """Fix cross-attention mask dimensions if needed"""
85
+ if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
86
+ batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
87
+ visual_features = 6404 # The exact dimension used in training
88
+ new_mask = torch.ones(
89
+ (batch_size, seq_len, visual_features, num_tiles),
90
+ device=inputs['cross_attention_mask'].device
91
+ )
92
+ inputs['cross_attention_mask'] = new_mask
93
+ return True, inputs
94
+ return False, inputs
95
+
96
+ # Create sidebar with options
97
+ with st.sidebar:
98
+ st.header("Options")
99
+ temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1,
100
+ help="Higher values make output more random, lower values more deterministic")
101
+ max_length = st.slider("Maximum response length", min_value=100, max_value=1000, value=500, step=50)
102
+
103
+ custom_prompt = st.text_area(
104
+ "Custom instruction (optional)",
105
+ value="Analyze this image and determine if it's a deepfake. Provide both technical and non-technical explanations.",
106
+ height=100
107
+ )
108
+
109
+ st.markdown("### About")
110
+ st.markdown("""
111
+ This app uses a fine-tuned Llama 3.2 Vision model to detect and explain deepfakes.
112
+
113
+ The analyzer looks for:
114
+ - Inconsistencies in facial features
115
+ - Unusual lighting or shadows
116
+ - Unnatural blur patterns
117
+ - Artifacts around edges
118
+ - Texture inconsistencies
119
+
120
+ Model by [saakshigupta](https://huggingface.co/saakshigupta/deepfake-explainer-1)
121
+ """)
122
+
123
+ # Load model button
124
+ if st.button("Load Model"):
125
+ with st.spinner("Loading model... this may take several minutes"):
126
+ try:
127
+ model, processor = load_model()
128
+ if model is not None and processor is not None:
129
+ st.session_state['model'] = model
130
+ st.session_state['processor'] = processor
131
+ st.success("Model loaded successfully!")
132
+ else:
133
+ st.error("Failed to load model.")
134
+ except Exception as e:
135
+ st.error(f"Error during model loading: {str(e)}")
136
+ st.exception(e)
137
+
138
+ # Main content area - file uploader
139
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
140
+
141
+ # Check if model is loaded
142
+ model_loaded = 'model' in st.session_state and st.session_state['model'] is not None
143
+
144
+ if uploaded_file is not None:
145
+ # Display the image
146
+ image = Image.open(uploaded_file).convert('RGB')
147
+ st.image(image, caption="Uploaded Image", use_column_width=True)
148
+
149
+ # Analyze button (only enabled if model is loaded)
150
+ if st.button("Analyze Image", disabled=not model_loaded):
151
+ if not model_loaded:
152
+ st.warning("Please load the model first by clicking the 'Load Model' button.")
153
+ else:
154
+ with st.spinner("Analyzing the image... This may take 15-30 seconds"):
155
+ try:
156
+ # Get components from session state
157
+ model = st.session_state['model']
158
+ processor = st.session_state['processor']
159
+
160
+ # Process the image using the processor
161
+ inputs = processor(text=custom_prompt, images=image, return_tensors="pt")
162
+
163
+ # Fix cross-attention mask if needed
164
+ fixed, inputs = fix_processor_outputs(inputs)
165
+ if fixed:
166
+ st.info("Fixed cross-attention mask dimensions")
167
+
168
+ # Move to device
169
+ inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
170
+
171
+ # Generate the analysis
172
+ with torch.no_grad():
173
+ output_ids = model.generate(
174
+ **inputs,
175
+ max_new_tokens=max_length,
176
+ temperature=temperature,
177
+ top_p=0.9
178
+ )
179
+
180
+ # Decode the output
181
+ response = processor.decode(output_ids[0], skip_special_tokens=True)
182
+
183
+ # Extract the actual response (removing the prompt)
184
+ if custom_prompt in response:
185
+ result = response.split(custom_prompt)[-1].strip()
186
+ else:
187
+ result = response
188
+
189
+ # Display result in a nice format
190
+ st.success("Analysis complete!")
191
+
192
+ # Show technical and non-technical explanations separately if they exist
193
+ if "Technical Explanation:" in result and "Non-Technical Explanation:" in result:
194
+ technical, non_technical = result.split("Non-Technical Explanation:")
195
+ technical = technical.replace("Technical Explanation:", "").strip()
196
+
197
+ col1, col2 = st.columns(2)
198
+ with col1:
199
+ st.subheader("Technical Analysis")
200
+ st.write(technical)
201
+
202
+ with col2:
203
+ st.subheader("Simple Explanation")
204
+ st.write(non_technical)
205
+ else:
206
+ st.subheader("Analysis Result")
207
+ st.write(result)
208
+
209
+ # Free memory after analysis
210
+ free_memory()
211
+
212
+ except Exception as e:
213
+ st.error(f"Error analyzing image: {str(e)}")
214
+ st.exception(e)
215
+ elif not model_loaded:
216
+ st.warning("Please load the model first by clicking the 'Load Model' button at the top of the page.")
217
+ else:
218
+ st.info("Please upload an image to begin analysis")
219
+
220
+ # Add footer
221
+ st.markdown("---")
222
+ st.markdown("Deepfake Image Analyzer")