Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from transformers import CLIPImageProcessor | |
try: | |
from imagebind.models import imagebind_model | |
from imagebind.models.imagebind_model import ModalityType | |
from imagebind.data import load_and_transform_audio_data | |
except ImportError: | |
pass | |
class ImageBindWrapper(nn.Module): | |
def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False): | |
super().__init__() | |
self.is_loaded = False | |
self.vision_tower_name = vision_tower | |
self.select_layer = select_layer | |
self.select_feature = select_feature | |
if not delay_load: | |
self.load_model() | |
def load_model(self): | |
self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
self.vision_tower = imagebind_model.imagebind_huge(pretrained=True) | |
for p in self.vision_tower.parameters(): | |
p.requires_grad = False | |
self.vision_tower.eval() | |
self.is_loaded = True | |
def train(self, mode=True): | |
self.training = mode | |
if self.is_loaded: | |
self.vision_tower.eval() | |
def forward(self, x): | |
if type(x) == dict: | |
if x["audios"] is not None: | |
inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()} | |
embeddings = self.vision_tower(inputs) | |
audio_embedding = embeddings[ModalityType.AUDIO] | |
return audio_embedding.unsqueeze(1) | |
else: | |
inputs = {ModalityType.VISION: x.to(dtype=self.dtype)} | |
embeddings = self.vision_tower(inputs) | |
vision_embedding = embeddings[ModalityType.VISION] | |
if vision_embedding.ndim == 2: | |
return vision_embedding.unsqueeze(1) | |
if vision_embedding.shape[1] == 257: | |
return vision_embedding[:, 1:] | |
raise ValueError(f"Unexpected shape: {vision_embedding.shape}") | |
def dummy_feature(self): | |
return torch.zeros(1, 1024, device=self.device, dtype=self.dtype) | |
def dtype(self): | |
return self.vision_tower.modality_preprocessors.vision.cls_token.dtype | |
def device(self): | |
return self.vision_tower.modality_preprocessors.vision.cls_token.device | |
def hidden_size(self): | |
return 1024 | |