File size: 7,559 Bytes
32887b7
 
932d067
422a1d6
32887b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422a1d6
 
32887b7
 
 
 
 
932d067
 
32887b7
971be40
32887b7
932d067
 
32887b7
932d067
32887b7
932d067
 
32887b7
422a1d6
932d067
 
32887b7
932d067
 
 
 
32887b7
932d067
32887b7
932d067
32887b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
932d067
 
 
 
 
 
 
 
 
 
 
 
 
 
422a1d6
971be40
 
 
 
 
 
 
 
 
 
932d067
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
971be40
932d067
 
 
 
 
971be40
932d067
971be40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import sys
import logging
import os

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

# First try to import from llava
try:
    from llava.model.builder import load_pretrained_model
    from llava.mm_utils import process_images, tokenizer_image_token
    logger.info("Successfully imported llava modules")
except ImportError as e:
    logger.error(f"Failed to import llava modules: {e}")
    sys.exit(1)

# Then import other dependencies
try:
    from transformers import AutoTokenizer, AutoConfig
    import torch
    import requests
    from PIL import Image
    from io import BytesIO
    logger.info("Successfully imported other required modules")
except ImportError as e:
    logger.error(f"Failed to import dependency: {e}")
    sys.exit(1)

class LLaVAHelper:
    def __init__(self, model_name="llava-hf/llava-1.5-7b-hf"):
        """
        Initialize the LLaVA model for image-text processing
        """
        logger.info(f"Initializing LLaVAHelper with model: {model_name}")
        
        # Create cache directory if it doesn't exist
        os.makedirs("./model_cache", exist_ok=True)
        logger.info("Created model cache directory")
        
        # Try loading just the config to ensure the model is valid
        try:
            AutoConfig.from_pretrained(model_name)
            logger.info(f"Successfully loaded config for {model_name}")
        except Exception as e:
            logger.warning(f"Error loading model config: {e}")
            # Try a different model version as fallback
            model_name = "llava-hf/llava-1.5-13b-hf"
            logger.info(f"Trying alternative model: {model_name}")

        try:
            # Use specific tokenizer class to avoid issues
            logger.info("Loading tokenizer...")
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                cache_dir="./model_cache",
                use_fast=False,  # Use the Python implementation instead of the Rust one
                trust_remote_code=True
            )
            logger.info("Tokenizer loaded successfully")
            
            # Inspect the load_pretrained_model function to understand its parameters
            import inspect
            logger.info(f"load_pretrained_model signature: {inspect.signature(load_pretrained_model)}")
            
            # Try loading with different parameter combinations
            logger.info("Loading model...")
            try:
                # First attempt - standard parameter order
                self.model, self.image_processor, _ = load_pretrained_model(
                    model_path=model_name,
                    model_base=None,
                    cache_dir="./model_cache",
                )
            except Exception as e1:
                logger.warning(f"First attempt to load model failed: {e1}")
                try:
                    # Second attempt - try with model_name parameter
                    self.model, self.image_processor, _ = load_pretrained_model(
                        model_name=model_name,
                        model_path=model_name,
                        model_base=None,
                        cache_dir="./model_cache",
                    )
                except Exception as e2:
                    logger.warning(f"Second attempt to load model failed: {e2}")
                    # Third attempt - minimal parameters
                    self.model, self.image_processor, _ = load_pretrained_model(
                        model_name,
                        None,
                        "./model_cache",
                    )
            
            logger.info("Model loaded successfully")
            self.model.eval()
            
            # Move model to appropriate device
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            logger.info(f"Using device: {self.device}")
            if self.device == "cpu":
                # If using CPU, make sure model is in the right place
                self.model = self.model.to(self.device)
            
            logger.info(f"Model successfully loaded on {self.device}")
        except Exception as e:
            logger.error(f"Detailed initialization error: {e}")
            logger.error("Stack trace:", exc_info=True)
            raise
            
            self.model.eval()
            
            # Move model to appropriate device
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            if self.device == "cpu":
                # If using CPU, make sure model is in the right place
                self.model = self.model.to(self.device)
            
            print(f"Model loaded on {self.device}")
        except Exception as e:
            print(f"Detailed initialization error: {e}")
            raise
    
    def generate_answer(self, image, question):
        """
        Generate a response to a question about an image
        
        Args:
            image: PIL Image or path to image
            question: String question about the image
        
        Returns:
            String response from the model
        """
        try:
            # Handle image input (either PIL Image or path/URL)
            if isinstance(image, str):
                if image.startswith(('http://', 'https://')):
                    response = requests.get(image)
                    image = Image.open(BytesIO(response.content))
                else:
                    image = Image.open(image)
            
            # Preprocess image
            image_tensor = process_images(
                [image],
                self.image_processor,
                self.model.config
            )[0].unsqueeze(0).to(self.device)
            
            # Format prompt with question
            prompt = f"###Human: <image>\n{question}\n###Assistant:"
            
            # Tokenize prompt
            input_ids = tokenizer_image_token(
                prompt,
                self.tokenizer,
                return_tensors="pt"
            ).to(self.device)
            
            # Generate response
            with torch.no_grad():
                output_ids = self.model.generate(
                    input_ids=input_ids.input_ids,
                    images=image_tensor,
                    max_new_tokens=512,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                )
            
            # Decode and extract response
            output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
            return output.split("###Assistant:")[-1].strip()
        except Exception as e:
            return f"Error generating answer: {str(e)}"

# Example usage if __name__ == "__main__":
if __name__ == "__main__":
    try:
        # Initialize model
        llava = LLaVAHelper()
        
        # Example with a local file
        # response = llava.generate_answer("path/to/your/image.jpg", "What's in this image?")
        
        # Example with a URL
        # image_url = "https://example.com/image.jpg"
        # response = llava.generate_answer(image_url, "Describe this image in detail.")
        
        # print(response)
        print("LLaVA model initialized successfully. Ready to process images.")
    except Exception as e:
        print(f"Error initializing LLaVA: {e}")