File size: 1,721 Bytes
88e0bae 5cc3519 88e0bae 5cc3519 c54b3d1 fdf1598 88e0bae fdf1598 c54b3d1 88e0bae fdf1598 168270d e5ee3a4 168270d 2eb56eb 88e0bae 5cc3519 88e0bae ab3f0c3 5cc3519 88e0bae 5cc3519 88e0bae ab3f0c3 5cc3519 88e0bae 5cc3519 88e0bae 5cc3519 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
from typing import List, Dict
from PIL.Image import Image
import torch
from transformers import AutoModel, AutoProcessor
from .utils import normalize_vectors
MODEL_NAME = "Marqo/marqo-fashionCLIP"
class FashionCLIPEncoder:
def __init__(self, normalize: bool = False):
self.normalize = normalize
self.device = torch.device("cpu")
self.processor = AutoProcessor.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
)
self.model = AutoModel.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
)
self.model.to(self.device)
self.model.eval()
def encode_text(self, texts: List[str]) -> List[List[float]]:
kwargs = {
"padding": "max_length",
"return_tensors": "pt",
"truncation": True,
}
inputs = self.processor(text=texts, **kwargs)
with torch.no_grad():
batch = {k: v.to(self.device) for k, v in inputs.items()}
vectors = self.model.get_text_features(**batch)
return self._postprocess_vectors(vectors)
def encode_images(self, images: List[Image]) -> List[List[float]]:
inputs = self.processor(images=images, return_tensors="pt")
with torch.no_grad():
batch = {k: v.to(self.device) for k, v in inputs.items()}
vectors = self.model.get_image_features(**batch)
return self._postprocess_vectors(vectors)
def _postprocess_vectors(self, vectors: torch.Tensor) -> List[List[float]]:
if self.normalize:
vectors = normalize_vectors(vectors)
return vectors.detach().cpu().numpy().tolist() |