Spaces:
Sleeping
Sleeping
import os | |
import sys | |
from transformers import AutoProcessor, CLIPModel, AutoTokenizer | |
src_directory = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src")) | |
sys.path.append(src_directory) | |
from data import request_images | |
from utils import logger | |
import torch | |
logger = logger.get_logger() | |
class ClipModel: | |
_models = {} | |
def __init__(self, model_name: str = "openai/clip-vit-base-patch32", tokenizer_name: str = "openai/clip-vit-large-patch14"): | |
self.model_name = model_name | |
self.tokenizer_name = tokenizer_name | |
if model_name not in ClipModel._models: | |
ClipModel._models[model_name] = self.load_models() | |
def load_models(self): | |
try: | |
logger.info(f"Loading the models: {self.model_name}") | |
model = CLIPModel.from_pretrained(self.model_name) | |
processor = AutoProcessor.from_pretrained(self.model_name) | |
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) | |
return { | |
'model': model, | |
'processor': processor, | |
'tokenizer': tokenizer | |
} | |
except Exception as e: | |
logger.error(f"Unable to load the model {e}") | |
raise | |
def get_text_embedding(self, text: str): | |
try: | |
logger.info(f"Getting embedding for the text: {text}") | |
inputs = self._models[self.model_name]['tokenizer']([text], padding=True, return_tensors="pt") | |
text_features = self._models[self.model_name]['model'].get_text_features(**inputs) | |
text_embedding = text_features.detach().numpy().flatten().tolist() | |
logger.info("Text embedding successfully retrieved.") | |
return text_embedding | |
except Exception as e: | |
logger.error(f"Error while getting embedding for text: {e}") | |
raise | |
def get_image_embedding(self, image): | |
try: | |
logger.info(f"Getting embedding for the image") | |
image = request_images.get_image_url(image) | |
inputs = self._models[self.model_name]['processor'](images=image, return_tensors="pt") | |
image_features = self._models[self.model_name]['model'].get_image_features(**inputs) | |
embeddings = image_features.detach().cpu().numpy().flatten().tolist() | |
logger.info("Image embedding successfully retrieved.") | |
return embeddings | |
except Exception as e: | |
logger.error(f"Error while getting embedding for image: {e}") | |
raise | |
def get_uploaded_image_embedding(self, image): | |
try: | |
logger.info(f"Getting embedding for the image") | |
image = request_images.convert_image_to_embedding_format(image) | |
inputs = self._models[self.model_name]['processor'](images=image, return_tensors="pt") | |
image_features = self._models[self.model_name]['model'].get_image_features(**inputs) | |
embeddings = image_features.detach().cpu().numpy().flatten().tolist() | |
logger.info("Image embedding successfully retrieved.") | |
return embeddings | |
except Exception as e: | |
logger.error(f"Error while getting embedding for image: {e}") | |
raise | |
if __name__ == "__main__": | |
try: | |
logger.info("Starting the initialization of the ClipModel class...") | |
clip_model = ClipModel() | |
logger.info("ClipModel class initialized successfully.") | |
except Exception as e: | |
logger.error(f"Error during ClipModel initialization: {str(e)}") |