Vela
added torch
38f7446
import os
import sys
from transformers import AutoProcessor, CLIPModel, AutoTokenizer
src_directory = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src"))
sys.path.append(src_directory)
from data import request_images
from utils import logger
import torch
logger = logger.get_logger()
class ClipModel:
_models = {}
def __init__(self, model_name: str = "openai/clip-vit-base-patch32", tokenizer_name: str = "openai/clip-vit-large-patch14"):
self.model_name = model_name
self.tokenizer_name = tokenizer_name
if model_name not in ClipModel._models:
ClipModel._models[model_name] = self.load_models()
def load_models(self):
try:
logger.info(f"Loading the models: {self.model_name}")
model = CLIPModel.from_pretrained(self.model_name)
processor = AutoProcessor.from_pretrained(self.model_name)
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
return {
'model': model,
'processor': processor,
'tokenizer': tokenizer
}
except Exception as e:
logger.error(f"Unable to load the model {e}")
raise
def get_text_embedding(self, text: str):
try:
logger.info(f"Getting embedding for the text: {text}")
inputs = self._models[self.model_name]['tokenizer']([text], padding=True, return_tensors="pt")
text_features = self._models[self.model_name]['model'].get_text_features(**inputs)
text_embedding = text_features.detach().numpy().flatten().tolist()
logger.info("Text embedding successfully retrieved.")
return text_embedding
except Exception as e:
logger.error(f"Error while getting embedding for text: {e}")
raise
def get_image_embedding(self, image):
try:
logger.info(f"Getting embedding for the image")
image = request_images.get_image_url(image)
inputs = self._models[self.model_name]['processor'](images=image, return_tensors="pt")
image_features = self._models[self.model_name]['model'].get_image_features(**inputs)
embeddings = image_features.detach().cpu().numpy().flatten().tolist()
logger.info("Image embedding successfully retrieved.")
return embeddings
except Exception as e:
logger.error(f"Error while getting embedding for image: {e}")
raise
def get_uploaded_image_embedding(self, image):
try:
logger.info(f"Getting embedding for the image")
image = request_images.convert_image_to_embedding_format(image)
inputs = self._models[self.model_name]['processor'](images=image, return_tensors="pt")
image_features = self._models[self.model_name]['model'].get_image_features(**inputs)
embeddings = image_features.detach().cpu().numpy().flatten().tolist()
logger.info("Image embedding successfully retrieved.")
return embeddings
except Exception as e:
logger.error(f"Error while getting embedding for image: {e}")
raise
if __name__ == "__main__":
try:
logger.info("Starting the initialization of the ClipModel class...")
clip_model = ClipModel()
logger.info("ClipModel class initialized successfully.")
except Exception as e:
logger.error(f"Error during ClipModel initialization: {str(e)}")