Spaces:
Runtime error
Runtime error
import data | |
import torch | |
import gradio as gr | |
from models import imagebind_model | |
from models.imagebind_model import ModalityType | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
model = imagebind_model.imagebind_huge(pretrained=True) | |
model.eval() | |
model.to(device) | |
def image_text_zeroshot(image, text_list): | |
image_paths = [image] | |
labels = [label.strip(" ") for label in text_list.strip(" ").split("|")] | |
inputs = { | |
ModalityType.TEXT: data.load_and_transform_text(labels, device), | |
ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), | |
} | |
with torch.no_grad(): | |
embeddings = model(inputs) | |
scores = ( | |
torch.softmax( | |
embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1 | |
) | |
.squeeze(0) | |
.tolist() | |
) | |
score_dict = {label: score for label, score in zip(labels, scores)} | |
return score_dict | |
def audio_text_zeroshot(audio, text_list): | |
audio_paths = [audio] | |
labels = [label.strip(" ") for label in text_list.strip(" ").split("|")] | |
inputs = { | |
ModalityType.TEXT: data.load_and_transform_text(labels, device), | |
ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), | |
} | |
with torch.no_grad(): | |
embeddings = model(inputs) | |
scores = ( | |
torch.softmax( | |
embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1 | |
) | |
.squeeze(0) | |
.tolist() | |
) | |
score_dict = {label: score for label, score in zip(labels, scores)} | |
return score_dict | |
def video_text_zeroshot(image, text_list): | |
image_paths = [image] | |
labels = [label.strip(" ") for label in text_list.strip(" ").split("|")] | |
inputs = { | |
ModalityType.TEXT: data.load_and_transform_text(labels, device), | |
ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), | |
} | |
with torch.no_grad(): | |
embeddings = model(inputs) | |
scores = ( | |
torch.softmax( | |
embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1 | |
) | |
.squeeze(0) | |
.tolist() | |
) | |
score_dict = {label: score for label, score in zip(labels, scores)} | |
return score_dict | |
def doubleimage_text_zeroshot(image, image2, text_list): | |
image_paths = [image, image2] | |
labels = [label.strip(" ") for label in text_list.strip(" ").split("|")] | |
inputs = { | |
ModalityType.TEXT: data.load_and_transform_text(labels, device), | |
ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), | |
} | |
with torch.no_grad(): | |
embeddings = model(inputs) | |
scores = ( | |
torch.softmax( | |
embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1 | |
) | |
.squeeze(0) | |
.tolist() | |
) | |
score_dict = {label: score for label, score in zip(labels, scores)} | |
return score_dict | |
def doubleimage_text_zeroshotOLD(image, image2, text_list): | |
image_paths = [image, image2] | |
labels = [label.strip(" ") for label in text_list.strip(" ").split("|")] | |
inputs = { | |
ModalityType.TEXT: data.load_and_transform_text(labels, device), | |
ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), | |
} | |
with torch.no_grad(): | |
embeddings = model(inputs) | |
return str(torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1) ) | |
def inference( | |
task, | |
text_list=None, | |
image=None, | |
audio=None, | |
image2=None, | |
): | |
if task == "image-text": | |
result = image_text_zeroshot(image, text_list) | |
elif task == "audio-text": | |
result = audio_text_zeroshot(audio, text_list) | |
elif task == "embeddings": | |
result = doubleimage_text_zeroshot(image, image2, text_list) | |
else: | |
raise NotImplementedError | |
return result | |
def main(): | |
inputs = [ | |
gr.inputs.Radio( | |
choices=[ | |
"image-text", | |
"audio-text", | |
"embeddings", | |
], | |
type="value", | |
default="embeddings", | |
label="Task", | |
), | |
gr.inputs.Textbox(lines=1, label="Candidate texts"), | |
gr.inputs.Image(type="filepath", label="Input image"), | |
gr.inputs.Audio(type="filepath", label="Input audio"), | |
gr.inputs.Image(type="filepath", label="Input image2"), | |
] | |
iface = gr.Interface( | |
inference, | |
inputs, | |
"label", | |
title="Multimodal AI assitive agents for Learning Disorders : Demo with embeddings of ImageBind: ", | |
) | |
iface.launch() | |
if __name__ == "__main__": | |
main() | |