import os import sys from transformers import AutoProcessor, CLIPModel, AutoTokenizer src_directory = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src")) sys.path.append(src_directory) from data import request_images from utils import logger import torch logger = logger.get_logger() class ClipModel: _models = {} def __init__(self, model_name: str = "openai/clip-vit-base-patch32", tokenizer_name: str = "openai/clip-vit-large-patch14"): self.model_name = model_name self.tokenizer_name = tokenizer_name if model_name not in ClipModel._models: ClipModel._models[model_name] = self.load_models() def load_models(self): try: logger.info(f"Loading the models: {self.model_name}") model = CLIPModel.from_pretrained(self.model_name) processor = AutoProcessor.from_pretrained(self.model_name) tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) return { 'model': model, 'processor': processor, 'tokenizer': tokenizer } except Exception as e: logger.error(f"Unable to load the model {e}") raise def get_text_embedding(self, text: str): try: logger.info(f"Getting embedding for the text: {text}") inputs = self._models[self.model_name]['tokenizer']([text], padding=True, return_tensors="pt") text_features = self._models[self.model_name]['model'].get_text_features(**inputs) text_embedding = text_features.detach().numpy().flatten().tolist() logger.info("Text embedding successfully retrieved.") return text_embedding except Exception as e: logger.error(f"Error while getting embedding for text: {e}") raise def get_image_embedding(self, image): try: logger.info(f"Getting embedding for the image") image = request_images.get_image_url(image) inputs = self._models[self.model_name]['processor'](images=image, return_tensors="pt") image_features = self._models[self.model_name]['model'].get_image_features(**inputs) embeddings = image_features.detach().cpu().numpy().flatten().tolist() logger.info("Image embedding successfully retrieved.") return embeddings except Exception as e: logger.error(f"Error while getting embedding for image: {e}") raise def get_uploaded_image_embedding(self, image): try: logger.info(f"Getting embedding for the image") image = request_images.convert_image_to_embedding_format(image) inputs = self._models[self.model_name]['processor'](images=image, return_tensors="pt") image_features = self._models[self.model_name]['model'].get_image_features(**inputs) embeddings = image_features.detach().cpu().numpy().flatten().tolist() logger.info("Image embedding successfully retrieved.") return embeddings except Exception as e: logger.error(f"Error while getting embedding for image: {e}") raise if __name__ == "__main__": try: logger.info("Starting the initialization of the ClipModel class...") clip_model = ClipModel() logger.info("ClipModel class initialized successfully.") except Exception as e: logger.error(f"Error during ClipModel initialization: {str(e)}")