from typing import List, Dict from PIL.Image import Image import torch from transformers import AutoModel, AutoProcessor from .utils import normalize_vectors MODEL_NAME = "Marqo/marqo-fashionCLIP" class FashionCLIPEncoder: def __init__(self, normalize: bool = False): self.normalize = normalize self.device = torch.device("cpu") self.processor = AutoProcessor.from_pretrained( MODEL_NAME, trust_remote_code=True, ) self.model = AutoModel.from_pretrained( MODEL_NAME, trust_remote_code=True, ) self.model.to(self.device) self.model.eval() def encode_text(self, texts: List[str]) -> List[List[float]]: kwargs = { "padding": "max_length", "return_tensors": "pt", "truncation": True, } inputs = self.processor(text=texts, **kwargs) with torch.no_grad(): batch = {k: v.to(self.device) for k, v in inputs.items()} vectors = self.model.get_text_features(**batch) return self._postprocess_vectors(vectors) def encode_images(self, images: List[Image]) -> List[List[float]]: inputs = self.processor(images=images, return_tensors="pt") with torch.no_grad(): batch = {k: v.to(self.device) for k, v in inputs.items()} vectors = self.model.get_image_features(**batch) return self._postprocess_vectors(vectors) def _postprocess_vectors(self, vectors: torch.Tensor) -> List[List[float]]: if self.normalize: vectors = normalize_vectors(vectors) return vectors.detach().cpu().numpy().tolist()