Refactor encoder.py to remove debug print statements and ensure inputs are moved to the correct device for text and image encoding.
Browse files- src/encoder.py +2 -7
src/encoder.py
CHANGED
@@ -25,11 +25,6 @@ class FashionCLIPEncoder:
|
|
25 |
self.model.to(self.device)
|
26 |
self.model.eval()
|
27 |
|
28 |
-
test = "baggy jeans"
|
29 |
-
print(test)
|
30 |
-
vector = self.encode_text([test])
|
31 |
-
print(vector)
|
32 |
-
|
33 |
def encode_text(self, texts: List[str]) -> List[List[float]]:
|
34 |
kwargs = {
|
35 |
"padding": "max_length",
|
@@ -39,7 +34,7 @@ class FashionCLIPEncoder:
|
|
39 |
inputs = self.processor(text=texts, **kwargs)
|
40 |
|
41 |
with torch.no_grad():
|
42 |
-
batch = {k: v for k, v in inputs.items()}
|
43 |
return self._encode_text(batch)
|
44 |
|
45 |
def encode_images(self, images: List[Image]) -> List[List[float]]:
|
@@ -49,7 +44,7 @@ class FashionCLIPEncoder:
|
|
49 |
inputs = self.processor(images=images, **kwargs)
|
50 |
|
51 |
with torch.no_grad():
|
52 |
-
batch = {k: v for k, v in inputs.items()}
|
53 |
return self._encode_images(batch)
|
54 |
|
55 |
def _encode_text(self, batch: Dict) -> List[List[float]]:
|
|
|
25 |
self.model.to(self.device)
|
26 |
self.model.eval()
|
27 |
|
|
|
|
|
|
|
|
|
|
|
28 |
def encode_text(self, texts: List[str]) -> List[List[float]]:
|
29 |
kwargs = {
|
30 |
"padding": "max_length",
|
|
|
34 |
inputs = self.processor(text=texts, **kwargs)
|
35 |
|
36 |
with torch.no_grad():
|
37 |
+
batch = {k: v.to(self.device) for k, v in inputs.items()}
|
38 |
return self._encode_text(batch)
|
39 |
|
40 |
def encode_images(self, images: List[Image]) -> List[List[float]]:
|
|
|
44 |
inputs = self.processor(images=images, **kwargs)
|
45 |
|
46 |
with torch.no_grad():
|
47 |
+
batch = {k: v.to(self.device) for k, v in inputs.items()}
|
48 |
return self._encode_images(batch)
|
49 |
|
50 |
def _encode_text(self, batch: Dict) -> List[List[float]]:
|