Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel | |
| from PIL import Image | |
| import logging | |
| import spaces | |
| import numpy | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| class LLaVAPhiModel: | |
| def __init__(self, model_id="microsoft/phi-1_5"): # Updated to match config | |
| self.device = "cuda" | |
| self.model_id = model_id | |
| logging.info(f"Initializing LLaVA-Phi model with {model_id}...") | |
| # Initialize tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| try: | |
| # Use CLIPProcessor with the correct model name from config | |
| self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| logging.info("Successfully loaded CLIP processor") | |
| except Exception as e: | |
| logging.error(f"Failed to load CLIP processor: {str(e)}") | |
| self.processor = None | |
| # Increase history length to retain more context | |
| self.history = [] | |
| self.model = None | |
| self.clip = None | |
| # Default generation parameters - can be updated from config | |
| self.temperature = 0.3 | |
| self.top_p = 0.92 | |
| self.top_k = 50 | |
| self.repetition_penalty = 1.2 | |
| # Set max length from config | |
| self.max_length = 512 # Default value, will be updated from config | |
| def ensure_models_loaded(self): | |
| """Ensure models are loaded in GPU context""" | |
| if self.model is None: | |
| # Use 4-bit quantization according to config | |
| from transformers import BitsAndBytesConfig | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, # Changed to match config | |
| bnb_4bit_compute_dtype=torch.bfloat16, # Changed to bfloat16 to match config's mixed_precision | |
| bnb_4bit_use_double_quant=False | |
| ) | |
| try: | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_id, | |
| quantization_config=quantization_config, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True | |
| ) | |
| self.model.config.pad_token_id = self.tokenizer.eos_token_id | |
| logging.info(f"Successfully loaded main model: {self.model_id}") | |
| except Exception as e: | |
| logging.error(f"Failed to load main model: {str(e)}") | |
| raise | |
| if self.clip is None: | |
| try: | |
| # Load CLIP model from config | |
| clip_model_name = "openai/clip-vit-base-patch32" # From config | |
| self.clip = CLIPModel.from_pretrained(clip_model_name).to(self.device) | |
| logging.info(f"Successfully loaded CLIP model: {clip_model_name}") | |
| except Exception as e: | |
| logging.error(f"Failed to load CLIP model: {str(e)}") | |
| self.clip = None | |
| def apply_lora_config(self, lora_params): | |
| """Apply LoRA configuration to the model - to be called during training""" | |
| from peft import LoraConfig, get_peft_model | |
| lora_config = LoraConfig( | |
| r=lora_params.get("r", 16), | |
| lora_alpha=lora_params.get("lora_alpha", 32), | |
| lora_dropout=lora_params.get("lora_dropout", 0.05), | |
| target_modules=lora_params.get("target_modules", ["Wqkv", "out_proj"]), | |
| bias="none", | |
| task_type="CAUSAL_LM" | |
| ) | |
| # Convert model to PEFT/LoRA model | |
| self.model = get_peft_model(self.model, lora_config) | |
| logging.info("Applied LoRA configuration to the model") | |
| return self.model | |
| def generate_response(self, message, image=None): | |
| try: | |
| self.ensure_models_loaded() | |
| # Prepare prompt based on whether we have an image | |
| has_image = image is not None | |
| # Process text input | |
| if has_image: | |
| # For image+text input | |
| prompt = f"human: <image>\n{message}\ngpt:" | |
| # Check if model has vision encoding capability | |
| if not hasattr(self.model, "encode_image") and not hasattr(self.model, "get_vision_tower"): | |
| logging.warning("Model doesn't have standard image encoding methods") | |
| has_image = False | |
| prompt = f"human: {message}\ngpt:" | |
| else: | |
| # For text-only input | |
| prompt = f"human: {message}\ngpt:" | |
| # Include previous conversation context | |
| context = "" | |
| for turn in self.history[-5:]: # Include 5 previous turns | |
| context += f"human: {turn[0]}\ngpt: {turn[1]}\n" | |
| full_prompt = context + prompt | |
| # Tokenize the input text | |
| inputs = self.tokenizer( | |
| full_prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=self.max_length | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| # LLaVA-Phi specific image handling | |
| if has_image: | |
| try: | |
| # Convert image to correct format | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| elif isinstance(image, numpy.ndarray): | |
| image = Image.fromarray(image) | |
| # Ensure image is in RGB mode | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Process the image with CLIP processor | |
| image_inputs = self.processor(images=image, return_tensors="pt") | |
| image_features = self.clip.get_image_features( | |
| pixel_values=image_inputs.pixel_values.to(self.device) | |
| ) | |
| # Some LLaVA models have a prepare_inputs_for_generation method | |
| if hasattr(self.model, "prepare_inputs_for_generation"): | |
| logging.info("Using model's prepare_inputs_for_generation for image handling") | |
| # Generate with image context | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| min_length=20, | |
| temperature=self.temperature, | |
| do_sample=True, | |
| top_p=self.top_p, | |
| top_k=self.top_k, | |
| repetition_penalty=self.repetition_penalty, | |
| no_repeat_ngram_size=3, | |
| use_cache=True, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id | |
| ) | |
| except Exception as e: | |
| logging.error(f"Error handling image: {str(e)}") | |
| # Fall back to text-only generation | |
| logging.info("Falling back to text-only generation") | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| min_length=20, | |
| temperature=self.temperature, | |
| do_sample=True, | |
| top_p=self.top_p, | |
| top_k=self.top_k, | |
| repetition_penalty=self.repetition_penalty, | |
| no_repeat_ngram_size=3, | |
| use_cache=True, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id | |
| ) | |
| else: | |
| # Text-only generation | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=200, | |
| min_length=20, | |
| temperature=self.temperature, | |
| do_sample=True, | |
| top_p=self.top_p, | |
| top_k=self.top_k, | |
| repetition_penalty=self.repetition_penalty, | |
| no_repeat_ngram_size=4, | |
| use_cache=True, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id | |
| ) | |
| # Decode and clean up the response | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Clean up response | |
| if "gpt:" in response: | |
| response = response.split("gpt:")[-1].strip() | |
| if "human:" in response: | |
| response = response.split("human:")[0].strip() | |
| if "<image>" in response: | |
| response = response.replace("<image>", "").strip() | |
| self.history.append((message, response)) | |
| return response | |
| except Exception as e: | |
| logging.error(f"Error generating response: {str(e)}") | |
| logging.error(f"Full traceback:", exc_info=True) | |
| return f"Error: {str(e)}" | |
| def clear_history(self): | |
| self.history = [] | |
| return None | |
| # Add new function to control generation parameters | |
| def update_generation_params(self, temperature=0.3, top_p=0.92, top_k=50, repetition_penalty=1.2): | |
| """Update generation parameters to control hallucination tendency""" | |
| self.temperature = temperature | |
| self.top_p = top_p | |
| self.top_k = top_k | |
| self.repetition_penalty = repetition_penalty | |
| return f"Generation parameters updated: temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}" | |
| # New method to apply config file settings | |
| def apply_config(self, config): | |
| """Apply settings from config file""" | |
| model_params = config.get("model_params", {}) | |
| self.model_id = model_params.get("model_name", self.model_id) | |
| self.max_length = model_params.get("max_length", 512) | |
| # Update generation parameters if needed | |
| training_params = config.get("training_params", {}) | |
| # Could add specific updates based on training_params if needed | |
| return f"Applied configuration. Model: {self.model_id}, Max Length: {self.max_length}" | |
| def create_demo(config=None): | |
| try: | |
| # Initialize with config file settings | |
| model = LLaVAPhiModel() | |
| if config: | |
| model.apply_config(config) | |
| with gr.Blocks(css="footer {visibility: hidden}") as demo: | |
| gr.Markdown( | |
| """ | |
| # LLaVA-Phi Demo (Optimized for Accuracy) | |
| Chat with a vision-language model that can understand both text and images. | |
| """ | |
| ) | |
| chatbot = gr.Chatbot(height=400) | |
| with gr.Row(): | |
| with gr.Column(scale=0.7): | |
| msg = gr.Textbox( | |
| show_label=False, | |
| placeholder="Enter text and/or upload an image", | |
| container=False | |
| ) | |
| with gr.Column(scale=0.15, min_width=0): | |
| clear = gr.Button("Clear") | |
| with gr.Column(scale=0.15, min_width=0): | |
| submit = gr.Button("Submit", variant="primary") | |
| image = gr.Image(type="pil", label="Upload Image (Optional)") | |
| # Add generation parameter controls | |
| with gr.Accordion("Advanced Settings (Reduce Hallucinations)", open=False): | |
| gr.Markdown("Adjust these parameters to control hallucination tendency") | |
| temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)") | |
| top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)") | |
| top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k") | |
| rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty") | |
| update_params = gr.Button("Update Parameters") | |
| # Add debugging information box | |
| debug_info = gr.Textbox(label="Debug Info", interactive=False) | |
| # Add config information | |
| if config: | |
| config_info = f"Model: {model.model_id}, Max Length: {model.max_length}" | |
| gr.Markdown(f"**Current Configuration:** {config_info}") | |
| def respond(message, chat_history, image): | |
| if not message and image is None: | |
| return chat_history, "" | |
| try: | |
| response = model.generate_response(message, image) | |
| chat_history.append((message, response)) | |
| debug_msg = "Response generated successfully" | |
| return "", chat_history, debug_msg | |
| except Exception as e: | |
| debug_msg = f"Error: {str(e)}" | |
| return message, chat_history, debug_msg | |
| def clear_chat(): | |
| model.clear_history() | |
| return None, None, "Chat history cleared" | |
| def update_params_fn(temp, top_p, top_k, rep_penalty): | |
| result = model.update_generation_params(temp, top_p, top_k, rep_penalty) | |
| return f"Parameters updated: temp={temp}, top_p={top_p}, top_k={top_k}, rep_penalty={rep_penalty}" | |
| submit.click( | |
| respond, | |
| [msg, chatbot, image], | |
| [msg, chatbot, debug_info], | |
| ) | |
| clear.click( | |
| clear_chat, | |
| None, | |
| [chatbot, image, debug_info], | |
| ) | |
| msg.submit( | |
| respond, | |
| [msg, chatbot, image], | |
| [msg, chatbot, debug_info], | |
| ) | |
| update_params.click( | |
| update_params_fn, | |
| [temp_slider, top_p_slider, top_k_slider, rep_penalty_slider], | |
| [debug_info] | |
| ) | |
| return demo | |
| except Exception as e: | |
| logging.error(f"Error creating demo: {str(e)}") | |
| raise | |
| if __name__ == "__main__": | |
| # Load config file | |
| import json | |
| try: | |
| with open("config.json", "r") as f: | |
| config = json.load(f) | |
| logging.info("Successfully loaded config file") | |
| except Exception as e: | |
| logging.error(f"Error loading config: {str(e)}") | |
| config = None | |
| demo = create_demo(config) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True | |
| ) |