Spaces:
Running
Running
File size: 15,753 Bytes
8e31ab1 94ee0c6 8e31ab1 8ec9ef4 f87dcd8 8e31ab1 f87dcd8 8e31ab1 f87dcd8 94ee0c6 8ec9ef4 f87dcd8 8e31ab1 f87dcd8 8ec9ef4 f87dcd8 8ec9ef4 f87dcd8 8ec9ef4 f87dcd8 8ec9ef4 f87dcd8 8ec9ef4 f87dcd8 8ec9ef4 f87dcd8 8ec9ef4 f87dcd8 2144e66 f87dcd8 8ec9ef4 f87dcd8 1abfce8 f87dcd8 1abfce8 f87dcd8 1abfce8 8ec9ef4 1abfce8 8e31ab1 1abfce8 f87dcd8 1abfce8 f87dcd8 1abfce8 f87dcd8 bd91e22 f87dcd8 8e31ab1 1abfce8 f87dcd8 8e31ab1 1abfce8 f87dcd8 8e31ab1 1abfce8 f87dcd8 1abfce8 8e31ab1 1abfce8 f87dcd8 1abfce8 f87dcd8 1abfce8 f87dcd8 1abfce8 f87dcd8 1abfce8 8e31ab1 f87dcd8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel
from PIL import Image
import logging
import spaces
import numpy
# Setup logging
logging.basicConfig(level=logging.INFO)
class LLaVAPhiModel:
def __init__(self, model_id="microsoft/phi-1_5"): # Updated to match config
self.device = "cuda"
self.model_id = model_id
logging.info(f"Initializing LLaVA-Phi model with {model_id}...")
# Initialize tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
try:
# Use CLIPProcessor with the correct model name from config
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
logging.info("Successfully loaded CLIP processor")
except Exception as e:
logging.error(f"Failed to load CLIP processor: {str(e)}")
self.processor = None
# Increase history length to retain more context
self.history = []
self.model = None
self.clip = None
# Default generation parameters - can be updated from config
self.temperature = 0.3
self.top_p = 0.92
self.top_k = 50
self.repetition_penalty = 1.2
# Set max length from config
self.max_length = 512 # Default value, will be updated from config
@spaces.GPU
def ensure_models_loaded(self):
"""Ensure models are loaded in GPU context"""
if self.model is None:
# Use 4-bit quantization according to config
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # Changed to match config
bnb_4bit_compute_dtype=torch.bfloat16, # Changed to bfloat16 to match config's mixed_precision
bnb_4bit_use_double_quant=False
)
try:
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
self.model.config.pad_token_id = self.tokenizer.eos_token_id
logging.info(f"Successfully loaded main model: {self.model_id}")
except Exception as e:
logging.error(f"Failed to load main model: {str(e)}")
raise
if self.clip is None:
try:
# Load CLIP model from config
clip_model_name = "openai/clip-vit-base-patch32" # From config
self.clip = CLIPModel.from_pretrained(clip_model_name).to(self.device)
logging.info(f"Successfully loaded CLIP model: {clip_model_name}")
except Exception as e:
logging.error(f"Failed to load CLIP model: {str(e)}")
self.clip = None
def apply_lora_config(self, lora_params):
"""Apply LoRA configuration to the model - to be called during training"""
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=lora_params.get("r", 16),
lora_alpha=lora_params.get("lora_alpha", 32),
lora_dropout=lora_params.get("lora_dropout", 0.05),
target_modules=lora_params.get("target_modules", ["Wqkv", "out_proj"]),
bias="none",
task_type="CAUSAL_LM"
)
# Convert model to PEFT/LoRA model
self.model = get_peft_model(self.model, lora_config)
logging.info("Applied LoRA configuration to the model")
return self.model
@spaces.GPU(duration=120)
def generate_response(self, message, image=None):
try:
self.ensure_models_loaded()
# Prepare prompt based on whether we have an image
has_image = image is not None
# Process text input
if has_image:
# For image+text input
prompt = f"human: <image>\n{message}\ngpt:"
# Check if model has vision encoding capability
if not hasattr(self.model, "encode_image") and not hasattr(self.model, "get_vision_tower"):
logging.warning("Model doesn't have standard image encoding methods")
has_image = False
prompt = f"human: {message}\ngpt:"
else:
# For text-only input
prompt = f"human: {message}\ngpt:"
# Include previous conversation context
context = ""
for turn in self.history[-5:]: # Include 5 previous turns
context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
full_prompt = context + prompt
# Tokenize the input text
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# LLaVA-Phi specific image handling
if has_image:
try:
# Convert image to correct format
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, numpy.ndarray):
image = Image.fromarray(image)
# Ensure image is in RGB mode
if image.mode != 'RGB':
image = image.convert('RGB')
# Process the image with CLIP processor
image_inputs = self.processor(images=image, return_tensors="pt")
image_features = self.clip.get_image_features(
pixel_values=image_inputs.pixel_values.to(self.device)
)
# Some LLaVA models have a prepare_inputs_for_generation method
if hasattr(self.model, "prepare_inputs_for_generation"):
logging.info("Using model's prepare_inputs_for_generation for image handling")
# Generate with image context
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
min_length=20,
temperature=self.temperature,
do_sample=True,
top_p=self.top_p,
top_k=self.top_k,
repetition_penalty=self.repetition_penalty,
no_repeat_ngram_size=3,
use_cache=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
except Exception as e:
logging.error(f"Error handling image: {str(e)}")
# Fall back to text-only generation
logging.info("Falling back to text-only generation")
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
min_length=20,
temperature=self.temperature,
do_sample=True,
top_p=self.top_p,
top_k=self.top_k,
repetition_penalty=self.repetition_penalty,
no_repeat_ngram_size=3,
use_cache=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
else:
# Text-only generation
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=200,
min_length=20,
temperature=self.temperature,
do_sample=True,
top_p=self.top_p,
top_k=self.top_k,
repetition_penalty=self.repetition_penalty,
no_repeat_ngram_size=4,
use_cache=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode and clean up the response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up response
if "gpt:" in response:
response = response.split("gpt:")[-1].strip()
if "human:" in response:
response = response.split("human:")[0].strip()
if "<image>" in response:
response = response.replace("<image>", "").strip()
self.history.append((message, response))
return response
except Exception as e:
logging.error(f"Error generating response: {str(e)}")
logging.error(f"Full traceback:", exc_info=True)
return f"Error: {str(e)}"
def clear_history(self):
self.history = []
return None
# Add new function to control generation parameters
def update_generation_params(self, temperature=0.3, top_p=0.92, top_k=50, repetition_penalty=1.2):
"""Update generation parameters to control hallucination tendency"""
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.repetition_penalty = repetition_penalty
return f"Generation parameters updated: temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}"
# New method to apply config file settings
def apply_config(self, config):
"""Apply settings from config file"""
model_params = config.get("model_params", {})
self.model_id = model_params.get("model_name", self.model_id)
self.max_length = model_params.get("max_length", 512)
# Update generation parameters if needed
training_params = config.get("training_params", {})
# Could add specific updates based on training_params if needed
return f"Applied configuration. Model: {self.model_id}, Max Length: {self.max_length}"
def create_demo(config=None):
try:
# Initialize with config file settings
model = LLaVAPhiModel()
if config:
model.apply_config(config)
with gr.Blocks(css="footer {visibility: hidden}") as demo:
gr.Markdown(
"""
# LLaVA-Phi Demo (Optimized for Accuracy)
Chat with a vision-language model that can understand both text and images.
"""
)
chatbot = gr.Chatbot(height=400)
with gr.Row():
with gr.Column(scale=0.7):
msg = gr.Textbox(
show_label=False,
placeholder="Enter text and/or upload an image",
container=False
)
with gr.Column(scale=0.15, min_width=0):
clear = gr.Button("Clear")
with gr.Column(scale=0.15, min_width=0):
submit = gr.Button("Submit", variant="primary")
image = gr.Image(type="pil", label="Upload Image (Optional)")
# Add generation parameter controls
with gr.Accordion("Advanced Settings (Reduce Hallucinations)", open=False):
gr.Markdown("Adjust these parameters to control hallucination tendency")
temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)")
top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)")
top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k")
rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
update_params = gr.Button("Update Parameters")
# Add debugging information box
debug_info = gr.Textbox(label="Debug Info", interactive=False)
# Add config information
if config:
config_info = f"Model: {model.model_id}, Max Length: {model.max_length}"
gr.Markdown(f"**Current Configuration:** {config_info}")
def respond(message, chat_history, image):
if not message and image is None:
return chat_history, ""
try:
response = model.generate_response(message, image)
chat_history.append((message, response))
debug_msg = "Response generated successfully"
return "", chat_history, debug_msg
except Exception as e:
debug_msg = f"Error: {str(e)}"
return message, chat_history, debug_msg
def clear_chat():
model.clear_history()
return None, None, "Chat history cleared"
def update_params_fn(temp, top_p, top_k, rep_penalty):
result = model.update_generation_params(temp, top_p, top_k, rep_penalty)
return f"Parameters updated: temp={temp}, top_p={top_p}, top_k={top_k}, rep_penalty={rep_penalty}"
submit.click(
respond,
[msg, chatbot, image],
[msg, chatbot, debug_info],
)
clear.click(
clear_chat,
None,
[chatbot, image, debug_info],
)
msg.submit(
respond,
[msg, chatbot, image],
[msg, chatbot, debug_info],
)
update_params.click(
update_params_fn,
[temp_slider, top_p_slider, top_k_slider, rep_penalty_slider],
[debug_info]
)
return demo
except Exception as e:
logging.error(f"Error creating demo: {str(e)}")
raise
if __name__ == "__main__":
# Load config file
import json
try:
with open("config.json", "r") as f:
config = json.load(f)
logging.info("Successfully loaded config file")
except Exception as e:
logging.error(f"Error loading config: {str(e)}")
config = None
demo = create_demo(config)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
) |