saakshigupta commited on
Commit
6e58aae
·
verified ·
1 Parent(s): eab7cdf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -149
app.py CHANGED
@@ -1,174 +1,128 @@
1
  import streamlit as st
2
  import torch
3
  from PIL import Image
4
- import gc
5
- from transformers import AutoProcessor
6
  from peft import PeftModel
7
  from unsloth import FastVisionModel
 
 
8
 
9
- # Simple page config
10
  st.set_page_config(page_title="Deepfake Analyzer", layout="wide")
11
-
12
- # Minimal UI
13
  st.title("Deepfake Image Analyzer")
14
- st.markdown("This app analyzes images for signs of deepfake manipulation")
15
-
16
- # Function to free up memory
17
- def free_memory():
18
- gc.collect()
19
- if torch.cuda.is_available():
20
- torch.cuda.empty_cache()
21
 
22
  # Function to fix cross-attention masks
23
- def fix_processor_outputs(inputs):
24
- """Fix cross-attention mask dimensions if needed"""
25
  if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
26
  batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
27
- visual_features = 6404 # The exact dimension used in training
28
- new_mask = torch.ones(
29
- (batch_size, seq_len, visual_features, num_tiles),
30
- device=inputs['cross_attention_mask'].device
31
- )
32
  inputs['cross_attention_mask'] = new_mask
33
- return True, inputs
34
- return False, inputs
35
 
36
  # Load model function
37
  @st.cache_resource
38
  def load_model():
39
- """Load model using Unsloth approach (similar to Colab)"""
40
- try:
41
- base_model_id = "unsloth/llama-3.2-11b-vision-instruct-unsloth-bnb-4bit"
42
-
43
- # Load processor
44
- processor = AutoProcessor.from_pretrained(base_model_id)
45
-
46
- # Load model using Unsloth's FastVisionModel
47
- model, _ = FastVisionModel.from_pretrained(
48
- base_model_id,
49
- load_in_4bit=True,
50
- torch_dtype=torch.float16,
51
- device_map="auto"
52
- )
53
-
54
- # Set to inference mode
55
- FastVisionModel.for_inference(model)
56
-
57
- # Load adapter
58
- adapter_id = "saakshigupta/deepfake-explainer-1"
59
- model = PeftModel.from_pretrained(model, adapter_id)
60
-
61
- return model, processor
62
-
63
- except Exception as e:
64
- st.error(f"Error loading model: {str(e)}")
65
- st.exception(e)
66
- return None, None
67
 
68
- # Minimal sidebar
69
- with st.sidebar:
70
- st.header("Settings")
71
- temperature = st.slider("Temperature", 0.1, 1.0, 0.7, 0.1)
72
- max_length = st.slider("Max length", 100, 500, 300, 50)
73
-
74
- # Instruction field
75
- prompt = st.text_area(
76
- "Analysis instruction",
77
- value="Analyze this image and determine if it's a deepfake. Provide your reasoning.",
78
- height=100
79
- )
80
 
81
- # Main content - two columns for clarity
82
- col1, col2 = st.columns([1, 2])
 
 
 
 
 
83
 
84
- with col1:
85
- # Load model button
86
- if st.button("1. Load Model"):
87
- with st.spinner("Loading model... (this may take a minute)"):
88
- model, processor = load_model()
89
- if model is not None and processor is not None:
90
- st.session_state['model'] = model
91
- st.session_state['processor'] = processor
92
- st.success("✓ Model loaded successfully!")
93
- else:
94
- st.error("Failed to load model")
95
-
96
- # File uploader
97
- uploaded_file = st.file_uploader("2. Upload an image", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- # Display uploaded image
100
- if uploaded_file is not None:
101
- image = Image.open(uploaded_file).convert('RGB')
102
- st.image(image, caption="Uploaded Image", use_column_width=True)
 
 
103
 
104
- # Only enable analysis if model is loaded
105
- model_loaded = 'model' in st.session_state and st.session_state['model'] is not None
 
106
 
107
- if st.button("3. Analyze Image", disabled=not model_loaded):
108
- if not model_loaded:
109
- st.warning("Please load the model first")
110
- else:
111
- col2.subheader("Analysis Results")
112
- with col2.spinner("Analyzing image..."):
113
- try:
114
- # Get model components
115
- model = st.session_state['model']
116
- processor = st.session_state['processor']
117
-
118
- # Format message for analysis
119
- messages = [
120
- {"role": "user", "content": [
121
- {"type": "image"},
122
- {"type": "text", "text": prompt}
123
- ]}
124
- ]
125
-
126
- # Apply chat template
127
- input_text = processor.tokenizer.apply_chat_template(
128
- messages,
129
- add_generation_prompt=True
130
- )
131
-
132
- # Process with image
133
- inputs = processor(
134
- images=image,
135
- text=input_text,
136
- add_special_tokens=False,
137
- return_tensors="pt"
138
- ).to(model.device)
139
-
140
- # Apply the fix
141
- fixed, inputs = fix_processor_outputs(inputs)
142
- if fixed:
143
- col2.info("Fixed cross-attention mask dimensions")
144
-
145
- # Generate analysis
146
- with torch.no_grad():
147
- output_ids = model.generate(
148
- **inputs,
149
- max_new_tokens=max_length,
150
- temperature=temperature,
151
- top_p=0.9
152
- )
153
-
154
- # Decode the output
155
- response = processor.tokenizer.decode(output_ids[0], skip_special_tokens=True)
156
-
157
- # Display results
158
- col2.success("Analysis complete!")
159
- col2.markdown(response)
160
-
161
- # Free memory
162
- free_memory()
163
-
164
- except Exception as e:
165
- col2.error(f"Error analyzing image: {str(e)}")
166
- col2.exception(e)
167
- elif not model_loaded:
168
- st.info("Please load the model first (Step 1)")
169
  else:
170
- st.info("Please upload an image (Step 2)")
171
 
172
- with col2:
173
- if 'model' not in st.session_state:
174
- st.info("👈 Follow the steps on the left to analyze an image")
 
1
  import streamlit as st
2
  import torch
3
  from PIL import Image
4
+ import io
 
5
  from peft import PeftModel
6
  from unsloth import FastVisionModel
7
+ import tempfile
8
+ import os
9
 
10
+ # App title and description
11
  st.set_page_config(page_title="Deepfake Analyzer", layout="wide")
 
 
12
  st.title("Deepfake Image Analyzer")
13
+ st.markdown("Upload an image to analyze it for potential deepfake manipulation")
 
 
 
 
 
 
14
 
15
  # Function to fix cross-attention masks
16
+ def fix_cross_attention_mask(inputs):
 
17
  if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
18
  batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
19
+ visual_features = 6404 # Critical dimension
20
+ new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles),
21
+ device=inputs['cross_attention_mask'].device)
 
 
22
  inputs['cross_attention_mask'] = new_mask
