Rajagopal commited on
Commit
a48bf14
·
1 Parent(s): dfb5932

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -1
app.py CHANGED
@@ -82,6 +82,29 @@ def video_text_zeroshot(image, text_list):
82
 
83
  return score_dict
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
 
87
  def inference(
@@ -96,7 +119,7 @@ def inference(
96
  elif task == "audio-text":
97
  result = audio_text_zeroshot(audio, text_list)
98
  elif task == "video-text":
99
- result = image_text_zeroshot(image2, text_list)
100
  else:
101
  raise NotImplementedError
102
  return result
 
82
 
83
  return score_dict
84
 
85
+ def doubleimage_text_zeroshot(image, image2, text_list):
86
+ image_paths = [image]
87
+ labels = [label.strip(" ") for label in text_list.strip(" ").split("|")]
88
+ inputs = {
89
+ ModalityType.TEXT: data.load_and_transform_text(labels, device),
90
+ ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
91
+ }
92
+
93
+ with torch.no_grad():
94
+ embeddings = model(inputs)
95
+
96
+ scores = (
97
+ torch.softmax(
98
+ embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1
99
+ )
100
+ .squeeze(0)
101
+ .tolist()
102
+ )
103
+
104
+ score_dict = {label: score for label, score in zip(labels, scores)}
105
+
106
+ return score_dict
107
+
108
 
109
 
110
  def inference(
 
119
  elif task == "audio-text":
120
  result = audio_text_zeroshot(audio, text_list)
121
  elif task == "video-text":
122
+ result = doubleimage_text_zeroshot(image, image2, text_list)
123
  else:
124
  raise NotImplementedError
125
  return result