fclip / src /encoder.py
pediot's picture
Refactor Dockerfile and encoder.py to improve environment variable handling and model initialization
c54b3d1
raw
history blame
1.93 kB
from typing import List, Dict
from PIL.Image import Image
import os
import torch
from transformers import AutoModel, AutoProcessor
MODEL_NAME = "Marqo/marqo-fashionCLIP"
HF_TOKEN = os.environ.get("HF_TOKEN")
class FashionCLIPEncoder:
def __init__(self):
self.device = torch.device("cpu")
self.processor = AutoProcessor.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
token=HF_TOKEN
)
try:
self.model = AutoModel.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
device_map=None,
token=HF_TOKEN
)
self.model = self.model.to(self.device)
self.model.eval()
except Exception as e:
print(f"Error initializing model: {str(e)}")
raise
def encode_text(self, texts: List[str]) -> List[List[float]]:
kwargs = {
"padding": "max_length",
"return_tensors": "pt",
"truncation": True,
}
inputs = self.processor(text=texts, **kwargs)
with torch.no_grad():
batch = {k: v.to(self.device) for k, v in inputs.items()}
return self._encode_text(batch)
def encode_images(self, images: List[Image]) -> List[List[float]]:
kwargs = {
"return_tensors": "pt",
}
inputs = self.processor(images=images, **kwargs)
with torch.no_grad():
batch = {k: v.to(self.device) for k, v in inputs.items()}
return self._encode_images(batch)
def _encode_text(self, batch: Dict) -> List[List[float]]:
return self.model.get_text_features(**batch).detach().cpu().numpy().tolist()
def _encode_images(self, batch: Dict) -> List[List[float]]:
return self.model.get_image_features(**batch).detach().cpu().numpy().tolist()