saakshigupta commited on
Commit
eab7cdf
Β·
verified Β·
1 Parent(s): aa3f85c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -168
app.py CHANGED
@@ -1,21 +1,17 @@
1
  import streamlit as st
2
  import torch
3
  from PIL import Image
4
- import os
5
  import gc
6
- from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
7
  from peft import PeftModel
 
8
 
9
- # Page config
10
- st.set_page_config(
11
- page_title="Deepfake Image Analyzer",
12
- page_icon="πŸ”",
13
- layout="wide"
14
- )
15
 
16
- # App title and description
17
  st.title("Deepfake Image Analyzer")
18
- st.markdown("Upload an image to analyze it for possible deepfake manipulation")
19
 
20
  # Function to free up memory
21
  def free_memory():
@@ -23,51 +19,41 @@ def free_memory():
23
  if torch.cuda.is_available():
24
  torch.cuda.empty_cache()
25
 
26
- # Helper function to check CUDA
27
- def init_device():
28
- if torch.cuda.is_available():
29
- st.sidebar.success("βœ“ GPU available: Using CUDA")
30
- return "cuda"
31
- else:
32
- st.sidebar.warning("⚠️ No GPU detected: Using CPU (analysis will be slow)")
33
- return "cpu"
34
-
35
- # Set device
36
- device = init_device()
 
 
37
 
 
38
  @st.cache_resource
39
  def load_model():
40
- """Load pre-quantized model"""
41
  try:
42
- # Using your original base model
43
  base_model_id = "unsloth/llama-3.2-11b-vision-instruct-unsloth-bnb-4bit"
44
 
45
  # Load processor
46
  processor = AutoProcessor.from_pretrained(base_model_id)
47
 
48
- # Configure quantization settings for unsloth model
49
- quantization_config = BitsAndBytesConfig(
50
- load_in_4bit=True,
51
- bnb_4bit_compute_dtype=torch.float16,
52
- bnb_4bit_use_double_quant=True,
53
- bnb_4bit_quant_type="nf4",
54
- bnb_4bit_quant_storage=torch.float16,
55
- llm_int8_skip_modules=["lm_head"],
56
- llm_int8_enable_fp32_cpu_offload=True
57
- )
58
-
59
- # Load the pre-quantized model with unsloth settings
60
- model = AutoModelForCausalLM.from_pretrained(
61
  base_model_id,
62
- device_map="auto",
63
- quantization_config=quantization_config,
64
  torch_dtype=torch.float16,
65
- trust_remote_code=True,
66
- low_cpu_mem_usage=True,
67
- use_cache=True,
68
- offload_folder="offload" # Enable disk offloading
69
  )
70
 
 
 
 
71
  # Load adapter
72
  adapter_id = "saakshigupta/deepfake-explainer-1"
73
  model = PeftModel.from_pretrained(model, adapter_id)
@@ -79,144 +65,110 @@ def load_model():
79
  st.exception(e)
80
  return None, None
81
 
82
- # Function to fix cross-attention masks
83
- def fix_processor_outputs(inputs):
84
- """Fix cross-attention mask dimensions if needed"""
85
- if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
86
- batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
87
- visual_features = 6404 # The exact dimension used in training
88
- new_mask = torch.ones(
89
- (batch_size, seq_len, visual_features, num_tiles),
90
- device=inputs['cross_attention_mask'].device
91
- )
92
- inputs['cross_attention_mask'] = new_mask
93
- return True, inputs
94
- return False, inputs
95
-
96
- # Create sidebar with options
97
  with st.sidebar:
98
- st.header("Options")
99
- temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1,
100
- help="Higher values make output more random, lower values more deterministic")
101
- max_length = st.slider("Maximum response length", min_value=100, max_value=1000, value=500, step=50)
102
 
103
- custom_prompt = st.text_area(
104
- "Custom instruction (optional)",
105
- value="Analyze this image and determine if it's a deepfake. Provide both technical and non-technical explanations.",
 
106
  height=100
107
  )
108
-
109
- st.markdown("### About")
110
- st.markdown("""
111
- This app uses a fine-tuned Llama 3.2 Vision model to detect and explain deepfakes.
112
-
113
- The analyzer looks for:
114
- - Inconsistencies in facial features
115
- - Unusual lighting or shadows
116
- - Unnatural blur patterns
117
- - Artifacts around edges
118
- - Texture inconsistencies
119
-
120
- Model by [saakshigupta](https://huggingface.co/saakshigupta/deepfake-explainer-1)
121
- """)
122
 
123
- # Load model button
124
- if st.button("Load Model"):
125
- with st.spinner("Loading model... this may take several minutes"):
126
- try:
 
 
 
