File size: 2,092 Bytes
372785b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch import Tensor
import logging
from huggingface_hub import hf_hub_download
import functools
from typing import Callable, Optional

def process_video_gt_masks(gt_masks, num_frames, num_objs):
    gt_masks_processed = []
    for i in range(num_frames):
        for j in range(num_objs):
            gt_masks_processed.append(gt_masks[j*num_frames+i])
    return gt_masks_processed

def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'):
    HF_HUB_PREFIX = 'hf-hub:'
    if filename.startswith(HF_HUB_PREFIX):
        model_id = filename[len(HF_HUB_PREFIX):]
        filename = hf_hub_download(model_id, 'pytorch_model.bin')

    checkpoint = torch.load(filename, map_location=map_location)

    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    elif 'model' in checkpoint:
        state_dict = checkpoint['model']
    else:
        state_dict = checkpoint
    if not prefix:
        return state_dict
    if not prefix.endswith('.'):
        prefix += '.'
    prefix_len = len(prefix)

    state_dict = {
        k[prefix_len:]: v
        for k, v in state_dict.items() if k.startswith(prefix)
    }

    assert state_dict, f'{prefix} is not in the pretrained model'
    return state_dict


def load_state_dict_to_model(model, state_dict,  logger='current'):
    missing_keys, unexpected_keys = model.load_state_dict(state_dict)
    if missing_keys:
        raise RuntimeError()
    if unexpected_keys:
        raise RuntimeError()

def genetate_video_pred_embeddings(pred_embeddings_list, frames_per_batch):
    assert len(pred_embeddings_list) == len(frames_per_batch), \
    f"Lengths do not match: len(pred_embeddings_list)={len(pred_embeddings_list)}, len(frames_per_batch)={len(frames_per_batch)}"
    
    pred_embeddings_list_video = []
    for pred_embedding_batch, frame_nums in zip(pred_embeddings_list, frames_per_batch):
        pred_embeddings_list_video += [pred_embedding_batch] * frame_nums
    return pred_embeddings_list_video