23
+ st.success("Fixed cross-attention mask dimensions")
24
+ return inputs
25
 
26
  # Load model function
27
  @st.cache_resource
28
  def load_model():
29
+ with st.spinner("Loading model... This may take a minute or two..."):
30
+ try:
31
+ # Load base model and tokenizer using Unsloth
32
+ base_model_id = "unsloth/llama-3.2-11b-vision-instruct"
33
+ model, tokenizer = FastVisionModel.from_pretrained(
34
+ base_model_id,
35
+ load_in_4bit=True,
36
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # Load the adapter
39
+ adapter_id = "saakshigupta/deepfake-explainer-1"
40
+ model = PeftModel.from_pretrained(model, adapter_id)
 
 
 
 
 
 
 
 
 
41
 
42
+ # Set to inference mode
43
+ FastVisionModel.for_inference(model)
44
+
45
+ return model, tokenizer
46
+ except Exception as e:
47
+ st.error(f"Error loading model: {str(e)}")
48
+ return None, None
49
 
50
+ # Analyze image function
51
+ def analyze_image(image, question, model, tokenizer):
52
+ # Format the message
53
+ messages = [
54
+ {"role": "user", "content": [
55
+ {"type": "image"},
56
+ {"type": "text", "text": question}
57
+ ]}
58
+ ]
59
+
60
+ # Apply chat template
61
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
62
+
63
+ # Process with image
64
+ inputs = tokenizer(
65
+ image,
66
+ input_text,
67
+ add_special_tokens=False,
68
+ return_tensors="pt",
69
+ ).to(model.device)
70
+
71
+ # Fix cross-attention mask if needed
72
+ inputs = fix_cross_attention_mask(inputs)
73
+
74
+ # Generate response
75
+ with st.spinner("Analyzing image... (this may take a moment)"):
76
+ with torch.no_grad():
77
+ output_ids = model.generate(
78
+ **inputs,
79
+ max_new_tokens=512,
80
+ use_cache=True,
81
+ temperature=0.7,
82
+ top_p=0.9
83
+ )
84
+
85
+ # Decode the output
86
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
87
+
88
+ # Try to extract just the model's response (after the prompt)
89
+ if question in response:
90
+ result = response.split(question)[-1].strip()
91
+ else:
92
+ result = response
93
+
94
+ return result
95
+
96
+ # Main app
97
+ def main():
98
+ # Load model
99
+ model, tokenizer = load_model()
100
 
101
+ if model is not None and tokenizer is not None:
102
+ st.success("✅ Model loaded successfully! You can now analyze images.")
103
+
104
+ # Image upload section
105
+ st.subheader("Upload an Image")
106
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
107
 
108
+ # Default question with option to customize
109
+ default_question = "Analyze this image and tell me if it's a deepfake. Provide both technical and non-technical explanations."
110
+ question = st.text_area("Question/Prompt:", value=default_question, height=100)
111
 
112
+ if uploaded_file is not None:
113
+ # Display the uploaded image
114
+ image = Image.open(uploaded_file).convert("RGB")
115
+ st.image(image, caption="Uploaded Image", use_column_width=True)
116
+
117
+ # Analyze button
118
+ if st.button("Analyze Image"):
119
+ result = analyze_image(image, question, model, tokenizer)
120
+
121
+ # Display results
122
+ st.subheader("Analysis Results")
123
+ st.markdown(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  else:
125
+ st.warning("Failed to load the model. Please check the console for errors.")
126
 
127
+ if __name__ == "__main__":
128
+ main()