Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,320 Bytes
08f69f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
from typing import Literal, Optional
import json
import open_clip
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from open_clip import create_model_from_pretrained
from torchvision.transforms import Normalize
from think_sound.models.factory import create_model_from_config
from think_sound.models.utils import load_ckpt_state_dict
from think_sound.training.utils import copy_state_dict
from transformers import AutoModel
from transformers import AutoProcessor
from transformers import T5EncoderModel, AutoTokenizer
import logging
from data_utils.ext.synchformer import Synchformer
log = logging.getLogger()
def patch_clip(clip_model):
# a hack to make it output last hidden states
# https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269
def new_get_text_features(self, input_ids=None, attention_mask=None, position_ids=None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = text_outputs[0]
pooled_output = text_outputs[1]
text_features = self.text_projection(pooled_output)
return text_features, last_hidden_state
clip_model.get_text_features = new_get_text_features.__get__(clip_model)
return clip_model
class FeaturesUtils(nn.Module):
def __init__(
self,
*,
vae_ckpt: Optional[str] = None,
vae_config: Optional[str] = None,
synchformer_ckpt: Optional[str] = None,
enable_conditions: bool = True,
need_vae_encoder: bool = True,
):
super().__init__()
if enable_conditions:
self.clip_model = AutoModel.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
self.clip_model = patch_clip(self.clip_model)
self.t5_tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-xl")
self.t5_model = T5EncoderModel.from_pretrained("google/t5-v1_1-xl")
self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
# self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
# std=[0.26862954, 0.26130258, 0.27577711])
self.synchformer = Synchformer()
self.synchformer.load_state_dict(
torch.load(synchformer_ckpt, weights_only=True, map_location='cpu'))
# self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14'
else:
self.clip_model = None
self.synchformer = None
self.tokenizer = None
if vae_ckpt is not None:
with open(vae_config) as f:
vae_config = json.load(f)
self.vae = create_model_from_config(vae_config)
print(f"Loading model checkpoint from {vae_ckpt}")
# Load checkpoint
copy_state_dict(self.vae, load_ckpt_state_dict(vae_ckpt,prefix='autoencoder.'))#,prefix='autoencoder.'
else:
self.tod = None
def compile(self):
if self.clip_model is not None:
self.clip_model.encode_image = torch.compile(self.clip_model.encode_image)
self.clip_model.encode_text = torch.compile(self.clip_model.encode_text)
if self.synchformer is not None:
self.synchformer = torch.compile(self.synchformer)
def train(self, mode: bool) -> None:
return super().train(False)
@torch.inference_mode()
def encode_video_with_clip(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
assert self.clip_model is not None, 'CLIP is not loaded'
# x: (B, T, C, H, W) H/W: 384
b, t, c, h, w = x.shape
assert c == 3 and h == 224 and w == 224
# x = self.clip_preprocess(x)
x = rearrange(x, 'b t c h w -> (b t) c h w')
outputs = []
if batch_size < 0:
batch_size = b * t
for i in range(0, b * t, batch_size):
outputs.append(self.clip_model.get_image_features(x[i:i + batch_size]))
x = torch.cat(outputs, dim=0)
# x = self.clip_model.encode_image(x, normalize=True)
x = rearrange(x, '(b t) d -> b t d', b=b)
return x
@torch.inference_mode()
def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
assert self.synchformer is not None, 'Synchformer is not loaded'
# x: (B, T, C, H, W) H/W: 384
b, t, c, h, w = x.shape
# import ipdb
# ipdb.set_trace()
assert c == 3 and h == 224 and w == 224
# partition the video
segment_size = 16
step_size = 8
num_segments = (t - segment_size) // step_size + 1
segments = []
for i in range(num_segments):
segments.append(x[:, i * step_size:i * step_size + segment_size])
x = torch.stack(segments, dim=1) # (B, S, T, C, H, W)
outputs = []
if batch_size < 0:
batch_size = b
x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w')
for i in range(0, b * num_segments, batch_size):
outputs.append(self.synchformer(x[i:i + batch_size]))
x = torch.cat(outputs, dim=0)
x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b)
return x
@torch.inference_mode()
def encode_text(self, text: list[str]) -> torch.Tensor:
assert self.clip_model is not None, 'CLIP is not loaded'
# assert self.tokenizer is not None, 'Tokenizer is not loaded'
# x: (B, L)
tokens = self.clip_processor(text=text, truncation=True, max_length=77, padding="max_length",return_tensors="pt").to(self.device)
return self.clip_model.get_text_features(**tokens)
@torch.inference_mode()
def encode_t5_text(self, text: list[str]) -> torch.Tensor:
assert self.t5_model is not None, 'T5 model is not loaded'
assert self.t5_tokenizer is not None, 'T5 Tokenizer is not loaded'
# x: (B, L)
inputs = self.t5_tokenizer(text,
truncation=True,
max_length=77,
padding="max_length",
return_tensors="pt").to(self.device)
return self.t5_model(**inputs).last_hidden_state
@torch.inference_mode()
def encode_audio(self, x) -> torch.Tensor:
x = self.vae.encode(x)
return x
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
|