File size: 15,753 Bytes
8e31ab1
 
94ee0c6
8e31ab1
 
8ec9ef4
f87dcd8
8e31ab1
f87dcd8
8e31ab1
 
 
f87dcd8
94ee0c6
8ec9ef4
f87dcd8
8e31ab1
f87dcd8
8ec9ef4
 
 
f87dcd8
 
 
 
 
 
 
 
 
 
8ec9ef4
 
 
f87dcd8
 
 
 
 
 
 
 
 
8ec9ef4
 
 
f87dcd8
8ec9ef4
f87dcd8
8ec9ef4
 
f87dcd8
 
 
8ec9ef4
f87dcd8
 
 
 
 
 
 
 
 
 
 
 
 
 
8ec9ef4
 
f87dcd8
 
 
 
 
 
 
 
2144e66
f87dcd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ec9ef4
f87dcd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1abfce8
f87dcd8
1abfce8
 
f87dcd8
 
 
 
1abfce8
 
 
 
 
 
8ec9ef4
1abfce8
 
 
 
 
 
 
8e31ab1
1abfce8
 
 
 
 
 
 
f87dcd8
 
1abfce8
 
 
 
 
 
f87dcd8
 
 
 
 
 
 
 
1abfce8
 
 
f87dcd8
bd91e22
f87dcd8
 
 
 
 
 
 
 
8e31ab1
1abfce8
 
f87dcd8
8e31ab1
1abfce8
f87dcd8
 
8e31ab1
1abfce8
 
 
f87dcd8
1abfce8
8e31ab1
1abfce8
 
 
f87dcd8
1abfce8
 
 
 
 
f87dcd8
1abfce8
 
 
 
 
f87dcd8
1abfce8
f87dcd8
1abfce8
 
 
 
8e31ab1
 
f87dcd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel
from PIL import Image
import logging
import spaces
import numpy

# Setup logging
logging.basicConfig(level=logging.INFO)

