Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
import torch.nn as nn | |
from torch import Tensor | |
import logging | |
from huggingface_hub import hf_hub_download | |
import functools | |
from typing import Callable, Optional | |
def process_video_gt_masks(gt_masks, num_frames, num_objs): | |
gt_masks_processed = [] | |
for i in range(num_frames): | |
for j in range(num_objs): | |
gt_masks_processed.append(gt_masks[j*num_frames+i]) | |
return gt_masks_processed | |
def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'): | |
HF_HUB_PREFIX = 'hf-hub:' | |
if filename.startswith(HF_HUB_PREFIX): | |
model_id = filename[len(HF_HUB_PREFIX):] | |
filename = hf_hub_download(model_id, 'pytorch_model.bin') | |
checkpoint = torch.load(filename, map_location=map_location) | |
if 'state_dict' in checkpoint: | |
state_dict = checkpoint['state_dict'] | |
elif 'model' in checkpoint: | |
state_dict = checkpoint['model'] | |
else: | |
state_dict = checkpoint | |
if not prefix: | |
return state_dict | |
if not prefix.endswith('.'): | |
prefix += '.' | |
prefix_len = len(prefix) | |
state_dict = { | |
k[prefix_len:]: v | |
for k, v in state_dict.items() if k.startswith(prefix) | |
} | |
assert state_dict, f'{prefix} is not in the pretrained model' | |
return state_dict | |
def load_state_dict_to_model(model, state_dict, logger='current'): | |
missing_keys, unexpected_keys = model.load_state_dict(state_dict) | |
if missing_keys: | |
raise RuntimeError() | |
if unexpected_keys: | |
raise RuntimeError() | |
def genetate_video_pred_embeddings(pred_embeddings_list, frames_per_batch): | |
assert len(pred_embeddings_list) == len(frames_per_batch), \ | |
f"Lengths do not match: len(pred_embeddings_list)={len(pred_embeddings_list)}, len(frames_per_batch)={len(frames_per_batch)}" | |
pred_embeddings_list_video = [] | |
for pred_embedding_batch, frame_nums in zip(pred_embeddings_list, frames_per_batch): | |
pred_embeddings_list_video += [pred_embedding_batch] * frame_nums | |
return pred_embeddings_list_video | |