Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # Portions Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import gzip | |
| import html | |
| import io | |
| import math | |
| from functools import lru_cache | |
| from typing import Callable, List, Optional, Tuple | |
| import ftfy | |
| import numpy as np | |
| import regex as re | |
| import torch | |
| import torch.nn as nn | |
| from iopath.common.file_io import g_pathmgr | |
| from timm.models.layers import trunc_normal_ | |
| from .helpers import VerboseNNModule, cast_if_src_dtype | |
| def get_sinusoid_encoding_table(n_position, d_hid): | |
| """Sinusoid position encoding table""" | |
| # TODO: make it with torch instead of numpy | |
| def get_position_angle_vec(position): | |
| return [ | |
| position / np.power(10000, 2 * (hid_j // 2) / d_hid) | |
| for hid_j in range(d_hid) | |
| ] | |
| sinusoid_table = np.array( | |
| [get_position_angle_vec(pos_i) for pos_i in range(n_position)] | |
| ) | |
| sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |
| sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |
| return torch.FloatTensor(sinusoid_table).unsqueeze(0) | |
| def interpolate_pos_encoding_2d(target_spatial_size, pos_embed): | |
| N = pos_embed.shape[1] | |
| if N == target_spatial_size: | |
| return pos_embed | |
| dim = pos_embed.shape[-1] | |
| # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32 | |
| pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32) | |
| pos_embed = nn.functional.interpolate( | |
| pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( | |
| 0, 3, 1, 2 | |
| ), | |
| scale_factor=math.sqrt(target_spatial_size / N), | |
| mode="bicubic", | |
| ) | |
| if updated: | |
| pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16) | |
| pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) | |
| return pos_embed | |
| def interpolate_pos_encoding( | |
| npatch_per_img, | |
| pos_embed, | |
| patches_layout, | |
| input_shape=None, | |
| first_patch_idx=1, | |
| ): | |
| assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none" | |
| N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists | |
| if npatch_per_img == N: | |
| return pos_embed | |
| assert ( | |
| patches_layout[-1] == patches_layout[-2] | |
| ), "Interpolation of pos embed not supported for non-square layouts" | |
| class_emb = pos_embed[:, :first_patch_idx] | |
| pos_embed = pos_embed[:, first_patch_idx:] | |
| if input_shape is None or patches_layout[0] == 1: | |
| # simple 2D pos embedding, no temporal component | |
| pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed) | |
| elif patches_layout[0] > 1: | |
| # pos embed has a temporal component | |
| assert len(input_shape) == 4, "temporal interpolation not supported" | |
| # we only support 2D interpolation in this case | |
| num_frames = patches_layout[0] | |
| num_spatial_tokens = patches_layout[1] * patches_layout[2] | |
| pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1) | |
| # interpolate embedding for zeroth frame | |
| pos_embed = interpolate_pos_encoding_2d( | |
| npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0) | |
| ) | |
| else: | |
| raise ValueError("This type of interpolation isn't implemented") | |
| return torch.cat((class_emb, pos_embed), dim=1) | |
| def _get_pos_embedding( | |
| npatch_per_img, | |
| pos_embed, | |
| patches_layout, | |
| input_shape, | |
| first_patch_idx=1, | |
| ): | |
| pos_embed = interpolate_pos_encoding( | |
| npatch_per_img, | |
| pos_embed, | |
| patches_layout, | |
| input_shape=input_shape, | |
| first_patch_idx=first_patch_idx, | |
| ) | |
| return pos_embed | |
| class PatchEmbedGeneric(nn.Module): | |
| """ | |
| PatchEmbed from Hydra | |
| """ | |
| def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None): | |
| super().__init__() | |
| if len(proj_stem) > 1: | |
| self.proj = nn.Sequential(*proj_stem) | |
| else: | |
| # Special case to be able to load pre-trained models that were | |
| # trained with a standard stem | |
| self.proj = proj_stem[0] | |
| self.norm_layer = norm_layer | |
| def get_patch_layout(self, img_size): | |
| with torch.no_grad(): | |
| dummy_img = torch.zeros( | |
| [ | |
| 1, | |
| ] | |
| + img_size | |
| ) | |
| dummy_out = self.proj(dummy_img) | |
| embed_dim = dummy_out.shape[1] | |
| patches_layout = tuple(dummy_out.shape[2:]) | |
| num_patches = np.prod(patches_layout) | |
| return patches_layout, num_patches, embed_dim | |
| def forward(self, x): | |
| x = self.proj(x) | |
| # B C (T) H W -> B (T)HW C | |
| x = x.flatten(2).transpose(1, 2) | |
| if self.norm_layer is not None: | |
| x = self.norm_layer(x) | |
| return x | |
| class SpatioTemporalPosEmbeddingHelper(VerboseNNModule): | |
| def __init__( | |
| self, | |
| patches_layout: List, | |
| num_patches: int, | |
| num_cls_tokens: int, | |
| embed_dim: int, | |
| learnable: bool, | |
| ) -> None: | |
| super().__init__() | |
| self.num_cls_tokens = num_cls_tokens | |
| self.patches_layout = patches_layout | |
| self.num_patches = num_patches | |
| self.num_tokens = num_cls_tokens + num_patches | |
| self.learnable = learnable | |
| if self.learnable: | |
| self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) | |
| trunc_normal_(self.pos_embed, std=0.02) | |
| else: | |
| self.register_buffer( | |
| "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim) | |
| ) | |
| def get_pos_embedding(self, vision_input, all_vision_tokens): | |
| input_shape = vision_input.shape | |
| pos_embed = _get_pos_embedding( | |
| all_vision_tokens.size(1) - self.num_cls_tokens, | |
| pos_embed=self.pos_embed, | |
| patches_layout=self.patches_layout, | |
| input_shape=input_shape, | |
| first_patch_idx=self.num_cls_tokens, | |
| ) | |
| return pos_embed | |
| class RGBDTPreprocessor(VerboseNNModule): | |
| def __init__( | |
| self, | |
| rgbt_stem: PatchEmbedGeneric, | |
| depth_stem: Optional[PatchEmbedGeneric], | |
| img_size: Tuple = (3, 224, 224), | |
| num_cls_tokens: int = 1, | |
| pos_embed_fn: Optional[Callable] = None, | |
| use_type_embed: bool = False, | |
| init_param_style: str = "openclip", | |
| ) -> None: | |
| super().__init__() | |
| stem = rgbt_stem if rgbt_stem is not None else depth_stem | |
| ( | |
| self.patches_layout, | |
| self.num_patches, | |
| self.embed_dim, | |
| ) = stem.get_patch_layout(img_size) | |
| self.rgbt_stem = rgbt_stem | |
| self.depth_stem = depth_stem | |
| self.use_pos_embed = pos_embed_fn is not None | |
| self.use_type_embed = use_type_embed | |
| self.num_cls_tokens = num_cls_tokens | |
| if self.use_pos_embed: | |
| self.pos_embedding_helper = pos_embed_fn( | |
| patches_layout=self.patches_layout, | |
| num_cls_tokens=num_cls_tokens, | |
| num_patches=self.num_patches, | |
| embed_dim=self.embed_dim, | |
| ) | |
| if self.num_cls_tokens > 0: | |
| self.cls_token = nn.Parameter( | |
| torch.zeros(1, self.num_cls_tokens, self.embed_dim) | |
| ) | |
| if self.use_type_embed: | |
| self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) | |
| self.init_parameters(init_param_style) | |
| def init_parameters(self, init_param_style): | |
| if init_param_style == "openclip": | |
| # OpenCLIP style initialization | |
| scale = self.embed_dim**-0.5 | |
| if self.use_pos_embed: | |
| nn.init.normal_(self.pos_embedding_helper.pos_embed) | |
| self.pos_embedding_helper.pos_embed *= scale | |
| if self.num_cls_tokens > 0: | |
| nn.init.normal_(self.cls_token) | |
| self.cls_token *= scale | |
| elif init_param_style == "vit": | |
| self.cls_token.data.fill_(0) | |
| else: | |
| raise ValueError(f"Unknown init {init_param_style}") | |
| if self.use_type_embed: | |
| nn.init.normal_(self.type_embed) | |
| def tokenize_input_and_cls_pos(self, input, stem, mask): | |
| # tokens is of shape B x L x D | |
| tokens = stem(input) | |
| assert tokens.ndim == 3 | |
| assert tokens.shape[2] == self.embed_dim | |
| B = tokens.shape[0] | |
| if self.num_cls_tokens > 0: | |
| class_tokens = self.cls_token.expand( | |
| B, -1, -1 | |
| ) # stole class_tokens impl from Phil Wang, thanks | |
| tokens = torch.cat((class_tokens, tokens), dim=1) | |
| if self.use_pos_embed: | |
| pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens) | |
| tokens = tokens + pos_embed | |
| if self.use_type_embed: | |
| tokens = tokens + self.type_embed.expand(B, -1, -1) | |
| return tokens | |
| def forward(self, vision=None, depth=None, patch_mask=None): | |
| if patch_mask is not None: | |
| raise NotImplementedError() | |
| if vision is not None: | |
| vision_tokens = self.tokenize_input_and_cls_pos( | |
| vision, self.rgbt_stem, patch_mask | |
| ) | |
| if depth is not None: | |
| depth_tokens = self.tokenize_input_and_cls_pos( | |
| depth, self.depth_stem, patch_mask | |
| ) | |
| # aggregate tokens | |
| if vision is not None and depth is not None: | |
| final_tokens = vision_tokens + depth_tokens | |
| else: | |
| final_tokens = vision_tokens if vision is not None else depth_tokens | |
| return_dict = { | |
| "trunk": { | |
| "tokens": final_tokens, | |
| }, | |
| "head": {}, | |
| } | |
| return return_dict | |
| class AudioPreprocessor(RGBDTPreprocessor): | |
| def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None: | |
| super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs) | |
| def forward(self, audio=None): | |
| return super().forward(vision=audio) | |
| class ThermalPreprocessor(RGBDTPreprocessor): | |
| def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None: | |
| super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs) | |
| def forward(self, thermal=None): | |
| return super().forward(vision=thermal) | |
| def build_causal_attention_mask(context_length): | |
| # lazily create causal attention mask, with full attention between the vision tokens | |
| # pytorch uses additive attention mask; fill with -inf | |
| mask = torch.empty(context_length, context_length, requires_grad=False) | |
| mask.fill_(float("-inf")) | |
| mask.triu_(1) # zero out the lower diagonal | |
| return mask | |
| class TextPreprocessor(VerboseNNModule): | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| context_length: int, | |
| embed_dim: int, | |
| causal_masking: bool, | |
| supply_seq_len_to_head: bool = True, | |
| num_cls_tokens: int = 0, | |
| init_param_style: str = "openclip", | |
| ) -> None: | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.context_length = context_length | |
| self.token_embedding = nn.Embedding(vocab_size, embed_dim) | |
| self.pos_embed = nn.Parameter( | |
| torch.empty(1, self.context_length + num_cls_tokens, embed_dim) | |
| ) | |
| self.causal_masking = causal_masking | |
| if self.causal_masking: | |
| mask = build_causal_attention_mask(self.context_length) | |
| # register the mask as a buffer so it can be moved to the right device | |
| self.register_buffer("mask", mask) | |
| self.supply_seq_len_to_head = supply_seq_len_to_head | |
| self.num_cls_tokens = num_cls_tokens | |
| self.embed_dim = embed_dim | |
| if num_cls_tokens > 0: | |
| assert self.causal_masking is False, "Masking + CLS token isn't implemented" | |
| self.cls_token = nn.Parameter( | |
| torch.zeros(1, self.num_cls_tokens, embed_dim) | |
| ) | |
| self.init_parameters(init_param_style) | |
| def init_parameters(self, init_param_style="openclip"): | |
| # OpenCLIP style initialization | |
| nn.init.normal_(self.token_embedding.weight, std=0.02) | |
| nn.init.normal_(self.pos_embed, std=0.01) | |
| if init_param_style == "openclip": | |
| # OpenCLIP style initialization | |
| scale = self.embed_dim**-0.5 | |
| if self.num_cls_tokens > 0: | |
| nn.init.normal_(self.cls_token) | |
| self.cls_token *= scale | |
| elif init_param_style == "vit": | |
| self.cls_token.data.fill_(0) | |
| else: | |
| raise ValueError(f"Unknown init {init_param_style}") | |
| def forward(self, text): | |
| # text tokens are of shape B x L x D | |
| text_tokens = self.token_embedding(text) | |
| # concat CLS tokens if any | |
| if self.num_cls_tokens > 0: | |
| B = text_tokens.shape[0] | |
| class_tokens = self.cls_token.expand( | |
| B, -1, -1 | |
| ) # stole class_tokens impl from Phil Wang, thanks | |
| text_tokens = torch.cat((class_tokens, text_tokens), dim=1) | |
| text_tokens = text_tokens + self.pos_embed | |
| return_dict = { | |
| "trunk": { | |
| "tokens": text_tokens, | |
| }, | |
| "head": {}, | |
| } | |
| # Compute sequence length after adding CLS tokens | |
| if self.supply_seq_len_to_head: | |
| text_lengths = text.argmax(dim=-1) | |
| return_dict["head"] = { | |
| "seq_len": text_lengths, | |
| } | |
| if self.causal_masking: | |
| return_dict["trunk"].update({"attn_mask": self.mask}) | |
| return return_dict | |
| class Im2Video(nn.Module): | |
| """Convert an image into a trivial video.""" | |
| def __init__(self, time_dim=2): | |
| super().__init__() | |
| self.time_dim = time_dim | |
| def forward(self, x): | |
| if x.ndim == 4: | |
| # B, C, H, W -> B, C, T, H, W | |
| return x.unsqueeze(self.time_dim) | |
| elif x.ndim == 5: | |
| return x | |
| else: | |
| raise ValueError(f"Dimension incorrect {x.shape}") | |
| class PadIm2Video(Im2Video): | |
| def __init__(self, ntimes, pad_type, time_dim=2): | |
| super().__init__(time_dim=time_dim) | |
| assert ntimes > 0 | |
| assert pad_type in ["zero", "repeat"] | |
| self.ntimes = ntimes | |
| self.pad_type = pad_type | |
| def forward(self, x): | |
| x = super().forward(x) | |
| if x.shape[self.time_dim] == 1: | |
| if self.pad_type == "repeat": | |
| new_shape = [1] * len(x.shape) | |
| new_shape[self.time_dim] = self.ntimes | |
| x = x.repeat(new_shape) | |
| elif self.pad_type == "zero": | |
| padarg = [0, 0] * len(x.shape) | |
| padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim] | |
| x = nn.functional.pad(x, padarg) | |
| return x | |
| # Modified from github.com/openai/CLIP | |
| def bytes_to_unicode(): | |
| """ | |
| Returns list of utf-8 byte and a corresponding list of unicode strings. | |
| The reversible bpe codes work on unicode strings. | |
| This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | |
| When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | |
| This is a signficant percentage of your normal, say, 32K bpe vocab. | |
| To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | |
| And avoids mapping to whitespace/control characters the bpe code barfs on. | |
| """ | |
| bs = ( | |
| list(range(ord("!"), ord("~") + 1)) | |
| + list(range(ord("¡"), ord("¬") + 1)) | |
| + list(range(ord("®"), ord("ÿ") + 1)) | |
| ) | |
| cs = bs[:] | |
| n = 0 | |
| for b in range(2**8): | |
| if b not in bs: | |
| bs.append(b) | |
| cs.append(2**8 + n) | |
| n += 1 | |
| cs = [chr(n) for n in cs] | |
| return dict(zip(bs, cs)) | |
| def get_pairs(word): | |
| """Return set of symbol pairs in a word. | |
| Word is represented as tuple of symbols (symbols being variable-length strings). | |
| """ | |
| pairs = set() | |
| prev_char = word[0] | |
| for char in word[1:]: | |
| pairs.add((prev_char, char)) | |
| prev_char = char | |
| return pairs | |
| def basic_clean(text): | |
| text = ftfy.fix_text(text) | |
| text = html.unescape(html.unescape(text)) | |
| return text.strip() | |
| def whitespace_clean(text): | |
| text = re.sub(r"\s+", " ", text) | |
| text = text.strip() | |
| return text | |
| class SimpleTokenizer(object): | |
| def __init__(self, bpe_path: str, context_length=77): | |
| self.byte_encoder = bytes_to_unicode() | |
| self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |
| with g_pathmgr.open(bpe_path, "rb") as fh: | |
| bpe_bytes = io.BytesIO(fh.read()) | |
| merges: List[str] = gzip.open(bpe_bytes).read().decode("utf-8").split("\n") | |
| merges = merges[1 : 49152 - 256 - 2 + 1] | |
| merges: List[Tuple[str, ...]] = [tuple(merge.split()) for merge in merges] | |
| vocab = list(bytes_to_unicode().values()) | |
| vocab = vocab + [v + "</w>" for v in vocab] | |
| for merge in merges: | |
| vocab.append("".join(merge)) | |
| vocab.extend(["<|startoftext|>", "<|endoftext|>"]) | |
| self.encoder = dict(zip(vocab, range(len(vocab)))) | |
| self.decoder = {v: k for k, v in self.encoder.items()} | |
| self.bpe_ranks = dict(zip(merges, range(len(merges)))) | |
| self.cache = { | |
| "<|startoftext|>": "<|startoftext|>", | |
| "<|endoftext|>": "<|endoftext|>", | |
| } | |
| self.pat = re.compile( | |
| r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", | |
| re.IGNORECASE, | |
| ) | |
| self.context_length = context_length | |
| def bpe(self, token): | |
| if token in self.cache: | |
| return self.cache[token] | |
| word = tuple(token[:-1]) + (token[-1] + "</w>",) | |
| pairs = get_pairs(word) | |
| if not pairs: | |
| return token + "</w>" | |
| while True: | |
| bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) | |
| if bigram not in self.bpe_ranks: | |
| break | |
| first, second = bigram | |
| new_word = [] | |
| i = 0 | |
| while i < len(word): | |
| try: | |
| j = word.index(first, i) | |
| new_word.extend(word[i:j]) | |
| i = j | |
| except: | |
| new_word.extend(word[i:]) | |
| break | |
| if word[i] == first and i < len(word) - 1 and word[i + 1] == second: | |
| new_word.append(first + second) | |
| i += 2 | |
| else: | |
| new_word.append(word[i]) | |
| i += 1 | |
| new_word = tuple(new_word) | |
| word = new_word | |
| if len(word) == 1: | |
| break | |
| else: | |
| pairs = get_pairs(word) | |
| word = " ".join(word) | |
| self.cache[token] = word | |
| return word | |
| def encode(self, text): | |
| bpe_tokens = [] | |
| text = whitespace_clean(basic_clean(text)).lower() | |
| for token in re.findall(self.pat, text): | |
| token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) | |
| bpe_tokens.extend( | |
| self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") | |
| ) | |
| return bpe_tokens | |
| def decode(self, tokens): | |
| text = "".join([self.decoder[token] for token in tokens]) | |
| text = ( | |
| bytearray([self.byte_decoder[c] for c in text]) | |
| .decode("utf-8", errors="replace") | |
| .replace("</w>", " ") | |
| ) | |
| return text | |
| def __call__(self, texts, context_length=None): | |
| if not context_length: | |
| context_length = self.context_length | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| sot_token = self.encoder["<|startoftext|>"] | |
| eot_token = self.encoder["<|endoftext|>"] | |
| all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] | |
| result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |
| for i, tokens in enumerate(all_tokens): | |
| tokens = tokens[:context_length] | |
| result[i, : len(tokens)] = torch.tensor(tokens) | |
| if len(result) == 1: | |
| return result[0] | |
| return result | |
| class IMUPreprocessor(VerboseNNModule): | |
| def __init__( | |
| self, | |
| kernel_size: int, | |
| imu_stem: PatchEmbedGeneric, | |
| embed_dim: int, | |
| img_size: Tuple = (6, 2000), | |
| num_cls_tokens: int = 1, | |
| pos_embed_fn: Optional[Callable] = None, | |
| init_param_style: str = "openclip", | |
| ) -> None: | |
| super().__init__() | |
| self.imu_stem = imu_stem | |
| self.embed_dim = embed_dim | |
| self.use_pos_embed = pos_embed_fn is not None | |
| self.num_cls_tokens = num_cls_tokens | |
| self.kernel_size = kernel_size | |
| self.pos_embed = nn.Parameter( | |
| torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim) | |
| ) | |
| if self.num_cls_tokens > 0: | |
| self.cls_token = nn.Parameter( | |
| torch.zeros(1, self.num_cls_tokens, self.embed_dim) | |
| ) | |
| self.init_parameters(init_param_style) | |
| def init_parameters(self, init_param_style): | |
| nn.init.normal_(self.pos_embed, std=0.01) | |
| if init_param_style == "openclip": | |
| # OpenCLIP style initialization | |
| scale = self.embed_dim**-0.5 | |
| if self.num_cls_tokens > 0: | |
| nn.init.normal_(self.cls_token) | |
| self.cls_token *= scale | |
| elif init_param_style == "vit": | |
| self.cls_token.data.fill_(0) | |
| else: | |
| raise ValueError(f"Unknown init {init_param_style}") | |
| def tokenize_input_and_cls_pos(self, input, stem): | |
| # tokens is of shape B x L x D | |
| tokens = stem.norm_layer(stem.proj(input)) | |
| assert tokens.ndim == 3 | |
| assert tokens.shape[2] == self.embed_dim | |
| B = tokens.shape[0] | |
| if self.num_cls_tokens > 0: | |
| class_tokens = self.cls_token.expand( | |
| B, -1, -1 | |
| ) # stole class_tokens impl from Phil Wang, thanks | |
| tokens = torch.cat((class_tokens, tokens), dim=1) | |
| if self.use_pos_embed: | |
| tokens = tokens + self.pos_embed | |
| return tokens | |
| def forward(self, imu): | |
| # Patchify | |
| imu = imu.unfold( | |
| -1, | |
| self.kernel_size, | |
| self.kernel_size, | |
| ).permute(0, 2, 1, 3) | |
| imu = imu.reshape(imu.size(0), imu.size(1), -1) | |
| imu_tokens = self.tokenize_input_and_cls_pos( | |
| imu, | |
| self.imu_stem, | |
| ) | |
| return_dict = { | |
| "trunk": { | |
| "tokens": imu_tokens, | |
| }, | |
| "head": {}, | |
| } | |
| return return_dict | |