class LLaVAPhiModel:
    def __init__(self, model_id="microsoft/phi-1_5"):  # Updated to match config
        self.device = "cuda"
        self.model_id = model_id
        logging.info(f"Initializing LLaVA-Phi model with {model_id}...")
        
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        try:
            # Use CLIPProcessor with the correct model name from config
            self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
            logging.info("Successfully loaded CLIP processor")
        except Exception as e:
            logging.error(f"Failed to load CLIP processor: {str(e)}")
            self.processor = None
        
        # Increase history length to retain more context
        self.history = []
        self.model = None
        self.clip = None
        
        # Default generation parameters - can be updated from config
        self.temperature = 0.3
        self.top_p = 0.92
        self.top_k = 50
        self.repetition_penalty = 1.2
        
        # Set max length from config
        self.max_length = 512  # Default value, will be updated from config

    @spaces.GPU
    def ensure_models_loaded(self):
        """Ensure models are loaded in GPU context"""
        if self.model is None:
            # Use 4-bit quantization according to config
            from transformers import BitsAndBytesConfig
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,  # Changed to match config
                bnb_4bit_compute_dtype=torch.bfloat16,  # Changed to bfloat16 to match config's mixed_precision
                bnb_4bit_use_double_quant=False
            )
            
            try:
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_id,
                    quantization_config=quantization_config,
                    device_map="auto",
                    torch_dtype=torch.bfloat16,
                    trust_remote_code=True
                )
                self.model.config.pad_token_id = self.tokenizer.eos_token_id
                logging.info(f"Successfully loaded main model: {self.model_id}")
            except Exception as e:
                logging.error(f"Failed to load main model: {str(e)}")
                raise

        if self.clip is None:
            try:
                # Load CLIP model from config
                clip_model_name = "openai/clip-vit-base-patch32"  # From config
                self.clip = CLIPModel.from_pretrained(clip_model_name).to(self.device)
                logging.info(f"Successfully loaded CLIP model: {clip_model_name}")
            except Exception as e:
                logging.error(f"Failed to load CLIP model: {str(e)}")
                self.clip = None

    def apply_lora_config(self, lora_params):
        """Apply LoRA configuration to the model - to be called during training"""
        from peft import LoraConfig, get_peft_model
        
        lora_config = LoraConfig(
            r=lora_params.get("r", 16),
            lora_alpha=lora_params.get("lora_alpha", 32),
            lora_dropout=lora_params.get("lora_dropout", 0.05),
            target_modules=lora_params.get("target_modules", ["Wqkv", "out_proj"]),
            bias="none",
            task_type="CAUSAL_LM"
        )
        
        # Convert model to PEFT/LoRA model
        self.model = get_peft_model(self.model, lora_config)
        logging.info("Applied LoRA configuration to the model")
        return self.model

    @spaces.GPU(duration=120)
    def generate_response(self, message, image=None):
        try:
            self.ensure_models_loaded()
            
            # Prepare prompt based on whether we have an image
            has_image = image is not None
            
            # Process text input
            if has_image:
                # For image+text input
                prompt = f"human: <image>\n{message}\ngpt:"
                
                # Check if model has vision encoding capability
                if not hasattr(self.model, "encode_image") and not hasattr(self.model, "get_vision_tower"):
                    logging.warning("Model doesn't have standard image encoding methods")
                    has_image = False
                    prompt = f"human: {message}\ngpt:"
            else:
                # For text-only input
                prompt = f"human: {message}\ngpt:"
            
            # Include previous conversation context
            context = ""
            for turn in self.history[-5:]:  # Include 5 previous turns
                context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
            
            full_prompt = context + prompt
            
            # Tokenize the input text
            inputs = self.tokenizer(
                full_prompt,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self.max_length
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            # LLaVA-Phi specific image handling
            if has_image:
                try:
                    # Convert image to correct format
                    if isinstance(image, str):
                        image = Image.open(image)
                    elif isinstance(image, numpy.ndarray):
                        image = Image.fromarray(image)
                    
                    # Ensure image is in RGB mode
                    if image.mode != 'RGB':
                        image = image.convert('RGB')
                    
                    # Process the image with CLIP processor
                    image_inputs = self.processor(images=image, return_tensors="pt")
                    image_features = self.clip.get_image_features(
                        pixel_values=image_inputs.pixel_values.to(self.device)
                    )
                        
                    # Some LLaVA models have a prepare_inputs_for_generation method
                    if hasattr(self.model, "prepare_inputs_for_generation"):
                        logging.info("Using model's prepare_inputs_for_generation for image handling")
                        
                    # Generate with image context
                    with torch.no_grad():
                        outputs = self.model.generate(
                            **inputs,
                            max_new_tokens=256,
                            min_length=20,
                            temperature=self.temperature,
                            do_sample=True,
                            top_p=self.top_p,
                            top_k=self.top_k,
                            repetition_penalty=self.repetition_penalty,
                            no_repeat_ngram_size=3,
                            use_cache=True,
                            pad_token_id=self.tokenizer.pad_token_id,
                            eos_token_id=self.tokenizer.eos_token_id
                        )
                    
                except Exception as e:
                    logging.error(f"Error handling image: {str(e)}")
                    # Fall back to text-only generation
                    logging.info("Falling back to text-only generation")
                    with torch.no_grad():
                        outputs = self.model.generate(
                            **inputs,
                            max_new_tokens=256,
                            min_length=20,
                            temperature=self.temperature,
                            do_sample=True,
                            top_p=self.top_p,
                            top_k=self.top_k,
                            repetition_penalty=self.repetition_penalty,
                            no_repeat_ngram_size=3,
                            use_cache=True,
                            pad_token_id=self.tokenizer.pad_token_id,
                            eos_token_id=self.tokenizer.eos_token_id
                        )
            else:
                # Text-only generation
                with torch.no_grad():
                    outputs = self.model.generate(
                        **inputs,
                        max_new_tokens=200,
                        min_length=20,
                        temperature=self.temperature,
                        do_sample=True,
                        top_p=self.top_p,
                        top_k=self.top_k,
                        repetition_penalty=self.repetition_penalty,
                        no_repeat_ngram_size=4,
                        use_cache=True,
                        pad_token_id=self.tokenizer.pad_token_id,
                        eos_token_id=self.tokenizer.eos_token_id
                    )
            
            # Decode and clean up the response
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Clean up response
            if "gpt:" in response:
                response = response.split("gpt:")[-1].strip()
            if "human:" in response:
                response = response.split("human:")[0].strip()
            if "<image>" in response:
                response = response.replace("<image>", "").strip()
            
            self.history.append((message, response))
            return response
            
        except Exception as e:
            logging.error(f"Error generating response: {str(e)}")
            logging.error(f"Full traceback:", exc_info=True)
            return f"Error: {str(e)}"

    def clear_history(self):
        self.history = []
        return None

    # Add new function to control generation parameters
    def update_generation_params(self, temperature=0.3, top_p=0.92, top_k=50, repetition_penalty=1.2):
        """Update generation parameters to control hallucination tendency"""
        self.temperature = temperature
        self.top_p = top_p
        self.top_k = top_k
        self.repetition_penalty = repetition_penalty
        return f"Generation parameters updated: temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}"

    # New method to apply config file settings
    def apply_config(self, config):
        """Apply settings from config file"""
        model_params = config.get("model_params", {})
        self.model_id = model_params.get("model_name", self.model_id)
        self.max_length = model_params.get("max_length", 512)
        
        # Update generation parameters if needed
        training_params = config.get("training_params", {})
        # Could add specific updates based on training_params if needed
        
        return f"Applied configuration. Model: {self.model_id}, Max Length: {self.max_length}"

def create_demo(config=None):
    try:
        # Initialize with config file settings
        model = LLaVAPhiModel()
        
        if config:
            model.apply_config(config)
        
        with gr.Blocks(css="footer {visibility: hidden}") as demo:
            gr.Markdown(
                """
                # LLaVA-Phi Demo (Optimized for Accuracy)
                Chat with a vision-language model that can understand both text and images.
                """
            )
            
            chatbot = gr.Chatbot(height=400)
            with gr.Row():
                with gr.Column(scale=0.7):
                    msg = gr.Textbox(
                        show_label=False,
                        placeholder="Enter text and/or upload an image",
                        container=False
                    )
                with gr.Column(scale=0.15, min_width=0):
                    clear = gr.Button("Clear")
                with gr.Column(scale=0.15, min_width=0):
                    submit = gr.Button("Submit", variant="primary")
            
            image = gr.Image(type="pil", label="Upload Image (Optional)")
            
            # Add generation parameter controls
            with gr.Accordion("Advanced Settings (Reduce Hallucinations)", open=False):
                gr.Markdown("Adjust these parameters to control hallucination tendency")
                temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)")
                top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)")
                top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k")
                rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
                update_params = gr.Button("Update Parameters")
                
                # Add debugging information box
                debug_info = gr.Textbox(label="Debug Info", interactive=False)
                
                # Add config information
                if config:
                    config_info = f"Model: {model.model_id}, Max Length: {model.max_length}"
                    gr.Markdown(f"**Current Configuration:** {config_info}")
            
            def respond(message, chat_history, image):
                if not message and image is None:
                    return chat_history, ""
                
                try:
                    response = model.generate_response(message, image)
                    chat_history.append((message, response))
                    debug_msg = "Response generated successfully"
                    return "", chat_history, debug_msg
                except Exception as e:
                    debug_msg = f"Error: {str(e)}"
                    return message, chat_history, debug_msg
            
            def clear_chat():
                model.clear_history()
                return None, None, "Chat history cleared"
            
            def update_params_fn(temp, top_p, top_k, rep_penalty):
                result = model.update_generation_params(temp, top_p, top_k, rep_penalty)
                return f"Parameters updated: temp={temp}, top_p={top_p}, top_k={top_k}, rep_penalty={rep_penalty}"
            
            submit.click(
                respond,
                [msg, chatbot, image],
                [msg, chatbot, debug_info],
            )
            
            clear.click(
                clear_chat,
                None,
                [chatbot, image, debug_info],
            )
            
            msg.submit(
                respond,
                [msg, chatbot, image],
                [msg, chatbot, debug_info],
            )
            
            update_params.click(
                update_params_fn,
                [temp_slider, top_p_slider, top_k_slider, rep_penalty_slider],
                [debug_info]
            )
            
        return demo
    except Exception as e:
        logging.error(f"Error creating demo: {str(e)}")
        raise

if __name__ == "__main__":
    # Load config file
    import json
    
    try:
        with open("config.json", "r") as f:
            config = json.load(f)
        logging.info("Successfully loaded config file")
    except Exception as e:
        logging.error(f"Error loading config: {str(e)}")
        config = None
    
    demo = create_demo(config)
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=True
    )