127
  model, processor = load_model()
128
  if model is not None and processor is not None:
129
  st.session_state['model'] = model
130
  st.session_state['processor'] = processor
131
- st.success("Model loaded successfully!")
132
  else:
133
- st.error("Failed to load model.")
134
- except Exception as e:
135
- st.error(f"Error during model loading: {str(e)}")
136
- st.exception(e)
137
-
138
- # Main content area - file uploader
139
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
140
-
141
- # Check if model is loaded
142
- model_loaded = 'model' in st.session_state and st.session_state['model'] is not None
143
-
144
- if uploaded_file is not None:
145
- # Display the image
146
- image = Image.open(uploaded_file).convert('RGB')
147
- st.image(image, caption="Uploaded Image", use_column_width=True)
148
 
149
- # Analyze button (only enabled if model is loaded)
150
- if st.button("Analyze Image", disabled=not model_loaded):
151
- if not model_loaded:
152
- st.warning("Please load the model first by clicking the 'Load Model' button.")
153
- else:
154
- with st.spinner("Analyzing the image... This may take 15-30 seconds"):
155
- try:
156
- # Get components from session state
157
- model = st.session_state['model']
158
- processor = st.session_state['processor']
159
-
160
- # Process the image using the processor
161
- inputs = processor(text=custom_prompt, images=image, return_tensors="pt")
162
-
163
- # Fix cross-attention mask if needed
164
- fixed, inputs = fix_processor_outputs(inputs)
165
- if fixed:
166
- st.info("Fixed cross-attention mask dimensions")
167
-
168
- # Move to device
169
- inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
170
-
171
- # Generate the analysis
172
- with torch.no_grad():
173
- output_ids = model.generate(
174
- **inputs,
175
- max_new_tokens=max_length,
176
- temperature=temperature,
177
- top_p=0.9
 
 
178
  )
179
-
180
- # Decode the output
181
- response = processor.decode(output_ids[0], skip_special_tokens=True)
182
-
183
- # Extract the actual response (removing the prompt)
184
- if custom_prompt in response:
185
- result = response.split(custom_prompt)[-1].strip()
186
- else:
187
- result = response
188
-
189
- # Display result in a nice format
190
- st.success("Analysis complete!")
191
-
192
- # Show technical and non-technical explanations separately if they exist
193
- if "Technical Explanation:" in result and "Non-Technical Explanation:" in result:
194
- technical, non_technical = result.split("Non-Technical Explanation:")
195
- technical = technical.replace("Technical Explanation:", "").strip()
196
 
197
- col1, col2 = st.columns(2)
198
- with col1:
199
- st.subheader("Technical Analysis")
200
- st.write(technical)
 
 
 
201
 
202
- with col2:
203
- st.subheader("Simple Explanation")
204
- st.write(non_technical)
205
- else:
206
- st.subheader("Analysis Result")
207
- st.write(result)
208
-
209
- # Free memory after analysis
210
- free_memory()
211
-
212
- except Exception as e:
213
- st.error(f"Error analyzing image: {str(e)}")
214
- st.exception(e)
215
- elif not model_loaded:
216
- st.warning("Please load the model first by clicking the 'Load Model' button at the top of the page.")
217
- else:
218
- st.info("Please upload an image to begin analysis")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- # Add footer
221
- st.markdown("---")
222
- st.markdown("Deepfake Image Analyzer")
 
1
  import streamlit as st
2
  import torch
3
  from PIL import Image
 
4
  import gc
5
+ from transformers import AutoProcessor
6
  from peft import PeftModel
7
+ from unsloth import FastVisionModel
8
 
9
+ # Simple page config
10
+ st.set_page_config(page_title="Deepfake Analyzer", layout="wide")
 
 
 
 
11
 
12
+ # Minimal UI
13
  st.title("Deepfake Image Analyzer")
14
+ st.markdown("This app analyzes images for signs of deepfake manipulation")
15
 
16
  # Function to free up memory
17
  def free_memory():
 
19
  if torch.cuda.is_available():
20
  torch.cuda.empty_cache()
21
 
22
+ # Function to fix cross-attention masks
23
+ def fix_processor_outputs(inputs):
24
+ """Fix cross-attention mask dimensions if needed"""
25
+ if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
26
+ batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
27
+ visual_features = 6404 # The exact dimension used in training
28
+ new_mask = torch.ones(
29
+ (batch_size, seq_len, visual_features, num_tiles),
30
+ device=inputs['cross_attention_mask'].device
31
+ )
32
+ inputs['cross_attention_mask'] = new_mask
33
+ return True, inputs
34
+ return False, inputs
35
 
36
+ # Load model function
37
  @st.cache_resource
