saakshigupta commited on
Commit
fa29b79
·
verified ·
1 Parent(s): b89d72c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -94
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import streamlit as st
2
  import torch
3
  from PIL import Image
4
- from peft import PeftModel
5
- import gc
6
  import os
 
7
 
8
  # Page config
9
  st.set_page_config(
@@ -39,7 +38,8 @@ device = init_device()
39
  def load_model():
40
  """Load model using Unsloth, similar to your notebook code"""
41
  try:
42
- # Import Unsloth here to ensure it's loaded when needed
 
43
  from unsloth import FastVisionModel
44
 
45
  st.info("Loading base model and tokenizer using Unsloth...")
@@ -108,19 +108,20 @@ with st.sidebar:
108
  Model by [saakshigupta](https://huggingface.co/saakshigupta/deepfake-explainer-1)
109
  """)
110
 
111
- # Load model on startup
112
- with st.spinner("Loading model... this may take a minute."):
113
- try:
114
- model, tokenizer = load_model()
115
- if model is not None and tokenizer is not None:
116
- st.session_state['model'] = model
117
- st.session_state['tokenizer'] = tokenizer
118
- st.success("Model loaded successfully!")
119
- else:
120
- st.error("Failed to load model.")
121
- except Exception as e:
122
- st.error(f"Error during model loading: {str(e)}")
123
- st.exception(e)
 
124
 
125
  # Main content area - file uploader
126
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
@@ -128,90 +129,92 @@ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png
128
  # Check if model is loaded
129
  model_loaded = 'model' in st.session_state and st.session_state['model'] is not None
130
 
131
- if uploaded_file is not None and model_loaded:
132
  # Display the image
133
  image = Image.open(uploaded_file).convert('RGB')
134
  st.image(image, caption="Uploaded Image", use_column_width=True)
135
 
136
- # Analyze button
137
- if st.button("Analyze Image"):
138
- with st.spinner("Analyzing the image... This may take 15-30 seconds"):
139
- try:
140
- # Get components from session state
141
- model = st.session_state['model']
142
- tokenizer = st.session_state['tokenizer']
143
-
144
- # Format the message for Unsloth - same as your notebook
145
- messages = [
146
- {"role": "user", "content": [
147
- {"type": "image"},
148
- {"type": "text", "text": custom_prompt}
149
- ]}
150
- ]
151
-
152
- # Apply chat template
153
- input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
154
-
155
- # Process with image
156
- inputs = tokenizer(
157
- image,
158
- input_text,
159
- add_special_tokens=False,
160
- return_tensors="pt",
161
- ).to(model.device)
162
-
163
- # Apply the cross-attention fix
164
- fixed, inputs = fix_processor_outputs(inputs)
165
- if fixed:
166
- st.info("Fixed cross-attention mask dimensions")
167
-
168
- # Generate analysis
169
- with torch.no_grad():
170
- output_ids = model.generate(
171
- **inputs,
172
- max_new_tokens=max_length,
173
- temperature=temperature,
174
- top_p=0.9
175
- )
176
-
177
- # Decode the output
178
- response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
179
-
180
- # Extract the model's response
181
- # Format might be different from processor.decode, check the output
182
- if "assistant" in response:
183
- result = response.split("assistant")[-1].strip()
184
- else:
185
- result = response
186
-
187
- # Display result in a nice format
188
- st.success("Analysis complete!")
189
-
190
- # Show technical and non-technical explanations separately if they exist
191
- if "Technical Explanation:" in result and "Non-Technical Explanation:" in result:
192
- technical, non_technical = result.split("Non-Technical Explanation:")
193
- technical = technical.replace("Technical Explanation:", "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- col1, col2 = st.columns(2)
196
- with col1:
197
- st.subheader("Technical Analysis")
198
- st.write(technical)
199
 
200
- with col2:
201
- st.subheader("Simple Explanation")
202
- st.write(non_technical)
203
- else:
204
- st.subheader("Analysis Result")
205
- st.write(result)
206
-
207
- # Free memory after analysis
208
- free_memory()
209
-
210
- except Exception as e:
211
- st.error(f"Error analyzing image: {str(e)}")
212
- st.exception(e)
213
- elif not model_loaded and uploaded_file is not None:
214
- st.warning("Model not loaded correctly. Try refreshing the page.")
215
  else:
216
  st.info("Please upload an image to begin analysis")
217
 
 
1
  import streamlit as st
2
  import torch
3
  from PIL import Image
 
 
4
  import os
5
+ import gc
6
 
7
  # Page config
8
  st.set_page_config(
 
38
  def load_model():
39
  """Load model using Unsloth, similar to your notebook code"""
40
  try:
41
+ # Import libraries here to ensure they're loaded when needed
42
+ from peft import PeftModel
43
  from unsloth import FastVisionModel
44
 
45
  st.info("Loading base model and tokenizer using Unsloth...")
 
108
  Model by [saakshigupta](https://huggingface.co/saakshigupta/deepfake-explainer-1)
109
  """)
110
 
111
+ # Load model button
112
+ if st.button("Load Model"):
113
+ with st.spinner("Loading model... this may take a minute."):
114
+ try:
115
+ model, tokenizer = load_model()
116
+ if model is not None and tokenizer is not None:
117
+ st.session_state['model'] = model
118
+ st.session_state['tokenizer'] = tokenizer
119
+ st.success("Model loaded successfully!")
120
+ else:
121
+ st.error("Failed to load model.")
122
+ except Exception as e:
123
+ st.error(f"Error during model loading: {str(e)}")
124
+ st.exception(e)
125
 
126
  # Main content area - file uploader
127
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
129
  # Check if model is loaded
130
  model_loaded = 'model' in st.session_state and st.session_state['model'] is not None
131
 
132
+ if uploaded_file is not None:
133
  # Display the image
134
  image = Image.open(uploaded_file).convert('RGB')
135
  st.image(image, caption="Uploaded Image", use_column_width=True)
136
 
137
+ # Analyze button (only enabled if model is loaded)
138
+ if st.button("Analyze Image", disabled=not model_loaded):
139
+ if not model_loaded:
140
+ st.warning("Please load the model first by clicking the 'Load Model' button.")
141
+ else:
142
+ with st.spinner("Analyzing the image... This may take 15-30 seconds"):
143
+ try:
144
+ # Get components from session state
145
+ model = st.session_state['model']
146
+ tokenizer = st.session_state['tokenizer']
147
+
148
+ # Format the message for Unsloth - same as your notebook
149
+ messages = [
150
+ {"role": "user", "content": [
151
+ {"type": "image"},
152
+ {"type": "text", "text": custom_prompt}
153
+ ]}
154
+ ]
155
+
156
+ # Apply chat template
157
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
158
+
159
+ # Process with image
160
+ inputs = tokenizer(
161
+ image,
162
+ input_text,
163
+ add_special_tokens=False,
164
+ return_tensors="pt",
165
+ ).to(model.device)
166
+
167
+ # Apply the cross-attention fix
168
+ fixed, inputs = fix_processor_outputs(inputs)
169
+ if fixed:
170
+ st.info("Fixed cross-attention mask dimensions")
171
+
172
+ # Generate analysis
173
+ with torch.no_grad():
174
+ output_ids = model.generate(
175
+ **inputs,
176
+ max_new_tokens=max_length,
177
+ temperature=temperature,
178
+ top_p=0.9
179
+ )
180
+
181
+ # Decode the output
182
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
183
+
184
+ # Extract the model's response
185
+ if "assistant" in response:
186
+ result = response.split("assistant")[-1].strip()
187
+ else:
188
+ result = response
189
+
190
+ # Display result in a nice format
191
+ st.success("Analysis complete!")
192
+
193
+ # Show technical and non-technical explanations separately if they exist
194
+ if "Technical Explanation:" in result and "Non-Technical Explanation:" in result:
195
+ technical, non_technical = result.split("Non-Technical Explanation:")
196
+ technical = technical.replace("Technical Explanation:", "").strip()
197
+
198
+ col1, col2 = st.columns(2)
199
+ with col1:
200
+ st.subheader("Technical Analysis")
201
+ st.write(technical)
202
+
203
+ with col2:
204
+ st.subheader("Simple Explanation")
205
+ st.write(non_technical)
206
+ else:
207
+ st.subheader("Analysis Result")
208
+ st.write(result)
209
 
210
+ # Free memory after analysis
211
+ free_memory()
 
 
212
 
213
+ except Exception as e:
214
+ st.error(f"Error analyzing image: {str(e)}")
215
+ st.exception(e)
216
+ elif not model_loaded:
217
+ st.warning("Please load the model first by clicking the 'Load Model' button at the top of the page.")
 
 
 
 
 
 
 
 
 
 
218
  else:
219
  st.info("Please upload an image to begin analysis")
220