saakshigupta commited on
Commit
9938609
·
verified ·
1 Parent(s): 34b363c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -36
app.py CHANGED
@@ -1,10 +1,9 @@
1
  import streamlit as st
2
  import torch
3
- import os
4
  from PIL import Image
5
- from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
6
  from peft import PeftModel
7
  import gc
 
8
 
9
  # Page config
10
  st.set_page_config(
@@ -38,35 +37,30 @@ device = init_device()
38
 
39
  @st.cache_resource
40
  def load_model():
41
- """Load model and processor with proper dtype settings"""
42
  try:
43
- # Load base model
44
- base_model_id = "unsloth/llama-3.2-11b-vision-instruct-unsloth-bnb-4bit"
45
 
46
- # Load processor first
47
- processor = AutoProcessor.from_pretrained(base_model_id)
48
 
49
- # Configure quantization explicitly with float16
50
- quantization_config = BitsAndBytesConfig(
 
 
51
  load_in_4bit=True,
52
- bnb_4bit_compute_dtype=torch.float16,
53
- bnb_4bit_quant_type="nf4",
54
- bnb_4bit_use_double_quant=True
55
  )
56
 
57
- # Load model with explicit dtype settings
58
- model = AutoModelForCausalLM.from_pretrained(
59
- base_model_id,
60
- device_map="auto",
61
- torch_dtype=torch.float16, # Explicit float16
62
- quantization_config=quantization_config
63
- )
64
 
65
- # Load adapter
 
66
  adapter_id = "saakshigupta/deepfake-explainer-1"
67
  model = PeftModel.from_pretrained(model, adapter_id)
68
 
69
- return model, processor
70
 
71
  except Exception as e:
72
  st.error(f"Error loading model: {str(e)}")
@@ -117,10 +111,10 @@ with st.sidebar:
117
  # Load model on startup
118
  with st.spinner("Loading model... this may take a minute."):
119
  try:
120
- model, processor = load_model()
121
- if model is not None and processor is not None:
122
  st.session_state['model'] = model
123
- st.session_state['processor'] = processor
124
  st.success("Model loaded successfully!")
125
  else:
126
  st.error("Failed to load model.")
@@ -145,20 +139,33 @@ if uploaded_file is not None and model_loaded:
145
  try:
146
  # Get components from session state
147
  model = st.session_state['model']
148
- processor = st.session_state['processor']
 
 
 
 
 
 
 
 
149
 
150
- # Process the image
151
- inputs = processor(text=custom_prompt, images=image, return_tensors="pt")
152
 
153
- # Fix cross-attention mask
 
 
 
 
 
 
 
 
154
  fixed, inputs = fix_processor_outputs(inputs)
155
  if fixed:
156
  st.info("Fixed cross-attention mask dimensions")
157
 
158
- # Move to device
159
- inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
160
-
161
- # Generate the analysis
162
  with torch.no_grad():
163
  output_ids = model.generate(
164
  **inputs,
@@ -168,11 +175,12 @@ if uploaded_file is not None and model_loaded:
168
  )
169
 
170
  # Decode the output
171
- response = processor.decode(output_ids[0], skip_special_tokens=True)
172
 
173
- # Extract the actual response (removing the prompt)
174
- if custom_prompt in response:
175
- result = response.split(custom_prompt)[-1].strip()
 
176
  else:
177
  result = response
178
 
 
1
  import streamlit as st
2
  import torch
 
3
  from PIL import Image
 
4
  from peft import PeftModel
5
  import gc
6
+ import os
7
 
8
  # Page config
9
  st.set_page_config(
 
37
 
38
  @st.cache_resource
39
  def load_model():
40
+ """Load model using Unsloth, similar to your notebook code"""
41
  try:
42
+ # Import Unsloth here to ensure it's loaded when needed
43
+ from unsloth import FastVisionModel
44
 
45
+ st.info("Loading base model and tokenizer using Unsloth...")
 
46
 
47
+ # Use the same model ID and loading approach that worked in your notebook
48
+ base_model_id = "unsloth/llama-3.2-11b-vision-instruct-unsloth-bnb-4bit"
49
+ model, tokenizer = FastVisionModel.from_pretrained(
50
+ base_model_id,
51
  load_in_4bit=True,
52
+ torch_dtype=torch.float16,
 
 
53
  )
54
 
55
+ # Set to inference mode
56
+ FastVisionModel.for_inference(model)
 
 
 
 
 
57
 
58
+ # Load the fine-tuned adapter
59
+ st.info("Loading adapter...")
60
  adapter_id = "saakshigupta/deepfake-explainer-1"
61
  model = PeftModel.from_pretrained(model, adapter_id)
62
 
63
+ return model, tokenizer
64
 
65
  except Exception as e:
66
  st.error(f"Error loading model: {str(e)}")
 
111
  # Load model on startup
112
  with st.spinner("Loading model... this may take a minute."):
113
  try:
114
+ model, tokenizer = load_model()
115
+ if model is not None and tokenizer is not None:
116
  st.session_state['model'] = model
117
+ st.session_state['tokenizer'] = tokenizer
118
  st.success("Model loaded successfully!")
119
  else:
120
  st.error("Failed to load model.")
 
139
  try:
140
  # Get components from session state
141
  model = st.session_state['model']
142
+ tokenizer = st.session_state['tokenizer']
143
+
144
+ # Format the message for Unsloth - same as your notebook
145
+ messages = [
146
+ {"role": "user", "content": [
147
+ {"type": "image"},
148
+ {"type": "text", "text": custom_prompt}
149
+ ]}
150
+ ]
151
 
152
+ # Apply chat template
153
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
154
 
155
+ # Process with image
156
+ inputs = tokenizer(
157
+ image,
158
+ input_text,
159
+ add_special_tokens=False,
160
+ return_tensors="pt",
161
+ ).to(model.device)
162
+
163
+ # Apply the cross-attention fix
164
  fixed, inputs = fix_processor_outputs(inputs)
165
  if fixed:
166
  st.info("Fixed cross-attention mask dimensions")
167
 
168
+ # Generate analysis
 
 
 
169
  with torch.no_grad():
170
  output_ids = model.generate(
171
  **inputs,
 
175
  )
176
 
177
  # Decode the output
178
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
179
 
180
+ # Extract the model's response
181
+ # Format might be different from processor.decode, check the output
182
+ if "assistant" in response:
183
+ result = response.split("assistant")[-1].strip()
184
  else:
185
  result = response
186