saakshigupta commited on
Commit
42fa481
·
verified ·
1 Parent(s): 8027591

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -32
app.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  import torch
3
  import os
4
  from PIL import Image
5
- from transformers import AutoProcessor, MllamaForCausalLM, BitsAndBytesConfig
6
  from peft import PeftModel
7
  import gc
8
 
@@ -22,9 +22,8 @@ def free_memory():
22
  gc.collect()
23
  if torch.cuda.is_available():
24
  torch.cuda.empty_cache()
25
- torch.cuda.ipc_collect()
26
 
27
- # Helper functions
28
  def init_device():
29
  """Set the appropriate device and return it"""
30
  if torch.cuda.is_available():
@@ -39,13 +38,15 @@ device = init_device()
39
 
40
  @st.cache_resource
41
  def load_model():
42
- """Load model and processor with caching to avoid reloading"""
43
  try:
44
  # Load base model
45
  base_model_id = "unsloth/llama-3.2-11b-vision-instruct-unsloth-bnb-4bit"
 
 
46
  processor = AutoProcessor.from_pretrained(base_model_id)
47
 
48
- # Configure 4-bit quantization with correct dtype
49
  quantization_config = BitsAndBytesConfig(
50
  load_in_4bit=True,
51
  bnb_4bit_compute_dtype=torch.float16,
@@ -53,11 +54,11 @@ def load_model():
53
  bnb_4bit_use_double_quant=True
54
  )
55
 
56
- # Load model with explicit dtype settings using MllamaForCausalLM
57
- model = MllamaForCausalLM.from_pretrained(
58
  base_model_id,
59
  device_map="auto",
60
- torch_dtype=torch.float16,
61
  quantization_config=quantization_config
62
  )
63
 
@@ -66,9 +67,10 @@ def load_model():
66
  model = PeftModel.from_pretrained(model, adapter_id)
67
 
68
  return model, processor
69
-
70
  except Exception as e:
71
  st.error(f"Error loading model: {str(e)}")
 
72
  return None, None
73
 
74
  # Function to fix cross-attention masks
@@ -112,34 +114,25 @@ with st.sidebar:
112
  Model by [saakshigupta](https://huggingface.co/saakshigupta/deepfake-explainer-1)
113
  """)
114
 
115
- # Load model on app startup with a progress bar
116
- if 'model_loaded' not in st.session_state:
117
- progress_bar = st.progress(0)
118
- st.info("Loading model... this may take a minute.")
119
-
120
- for i in range(10):
121
- # Simulate progress while model loads
122
- progress_bar.progress((i + 1) * 10)
123
- if i == 2:
124
- # Start loading the model at 30% progress
125
- model, processor = load_model()
126
- if model is not None:
127
- st.session_state['model'] = model
128
- st.session_state['processor'] = processor
129
- st.session_state['model_loaded'] = True
130
-
131
- progress_bar.empty()
132
-
133
- if 'model_loaded' in st.session_state and st.session_state['model_loaded']:
134
- st.success("Model loaded successfully!")
135
- else:
136
- st.error("Failed to load model. Try refreshing the page.")
137
 
138
  # Main content area - file uploader
139
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
140
 
141
  # Check if model is loaded
142
- model_loaded = 'model_loaded' in st.session_state and st.session_state['model_loaded']
143
 
144
  if uploaded_file is not None and model_loaded:
145
  # Display the image
@@ -208,6 +201,9 @@ if uploaded_file is not None and model_loaded:
208
 
209
  except Exception as e:
210
  st.error(f"Error analyzing image: {str(e)}")
 
 
 
211
  else:
212
  st.info("Please upload an image to begin analysis")
213
 
 
2
  import torch
3
  import os
4
  from PIL import Image
5
+ from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
6
  from peft import PeftModel
7
  import gc
8
 
 
22
  gc.collect()
23
  if torch.cuda.is_available():
24
  torch.cuda.empty_cache()
 
25
 
26
+ # Helper function to check CUDA
27
  def init_device():
28
  """Set the appropriate device and return it"""
29
  if torch.cuda.is_available():
 
38
 
39
  @st.cache_resource
40
  def load_model():
41
+ """Load model and processor with proper dtype settings"""
42
  try:
43
  # Load base model
44
  base_model_id = "unsloth/llama-3.2-11b-vision-instruct-unsloth-bnb-4bit"
45
+
46
+ # Load processor first
47
  processor = AutoProcessor.from_pretrained(base_model_id)
48
 
49
+ # Configure quantization explicitly with float16
50
  quantization_config = BitsAndBytesConfig(
51
  load_in_4bit=True,
52
  bnb_4bit_compute_dtype=torch.float16,
 
54
  bnb_4bit_use_double_quant=True
55
  )
56
 
57
+ # Load model with explicit dtype settings
58
+ model = AutoModelForCausalLM.from_pretrained(
59
  base_model_id,
60
  device_map="auto",
61
+ torch_dtype=torch.float16, # Explicit float16
62
  quantization_config=quantization_config
63
  )
64
 
 
67
  model = PeftModel.from_pretrained(model, adapter_id)
68
 
69
  return model, processor
70
+
71
  except Exception as e:
72
  st.error(f"Error loading model: {str(e)}")
73
+ st.exception(e)
74
  return None, None
75
 
76
  # Function to fix cross-attention masks
 
114
  Model by [saakshigupta](https://huggingface.co/saakshigupta/deepfake-explainer-1)
115
  """)
116
 
117
+ # Load model on startup
118
+ with st.spinner("Loading model... this may take a minute."):
119
+ try:
120
+ model, processor = load_model()
121
+ if model is not None and processor is not None:
122
+ st.session_state['model'] = model
123
+ st.session_state['processor'] = processor
124
+ st.success("Model loaded successfully!")
125
+ else:
126
+ st.error("Failed to load model.")
127
+ except Exception as e:
128
+ st.error(f"Error during model loading: {str(e)}")
129
+ st.exception(e)
 
 
 
 
 
 
 
 
 
130
 
131
  # Main content area - file uploader
132
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
133
 
134
  # Check if model is loaded
135
+ model_loaded = 'model' in st.session_state and st.session_state['model'] is not None
136
 
137
  if uploaded_file is not None and model_loaded:
138
  # Display the image
 
201
 
202
  except Exception as e:
203
  st.error(f"Error analyzing image: {str(e)}")
204
+ st.exception(e)
205
+ elif not model_loaded and uploaded_file is not None:
206
+ st.warning("Model not loaded correctly. Try refreshing the page.")
207
  else:
208
  st.info("Please upload an image to begin analysis")
209