pediot commited on
Commit
ab3f0c3
·
1 Parent(s): 5376b4a

Refactor encoder.py to remove debug print statements and ensure inputs are moved to the correct device for text and image encoding.

Browse files
Files changed (1) hide show
  1. src/encoder.py +2 -7
src/encoder.py CHANGED
@@ -25,11 +25,6 @@ class FashionCLIPEncoder:
25
  self.model.to(self.device)
26
  self.model.eval()
27
 
28
- test = "baggy jeans"
29
- print(test)
30
- vector = self.encode_text([test])
31
- print(vector)
32
-
33
  def encode_text(self, texts: List[str]) -> List[List[float]]:
34
  kwargs = {
35
  "padding": "max_length",
@@ -39,7 +34,7 @@ class FashionCLIPEncoder:
39
  inputs = self.processor(text=texts, **kwargs)
40
 
41
  with torch.no_grad():
42
- batch = {k: v for k, v in inputs.items()}
43
  return self._encode_text(batch)
44
 
45
  def encode_images(self, images: List[Image]) -> List[List[float]]:
@@ -49,7 +44,7 @@ class FashionCLIPEncoder:
49
  inputs = self.processor(images=images, **kwargs)
50
 
51
  with torch.no_grad():
52
- batch = {k: v for k, v in inputs.items()}
53
  return self._encode_images(batch)
54
 
55
  def _encode_text(self, batch: Dict) -> List[List[float]]:
 
25
  self.model.to(self.device)
26
  self.model.eval()
27
 
 
 
 
 
 
28
  def encode_text(self, texts: List[str]) -> List[List[float]]:
29
  kwargs = {
30
  "padding": "max_length",
 
34
  inputs = self.processor(text=texts, **kwargs)
35
 
36
  with torch.no_grad():
37
+ batch = {k: v.to(self.device) for k, v in inputs.items()}
38
  return self._encode_text(batch)
39
 
40
  def encode_images(self, images: List[Image]) -> List[List[float]]:
 
44
  inputs = self.processor(images=images, **kwargs)
45
 
46
  with torch.no_grad():
47
+ batch = {k: v.to(self.device) for k, v in inputs.items()}
48
  return self._encode_images(batch)
49
 
50
  def _encode_text(self, batch: Dict) -> List[List[float]]: