Spaces:
Runtime error
Runtime error
| import io | |
| import os | |
| import traceback | |
| import torch | |
| from PIL import Image, UnidentifiedImageError | |
| from .model_loader import ModelManager | |
| class VQAInference: | |
| """ | |
| Class to perform inference with Visual Question Answering models | |
| """ | |
| def __init__(self, model_name="blip", cache_dir=None): | |
| """ | |
| Initialize the VQA inference | |
| Args: | |
| model_name (str, optional): Name of model to use. Defaults to "blip". | |
| cache_dir (str, optional): Directory to cache models. Defaults to None. | |
| """ | |
| self.model_name = model_name | |
| self.model_manager = ModelManager(cache_dir=cache_dir) | |
| self.processor, self.model = self.model_manager.get_model(model_name) | |
| self.device = self.model_manager.device | |
| def predict(self, image, question): | |
| """ | |
| Perform VQA prediction on an image with a question | |
| Args: | |
| image (PIL.Image.Image or str): Image to analyze or path to image | |
| question (str): Question to ask about the image | |
| Returns: | |
| str: Answer to the question | |
| """ | |
| # Handle image input - could be a file path or PIL Image | |
| if isinstance(image, str): | |
| try: | |
| # Check if file exists | |
| if not os.path.exists(image): | |
| raise FileNotFoundError(f"Image file not found: {image}") | |
| # Try multiple approaches to load the image | |
| try: | |
| # Try the standard approach first | |
| image = Image.open(image).convert("RGB") | |
| print( | |
| f"Successfully opened image: {image.size}, mode: {image.mode}" | |
| ) | |
| except Exception as img_err: | |
| print( | |
| f"Standard image loading failed: {img_err}, trying alternative method..." | |
| ) | |
| # Try alternative approach with binary mode explicitly | |
| with open(image, "rb") as img_file: | |
| img_data = img_file.read() | |
| image = Image.open(io.BytesIO(img_data)).convert("RGB") | |
| print( | |
| f"Alternative image loading succeeded: {image.size}, mode: {image.mode}" | |
| ) | |
| except UnidentifiedImageError as e: | |
| # Specific error when image format cannot be identified | |
| raise ValueError(f"Cannot identify image format: {str(e)}") | |
| except Exception as e: | |
| # Provide detailed error information | |
| error_details = traceback.format_exc() | |
| print(f"Error details: {error_details}") | |
| raise ValueError(f"Could not open image file: {str(e)}") | |
| # Make sure image is a PIL Image | |
| if not isinstance(image, Image.Image): | |
| raise ValueError("Image must be a PIL Image or a file path") | |
| # Process based on model type | |
| if self.model_name.lower() == "blip": | |
| return self._predict_with_blip(image, question) | |
| elif self.model_name.lower() == "vilt": | |
| return self._predict_with_vilt(image, question) | |
| else: | |
| raise ValueError(f"Prediction not implemented for model: {self.model_name}") | |
| def _predict_with_blip(self, image, question): | |
| """ | |
| Perform prediction with BLIP model | |
| Args: | |
| image (PIL.Image.Image): Image to analyze | |
| question (str): Question to ask about the image | |
| Returns: | |
| str: Answer to the question | |
| """ | |
| try: | |
| # Process image and text inputs | |
| inputs = self.processor( | |
| images=image, text=question, return_tensors="pt" | |
| ).to(self.device) | |
| # Generate answer | |
| with torch.no_grad(): | |
| outputs = self.model.generate(**inputs) | |
| # Decode the output to text | |
| answer = self.processor.decode(outputs[0], skip_special_tokens=True) | |
| return answer | |
| except Exception as e: | |
| error_details = traceback.format_exc() | |
| print(f"Error in BLIP prediction: {str(e)}") | |
| print(f"Error details: {error_details}") | |
| raise RuntimeError(f"BLIP model prediction failed: {str(e)}") | |
| def _predict_with_vilt(self, image, question): | |
| """ | |
| Perform prediction with ViLT model | |
| Args: | |
| image (PIL.Image.Image): Image to analyze | |
| question (str): Question to ask about the image | |
| Returns: | |
| str: Answer to the question | |
| """ | |
| try: | |
| # Process image and text inputs | |
| encoding = self.processor(images=image, text=question, return_tensors="pt") | |
| # Move inputs to device | |
| for k, v in encoding.items(): | |
| encoding[k] = v.to(self.device) | |
| # Forward pass | |
| with torch.no_grad(): | |
| outputs = self.model(**encoding) | |
| logits = outputs.logits | |
| # Get the predicted answer idx | |
| idx = logits.argmax(-1).item() | |
| # Convert to answer text | |
| answer = self.model.config.id2label[idx] | |
| return answer | |
| except Exception as e: | |
| error_details = traceback.format_exc() | |
| print(f"Error in ViLT prediction: {str(e)}") | |
| print(f"Error details: {error_details}") | |
| raise RuntimeError(f"ViLT model prediction failed: {str(e)}") | |