File size: 1,721 Bytes
88e0bae
 
 
 
 
 
5cc3519
 
88e0bae
 
 
 
 
5cc3519
 
 
c54b3d1
fdf1598
88e0bae
fdf1598
c54b3d1
88e0bae
fdf1598
168270d
 
 
 
e5ee3a4
 
168270d
2eb56eb
88e0bae
 
 
 
 
 
5cc3519
88e0bae
 
 
ab3f0c3
5cc3519
 
 
88e0bae
 
5cc3519
88e0bae
 
ab3f0c3
5cc3519
 
 
88e0bae
5cc3519
 
 
88e0bae
5cc3519
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from typing import List, Dict
from PIL.Image import Image

import torch
from transformers import AutoModel, AutoProcessor

from .utils import normalize_vectors


MODEL_NAME = "Marqo/marqo-fashionCLIP"


class FashionCLIPEncoder:
    def __init__(self, normalize: bool = False):
        self.normalize = normalize

        self.device = torch.device("cpu")

        self.processor = AutoProcessor.from_pretrained(
            MODEL_NAME,
            trust_remote_code=True,
        )

        self.model = AutoModel.from_pretrained(
            MODEL_NAME,
            trust_remote_code=True,
        )

        self.model.to(self.device)
        self.model.eval()

    def encode_text(self, texts: List[str]) -> List[List[float]]:
        kwargs = {
            "padding": "max_length",
            "return_tensors": "pt",
            "truncation": True,
        }

        inputs = self.processor(text=texts, **kwargs)

        with torch.no_grad():
            batch = {k: v.to(self.device) for k, v in inputs.items()}
            vectors = self.model.get_text_features(**batch)
                
            return self._postprocess_vectors(vectors)

    def encode_images(self, images: List[Image]) -> List[List[float]]:
        inputs = self.processor(images=images, return_tensors="pt")

        with torch.no_grad():
            batch = {k: v.to(self.device) for k, v in inputs.items()}
            vectors = self.model.get_image_features(**batch)

            return self._postprocess_vectors(vectors)

    def _postprocess_vectors(self, vectors: torch.Tensor) -> List[List[float]]:
        if self.normalize:
            vectors = normalize_vectors(vectors)

        return vectors.detach().cpu().numpy().tolist()