38
  def load_model():
39
+ """Load model using Unsloth approach (similar to Colab)"""
40
  try:
 
41
  base_model_id = "unsloth/llama-3.2-11b-vision-instruct-unsloth-bnb-4bit"
42
 
43
  # Load processor
44
  processor = AutoProcessor.from_pretrained(base_model_id)
45
 
46
+ # Load model using Unsloth's FastVisionModel
47
+ model, _ = FastVisionModel.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
48
  base_model_id,
49
+ load_in_4bit=True,
 
50
  torch_dtype=torch.float16,
51
+ device_map="auto"
 
 
 
52
  )
53
 
54
+ # Set to inference mode
55
+ FastVisionModel.for_inference(model)
56
+
57
  # Load adapter
58
  adapter_id = "saakshigupta/deepfake-explainer-1"
59
  model = PeftModel.from_pretrained(model, adapter_id)
 
65
  st.exception(e)
66
  return None, None
67
 
68
+ # Minimal sidebar
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  with st.sidebar:
70
+ st.header("Settings")
71
+ temperature = st.slider("Temperature", 0.1, 1.0, 0.7, 0.1)
72
+ max_length = st.slider("Max length", 100, 500, 300, 50)
 
73
 
74
+ # Instruction field
75
+ prompt = st.text_area(
76
+ "Analysis instruction",
77
+ value="Analyze this image and determine if it's a deepfake. Provide your reasoning.",
78
  height=100
79
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ # Main content - two columns for clarity
82
+ col1, col2 = st.columns([1, 2])
83
+
84
+ with col1:
85
+ # Load model button
86
+ if st.button("1. Load Model"):
87
+ with st.spinner("Loading model... (this may take a minute)"):
88
  model, processor = load_model()
89
  if model is not None and processor is not None:
90
  st.session_state['model'] = model
91
  st.session_state['processor'] = processor
92
+ st.success("βœ“ Model loaded successfully!")
93
  else:
94
+ st.error("Failed to load model")
95
+
96
+ # File uploader
97
+ uploaded_file = st.file_uploader("2. Upload an image", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ # Display uploaded image
100
+ if uploaded_file is not None:
101
+ image = Image.open(uploaded_file).convert('RGB')
102
+ st.image(image, caption="Uploaded Image", use_column_width=True)
103
+
104
+ # Only enable analysis if model is loaded
105
+ model_loaded = 'model' in st.session_state and st.session_state['model'] is not None
106
+
107
+ if st.button("3. Analyze Image", disabled=not model_loaded):
108
+ if not model_loaded:
109
+ st.warning("Please load the model first")
110
+ else:
111
+ col2.subheader("Analysis Results")
112
+ with col2.spinner("Analyzing image..."):
113
+ try:
114
+ # Get model components
115
+ model = st.session_state['model']
116
+ processor = st.session_state['processor']
117
+
118
+ # Format message for analysis
119
+ messages = [
120
+ {"role": "user", "content": [
121
+ {"type": "image"},
122
+ {"type": "text", "text": prompt}
123
+ ]}
124
+ ]
125
+
126
+ # Apply chat template
127
+ input_text = processor.tokenizer.apply_chat_template(
128
+ messages,
129
+ add_generation_prompt=True
130
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ # Process with image
133
+ inputs = processor(
134
+ images=image,
135
+ text=input_text,
136
+ add_special_tokens=False,
137
+ return_tensors="pt"
138
+ ).to(model.device)
139
 
140
+ # Apply the fix
141
+ fixed, inputs = fix_processor_outputs(inputs)
142
+ if fixed:
143
+ col2.info("Fixed cross-attention mask dimensions")
144
+
145
+ # Generate analysis
146
+ with torch.no_grad():
147
+ output_ids = model.generate(
148
+ **inputs,
149
+ max_new_tokens=max_length,
150
+ temperature=temperature,
151
+ top_p=0.9
152
+ )
153
+
154
+ # Decode the output
155
+ response = processor.tokenizer.decode(output_ids[0], skip_special_tokens=True)
156
+
157
+ # Display results
158
+ col2.success("Analysis complete!")
159
+ col2.markdown(response)
160
+
161
+ # Free memory
162
+ free_memory()
163
+
164
+ except Exception as e:
165
+ col2.error(f"Error analyzing image: {str(e)}")
166
+ col2.exception(e)
167
+ elif not model_loaded:
168
+ st.info("Please load the model first (Step 1)")
169
+ else:
170
+ st.info("Please upload an image (Step 2)")
171
 
172
+ with col2:
173
+ if 'model' not in st.session_state:
174
+ st.info("πŸ‘ˆ Follow the steps on the left to analyze an image")