Spaces:
Running
on
Zero
Running
on
Zero
Delete sonique
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- sonique/Video_LLaMA/apply_delta.py +0 -49
- sonique/Video_LLaMA/eval_configs/video_llama_eval_only_vl.yaml +0 -36
- sonique/Video_LLaMA/eval_configs/video_llama_eval_withaudio.yaml +0 -38
- sonique/Video_LLaMA/inference.py +0 -94
- sonique/Video_LLaMA/video_llama/__init__.py +0 -31
- sonique/Video_LLaMA/video_llama/common/__init__.py +0 -0
- sonique/Video_LLaMA/video_llama/common/config.py +0 -468
- sonique/Video_LLaMA/video_llama/common/dist_utils.py +0 -137
- sonique/Video_LLaMA/video_llama/common/gradcam.py +0 -24
- sonique/Video_LLaMA/video_llama/common/logger.py +0 -195
- sonique/Video_LLaMA/video_llama/common/optims.py +0 -119
- sonique/Video_LLaMA/video_llama/common/registry.py +0 -329
- sonique/Video_LLaMA/video_llama/common/utils.py +0 -424
- sonique/Video_LLaMA/video_llama/configs/datasets/cc_sbu/align.yaml +0 -5
- sonique/Video_LLaMA/video_llama/configs/datasets/cc_sbu/defaults.yaml +0 -5
- sonique/Video_LLaMA/video_llama/configs/datasets/instruct/llava_instruct.yaml +0 -6
- sonique/Video_LLaMA/video_llama/configs/datasets/instruct/webvid_instruct.yaml +0 -6
- sonique/Video_LLaMA/video_llama/configs/datasets/laion/defaults.yaml +0 -5
- sonique/Video_LLaMA/video_llama/configs/datasets/webvid/defaults.yaml +0 -6
- sonique/Video_LLaMA/video_llama/configs/default.yaml +0 -5
- sonique/Video_LLaMA/video_llama/configs/models/minigpt4.yaml +0 -33
- sonique/Video_LLaMA/video_llama/configs/models/video_llama.yaml +0 -36
- sonique/Video_LLaMA/video_llama/conversation/__init__.py +0 -0
- sonique/Video_LLaMA/video_llama/conversation/conversation_video.py +0 -348
- sonique/Video_LLaMA/video_llama/datasets/__init__.py +0 -0
- sonique/Video_LLaMA/video_llama/datasets/builders/__init__.py +0 -77
- sonique/Video_LLaMA/video_llama/datasets/builders/base_dataset_builder.py +0 -236
- sonique/Video_LLaMA/video_llama/datasets/builders/image_text_pair_builder.py +0 -106
- sonique/Video_LLaMA/video_llama/datasets/builders/instruct_builder.py +0 -79
- sonique/Video_LLaMA/video_llama/datasets/builders/video_caption_builder.py +0 -34
- sonique/Video_LLaMA/video_llama/datasets/data_utils.py +0 -196
- sonique/Video_LLaMA/video_llama/datasets/datasets/__init__.py +0 -0
- sonique/Video_LLaMA/video_llama/datasets/datasets/base_dataset.py +0 -68
- sonique/Video_LLaMA/video_llama/datasets/datasets/caption_datasets.py +0 -85
- sonique/Video_LLaMA/video_llama/datasets/datasets/cc_sbu_dataset.py +0 -49
- sonique/Video_LLaMA/video_llama/datasets/datasets/dataloader_utils.py +0 -162
- sonique/Video_LLaMA/video_llama/datasets/datasets/laion_dataset.py +0 -31
- sonique/Video_LLaMA/video_llama/datasets/datasets/llava_instruct_dataset.py +0 -312
- sonique/Video_LLaMA/video_llama/datasets/datasets/video_instruct_dataset.py +0 -335
- sonique/Video_LLaMA/video_llama/datasets/datasets/webvid_datasets.py +0 -122
- sonique/Video_LLaMA/video_llama/models/ImageBind/.assets/bird_image.jpg +0 -0
- sonique/Video_LLaMA/video_llama/models/ImageBind/.assets/car_image.jpg +0 -0
- sonique/Video_LLaMA/video_llama/models/ImageBind/.assets/dog_image.jpg +0 -0
- sonique/Video_LLaMA/video_llama/models/ImageBind/CODE_OF_CONDUCT.md +0 -80
- sonique/Video_LLaMA/video_llama/models/ImageBind/CONTRIBUTING.md +0 -31
- sonique/Video_LLaMA/video_llama/models/ImageBind/LICENSE +0 -437
- sonique/Video_LLaMA/video_llama/models/ImageBind/README.md +0 -155
- sonique/Video_LLaMA/video_llama/models/ImageBind/bpe/bpe_simple_vocab_16e6.txt.gz +0 -3
- sonique/Video_LLaMA/video_llama/models/ImageBind/data.py +0 -338
- sonique/Video_LLaMA/video_llama/models/ImageBind/model_card.md +0 -94
sonique/Video_LLaMA/apply_delta.py
DELETED
@@ -1,49 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Apply the delta weights on top of a base model.
|
3 |
-
Adapted from: https://github.com/lm-sys/FastChat/blob/main/fastchat/model/apply_delta.py.
|
4 |
-
"""
|
5 |
-
import argparse
|
6 |
-
|
7 |
-
import torch
|
8 |
-
from tqdm import tqdm
|
9 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
10 |
-
|
11 |
-
|
12 |
-
def apply_delta(base_model_path, target_model_path, delta_path):
|
13 |
-
print(f"Loading the base model from {base_model_path}")
|
14 |
-
base = AutoModelForCausalLM.from_pretrained(
|
15 |
-
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
16 |
-
|
17 |
-
print(f"Loading the delta from {delta_path}")
|
18 |
-
delta = AutoModelForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
19 |
-
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False)
|
20 |
-
|
21 |
-
DEFAULT_PAD_TOKEN = "[PAD]"
|
22 |
-
base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False)
|
23 |
-
num_new_tokens = base_tokenizer.add_special_tokens(dict(pad_token=DEFAULT_PAD_TOKEN))
|
24 |
-
|
25 |
-
base.resize_token_embeddings(len(base_tokenizer))
|
26 |
-
input_embeddings = base.get_input_embeddings().weight.data
|
27 |
-
output_embeddings = base.get_output_embeddings().weight.data
|
28 |
-
input_embeddings[-num_new_tokens:] = 0
|
29 |
-
output_embeddings[-num_new_tokens:] = 0
|
30 |
-
|
31 |
-
print("Applying the delta")
|
32 |
-
for name, param in tqdm(base.state_dict().items(), desc="Applying delta"):
|
33 |
-
assert name in delta.state_dict()
|
34 |
-
param.data += delta.state_dict()[name]
|
35 |
-
|
36 |
-
print(f"Saving the target model to {target_model_path}")
|
37 |
-
base.save_pretrained(target_model_path)
|
38 |
-
delta_tokenizer.save_pretrained(target_model_path)
|
39 |
-
|
40 |
-
|
41 |
-
if __name__ == "__main__":
|
42 |
-
parser = argparse.ArgumentParser()
|
43 |
-
parser.add_argument("--base-model-path", type=str, required=True)
|
44 |
-
parser.add_argument("--target-model-path", type=str, required=True)
|
45 |
-
parser.add_argument("--delta-path", type=str, required=True)
|
46 |
-
|
47 |
-
args = parser.parse_args()
|
48 |
-
|
49 |
-
apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/eval_configs/video_llama_eval_only_vl.yaml
DELETED
@@ -1,36 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
arch: video_llama
|
3 |
-
model_type: pretrain_vicuna
|
4 |
-
freeze_vit: True
|
5 |
-
freeze_qformer: True
|
6 |
-
max_txt_len: 512
|
7 |
-
end_sym: "###"
|
8 |
-
low_resource: False
|
9 |
-
|
10 |
-
frozen_llama_proj: False
|
11 |
-
|
12 |
-
# If you want use LLaMA-2-chat,
|
13 |
-
# some ckpts could be download from our provided huggingface repo
|
14 |
-
# i.e. https://huggingface.co/DAMO-NLP-SG/Video-LLaMA-2-13B-Finetuned
|
15 |
-
llama_model: "./ckpts/video-llama/llama-2-7b-chat-hf" # "ckpt/vicuna-13b/" or "ckpt/vicuna-7b/" or "ckpt/llama-2-7b-chat-hf" or "ckpt/llama-2-13b-chat-hf"
|
16 |
-
ckpt: './ckpts/video-llama/VL_LLaMA_2_7B_Finetuned.pth' # you can use our pretrained ckpt from https://huggingface.co/DAMO-NLP-SG/Video-LLaMA-2-13B-Pretrained/
|
17 |
-
equip_audio_branch: False
|
18 |
-
|
19 |
-
fusion_head_layers: 2
|
20 |
-
max_frame_pos: 32
|
21 |
-
fusion_header_type: "seqTransf"
|
22 |
-
|
23 |
-
|
24 |
-
datasets:
|
25 |
-
webvid:
|
26 |
-
vis_processor:
|
27 |
-
train:
|
28 |
-
name: "alpro_video_eval"
|
29 |
-
n_frms: 8
|
30 |
-
image_size: 224
|
31 |
-
text_processor:
|
32 |
-
train:
|
33 |
-
name: "blip_caption"
|
34 |
-
|
35 |
-
run:
|
36 |
-
task: video_text_pretrain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/eval_configs/video_llama_eval_withaudio.yaml
DELETED
@@ -1,38 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
arch: video_llama
|
3 |
-
model_type: pretrain_vicuna
|
4 |
-
freeze_vit: True
|
5 |
-
freeze_qformer: True
|
6 |
-
max_txt_len: 512
|
7 |
-
end_sym: "###"
|
8 |
-
low_resource: False
|
9 |
-
|
10 |
-
frozen_llama_proj: False
|
11 |
-
|
12 |
-
# If you want use LLaMA-2-chat,
|
13 |
-
# some ckpts could be download from our provided huggingface repo
|
14 |
-
# i.e. https://huggingface.co/DAMO-NLP-SG/Video-LLaMA-2-13B-Finetuned
|
15 |
-
llama_model: "./ckpt/llama-2-13b-chat-hf" # "ckpt/vicuna-13b/" or "ckpt/vicuna-7b/" or "ckpt/llama-2-7b-chat-hf" or "ckpt/llama-2-13b-chat-hf"
|
16 |
-
imagebind_ckpt_path: "./ckpt/"
|
17 |
-
ckpt: './ckpt/VL_LLaMA_2_13B_Finetuned.pth' # you can use our pretrained ckpt from https://huggingface.co/DAMO-NLP-SG/Video-LLaMA-2-13B-Pretrained/
|
18 |
-
ckpt_2: './ckpt/AL_LLaMA_2_13B_Finetuned.pth'
|
19 |
-
|
20 |
-
equip_audio_branch: True # whether equips the audio branch
|
21 |
-
fusion_head_layers: 2
|
22 |
-
max_frame_pos: 32
|
23 |
-
fusion_header_type: "seqTransf"
|
24 |
-
|
25 |
-
|
26 |
-
datasets:
|
27 |
-
webvid:
|
28 |
-
vis_processor:
|
29 |
-
train:
|
30 |
-
name: "alpro_video_eval"
|
31 |
-
n_frms: 8
|
32 |
-
image_size: 224
|
33 |
-
text_processor:
|
34 |
-
train:
|
35 |
-
name: "blip_caption"
|
36 |
-
|
37 |
-
run:
|
38 |
-
task: video_text_pretrain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/inference.py
DELETED
@@ -1,94 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import os
|
3 |
-
import random
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
import torch
|
7 |
-
import torch.backends.cudnn as cudnn
|
8 |
-
import gradio as gr
|
9 |
-
from torch.cuda.amp import autocast
|
10 |
-
|
11 |
-
from sonique.Video_LLaMA.video_llama.common.config import Config
|
12 |
-
from sonique.Video_LLaMA.video_llama.common.dist_utils import get_rank
|
13 |
-
from sonique.Video_LLaMA.video_llama.common.registry import registry
|
14 |
-
from sonique.Video_LLaMA.video_llama.conversation.conversation_video import Chat, Conversation, default_conversation,SeparatorStyle,conv_llava_llama_2
|
15 |
-
import decord
|
16 |
-
import gc
|
17 |
-
|
18 |
-
decord.bridge.set_bridge('torch')
|
19 |
-
|
20 |
-
from sonique.Video_LLaMA.video_llama.datasets.builders import *
|
21 |
-
from sonique.Video_LLaMA.video_llama.models import *
|
22 |
-
from sonique.Video_LLaMA.video_llama.processors import *
|
23 |
-
from sonique.Video_LLaMA.video_llama.runners import *
|
24 |
-
from sonique.Video_LLaMA.video_llama.tasks import *
|
25 |
-
|
26 |
-
decord.bridge.set_bridge('torch')
|
27 |
-
|
28 |
-
|
29 |
-
def generate_prompt_from_video_description(cfg_path, gpu_id, model_type, input_file, num_beams=1, temperature=1.0, low_resource=False):
|
30 |
-
# initialize model
|
31 |
-
args = argparse.Namespace(cfg_path=cfg_path, gpu_id=gpu_id, model_type=model_type, options=[])
|
32 |
-
cfg = Config(args)
|
33 |
-
|
34 |
-
model_config = cfg.model_cfg
|
35 |
-
model_config.device_8bit = args.gpu_id
|
36 |
-
model_config.low_resource = low_resource
|
37 |
-
model_cls = registry.get_model_class(model_config.arch)
|
38 |
-
|
39 |
-
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
|
40 |
-
model.eval()
|
41 |
-
vis_processor_cfg = cfg.datasets_cfg.webvid.vis_processor.train
|
42 |
-
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
43 |
-
if args.model_type == 'vicuna':
|
44 |
-
chat_state = default_conversation.copy()
|
45 |
-
else:
|
46 |
-
chat_state = conv_llava_llama_2.copy()
|
47 |
-
chat = Chat(model, vis_processor, device=f'cuda:{args.gpu_id}')
|
48 |
-
|
49 |
-
# process input
|
50 |
-
if input_file.endswith('.jpg') or input_file.endswith('.png'):
|
51 |
-
print(input_file)
|
52 |
-
# chatbot = chatbot + [((input_file,), None)]
|
53 |
-
chat_state.system = "You are able to understand the visual content that the user provides. Follow the instructions carefully and explain your answers in detail."
|
54 |
-
img_list = []
|
55 |
-
llm_message = chat.upload_img(input_file, chat_state, img_list)
|
56 |
-
elif input_file.endswith('.mp4'):
|
57 |
-
print(input_file)
|
58 |
-
# chatbot = chatbot + [((input_file,), None)]
|
59 |
-
chat_state.system = "You are able to understand the visual content that the user provides. Follow the instructions carefully and explain your answers in detail."
|
60 |
-
img_list = []
|
61 |
-
llm_message = chat.upload_video_without_audio(input_file, chat_state, img_list)
|
62 |
-
|
63 |
-
else:
|
64 |
-
print("Unsupported file type")
|
65 |
-
return
|
66 |
-
|
67 |
-
question = "Describe the scene in detail"
|
68 |
-
# question = """
|
69 |
-
# As a music composer fluent in English, you're tasked with creating background music for a video.
|
70 |
-
# Based on the scene described, provide a set of tags in English that describe this background music for the video.
|
71 |
-
# Do not use the tags from the example.
|
72 |
-
# Please only return the set of tags that describe this background music for the input video without any explanation.
|
73 |
-
# Return the tags in the following format:
|
74 |
-
# Tags: [Tags1, Tags2, ..., Tempo (BPM)]
|
75 |
-
# Example format:
|
76 |
-
# Tags: [Piano, Synths, Strings, Violin, Flute, Reflective, Slow tempo, 96 BPM]
|
77 |
-
# """
|
78 |
-
with autocast():
|
79 |
-
chat.ask(question, chat_state)
|
80 |
-
|
81 |
-
llm_response = chat.answer(conv=chat_state,
|
82 |
-
img_list=img_list,
|
83 |
-
num_beams=num_beams,
|
84 |
-
temperature=temperature,
|
85 |
-
max_new_tokens=512,
|
86 |
-
max_length=2000)[0]
|
87 |
-
print("Chatbot response:", llm_response)
|
88 |
-
|
89 |
-
# clean up cache
|
90 |
-
del model
|
91 |
-
gc.collect()
|
92 |
-
torch.cuda.empty_cache()
|
93 |
-
return llm_response
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/__init__.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import os
|
9 |
-
import sys
|
10 |
-
|
11 |
-
from omegaconf import OmegaConf
|
12 |
-
|
13 |
-
from sonique.Video_LLaMA.video_llama.common.registry import registry
|
14 |
-
|
15 |
-
from sonique.Video_LLaMA.video_llama.datasets.builders import *
|
16 |
-
from sonique.Video_LLaMA.video_llama.models import *
|
17 |
-
from sonique.Video_LLaMA.video_llama.processors import *
|
18 |
-
from sonique.Video_LLaMA.video_llama.tasks import *
|
19 |
-
|
20 |
-
|
21 |
-
root_dir = os.path.dirname(os.path.abspath(__file__))
|
22 |
-
default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
|
23 |
-
|
24 |
-
registry.register_path("library_root", root_dir)
|
25 |
-
repo_root = os.path.join(root_dir, "..")
|
26 |
-
registry.register_path("repo_root", repo_root)
|
27 |
-
cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
|
28 |
-
registry.register_path("cache_root", cache_root)
|
29 |
-
|
30 |
-
registry.register("MAX_INT", sys.maxsize)
|
31 |
-
registry.register("SPLIT_NAMES", ["train", "val", "test"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/common/__init__.py
DELETED
File without changes
|
sonique/Video_LLaMA/video_llama/common/config.py
DELETED
@@ -1,468 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import logging
|
9 |
-
import json
|
10 |
-
from typing import Dict
|
11 |
-
|
12 |
-
from omegaconf import OmegaConf
|
13 |
-
from sonique.Video_LLaMA.video_llama.common.registry import registry
|
14 |
-
|
15 |
-
|
16 |
-
class Config:
|
17 |
-
def __init__(self, args):
|
18 |
-
self.config = {}
|
19 |
-
|
20 |
-
self.args = args
|
21 |
-
|
22 |
-
# Register the config and configuration for setup
|
23 |
-
registry.register("configuration", self)
|
24 |
-
|
25 |
-
user_config = self._build_opt_list(self.args.options)
|
26 |
-
|
27 |
-
config = OmegaConf.load(self.args.cfg_path)
|
28 |
-
|
29 |
-
runner_config = self.build_runner_config(config)
|
30 |
-
model_config = self.build_model_config(config, **user_config)
|
31 |
-
dataset_config = self.build_dataset_config(config)
|
32 |
-
|
33 |
-
# Validate the user-provided runner configuration
|
34 |
-
# model and dataset configuration are supposed to be validated by the respective classes
|
35 |
-
# [TODO] validate the model/dataset configuration
|
36 |
-
# self._validate_runner_config(runner_config)
|
37 |
-
|
38 |
-
# Override the default configuration with user options.
|
39 |
-
self.config = OmegaConf.merge(
|
40 |
-
runner_config, model_config, dataset_config, user_config
|
41 |
-
)
|
42 |
-
|
43 |
-
def _validate_runner_config(self, runner_config):
|
44 |
-
"""
|
45 |
-
This method validates the configuration, such that
|
46 |
-
1) all the user specified options are valid;
|
47 |
-
2) no type mismatches between the user specified options and the config.
|
48 |
-
"""
|
49 |
-
runner_config_validator = create_runner_config_validator()
|
50 |
-
runner_config_validator.validate(runner_config)
|
51 |
-
|
52 |
-
def _build_opt_list(self, opts):
|
53 |
-
opts_dot_list = self._convert_to_dot_list(opts)
|
54 |
-
return OmegaConf.from_dotlist(opts_dot_list)
|
55 |
-
|
56 |
-
@staticmethod
|
57 |
-
def build_model_config(config, **kwargs):
|
58 |
-
model = config.get("model", None)
|
59 |
-
assert model is not None, "Missing model configuration file."
|
60 |
-
|
61 |
-
model_cls = registry.get_model_class(model.arch)
|
62 |
-
assert model_cls is not None, f"Model '{model.arch}' has not been registered."
|
63 |
-
|
64 |
-
model_type = kwargs.get("model.model_type", None)
|
65 |
-
if not model_type:
|
66 |
-
model_type = model.get("model_type", None)
|
67 |
-
# else use the model type selected by user.
|
68 |
-
|
69 |
-
assert model_type is not None, "Missing model_type."
|
70 |
-
|
71 |
-
model_config_path = model_cls.default_config_path(model_type=model_type)
|
72 |
-
|
73 |
-
model_config = OmegaConf.create()
|
74 |
-
# hierarchy override, customized config > default config
|
75 |
-
model_config = OmegaConf.merge(
|
76 |
-
model_config,
|
77 |
-
OmegaConf.load(model_config_path),
|
78 |
-
{"model": config["model"]},
|
79 |
-
)
|
80 |
-
|
81 |
-
return model_config
|
82 |
-
|
83 |
-
@staticmethod
|
84 |
-
def build_runner_config(config):
|
85 |
-
return {"run": config.run}
|
86 |
-
|
87 |
-
@staticmethod
|
88 |
-
def build_dataset_config(config):
|
89 |
-
datasets = config.get("datasets", None)
|
90 |
-
if datasets is None:
|
91 |
-
raise KeyError(
|
92 |
-
"Expecting 'datasets' as the root key for dataset configuration."
|
93 |
-
)
|
94 |
-
|
95 |
-
dataset_config = OmegaConf.create()
|
96 |
-
|
97 |
-
for dataset_name in datasets:
|
98 |
-
builder_cls = registry.get_builder_class(dataset_name)
|
99 |
-
|
100 |
-
dataset_config_type = datasets[dataset_name].get("type", "default")
|
101 |
-
dataset_config_path = builder_cls.default_config_path(
|
102 |
-
type=dataset_config_type
|
103 |
-
)
|
104 |
-
|
105 |
-
# hierarchy override, customized config > default config
|
106 |
-
dataset_config = OmegaConf.merge(
|
107 |
-
dataset_config,
|
108 |
-
OmegaConf.load(dataset_config_path),
|
109 |
-
{"datasets": {dataset_name: config["datasets"][dataset_name]}},
|
110 |
-
)
|
111 |
-
|
112 |
-
return dataset_config
|
113 |
-
|
114 |
-
def _convert_to_dot_list(self, opts):
|
115 |
-
if opts is None:
|
116 |
-
opts = []
|
117 |
-
|
118 |
-
if len(opts) == 0:
|
119 |
-
return opts
|
120 |
-
|
121 |
-
has_equal = opts[0].find("=") != -1
|
122 |
-
|
123 |
-
if has_equal:
|
124 |
-
return opts
|
125 |
-
|
126 |
-
return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
|
127 |
-
|
128 |
-
def get_config(self):
|
129 |
-
return self.config
|
130 |
-
|
131 |
-
@property
|
132 |
-
def run_cfg(self):
|
133 |
-
return self.config.run
|
134 |
-
|
135 |
-
@property
|
136 |
-
def datasets_cfg(self):
|
137 |
-
return self.config.datasets
|
138 |
-
|
139 |
-
@property
|
140 |
-
def model_cfg(self):
|
141 |
-
return self.config.model
|
142 |
-
|
143 |
-
def pretty_print(self):
|
144 |
-
logging.info("\n===== Running Parameters =====")
|
145 |
-
logging.info(self._convert_node_to_json(self.config.run))
|
146 |
-
|
147 |
-
logging.info("\n====== Dataset Attributes ======")
|
148 |
-
datasets = self.config.datasets
|
149 |
-
|
150 |
-
for dataset in datasets:
|
151 |
-
if dataset in self.config.datasets:
|
152 |
-
logging.info(f"\n======== {dataset} =======")
|
153 |
-
dataset_config = self.config.datasets[dataset]
|
154 |
-
logging.info(self._convert_node_to_json(dataset_config))
|
155 |
-
else:
|
156 |
-
logging.warning(f"No dataset named '{dataset}' in config. Skipping")
|
157 |
-
|
158 |
-
logging.info(f"\n====== Model Attributes ======")
|
159 |
-
logging.info(self._convert_node_to_json(self.config.model))
|
160 |
-
|
161 |
-
def _convert_node_to_json(self, node):
|
162 |
-
container = OmegaConf.to_container(node, resolve=True)
|
163 |
-
return json.dumps(container, indent=4, sort_keys=True)
|
164 |
-
|
165 |
-
def to_dict(self):
|
166 |
-
return OmegaConf.to_container(self.config)
|
167 |
-
|
168 |
-
|
169 |
-
def node_to_dict(node):
|
170 |
-
return OmegaConf.to_container(node)
|
171 |
-
|
172 |
-
|
173 |
-
class ConfigValidator:
|
174 |
-
"""
|
175 |
-
This is a preliminary implementation to centralize and validate the configuration.
|
176 |
-
May be altered in the future.
|
177 |
-
|
178 |
-
A helper class to validate configurations from yaml file.
|
179 |
-
|
180 |
-
This serves the following purposes:
|
181 |
-
1. Ensure all the options in the yaml are defined, raise error if not.
|
182 |
-
2. when type mismatches are found, the validator will raise an error.
|
183 |
-
3. a central place to store and display helpful messages for supported configurations.
|
184 |
-
|
185 |
-
"""
|
186 |
-
|
187 |
-
class _Argument:
|
188 |
-
def __init__(self, name, choices=None, type=None, help=None):
|
189 |
-
self.name = name
|
190 |
-
self.val = None
|
191 |
-
self.choices = choices
|
192 |
-
self.type = type
|
193 |
-
self.help = help
|
194 |
-
|
195 |
-
def __str__(self):
|
196 |
-
s = f"{self.name}={self.val}"
|
197 |
-
if self.type is not None:
|
198 |
-
s += f", ({self.type})"
|
199 |
-
if self.choices is not None:
|
200 |
-
s += f", choices: {self.choices}"
|
201 |
-
if self.help is not None:
|
202 |
-
s += f", ({self.help})"
|
203 |
-
return s
|
204 |
-
|
205 |
-
def __init__(self, description):
|
206 |
-
self.description = description
|
207 |
-
|
208 |
-
self.arguments = dict()
|
209 |
-
|
210 |
-
self.parsed_args = None
|
211 |
-
|
212 |
-
def __getitem__(self, key):
|
213 |
-
assert self.parsed_args is not None, "No arguments parsed yet."
|
214 |
-
|
215 |
-
return self.parsed_args[key]
|
216 |
-
|
217 |
-
def __str__(self) -> str:
|
218 |
-
return self.format_help()
|
219 |
-
|
220 |
-
def add_argument(self, *args, **kwargs):
|
221 |
-
"""
|
222 |
-
Assume the first argument is the name of the argument.
|
223 |
-
"""
|
224 |
-
self.arguments[args[0]] = self._Argument(*args, **kwargs)
|
225 |
-
|
226 |
-
def validate(self, config=None):
|
227 |
-
"""
|
228 |
-
Convert yaml config (dict-like) to list, required by argparse.
|
229 |
-
"""
|
230 |
-
for k, v in config.items():
|
231 |
-
assert (
|
232 |
-
k in self.arguments
|
233 |
-
), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
|
234 |
-
|
235 |
-
if self.arguments[k].type is not None:
|
236 |
-
try:
|
237 |
-
self.arguments[k].val = self.arguments[k].type(v)
|
238 |
-
except ValueError:
|
239 |
-
raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
|
240 |
-
|
241 |
-
if self.arguments[k].choices is not None:
|
242 |
-
assert (
|
243 |
-
v in self.arguments[k].choices
|
244 |
-
), f"""{k} must be one of {self.arguments[k].choices}."""
|
245 |
-
|
246 |
-
return config
|
247 |
-
|
248 |
-
def format_arguments(self):
|
249 |
-
return str([f"{k}" for k in sorted(self.arguments.keys())])
|
250 |
-
|
251 |
-
def format_help(self):
|
252 |
-
# description + key-value pair string for each argument
|
253 |
-
help_msg = str(self.description)
|
254 |
-
return help_msg + ", available arguments: " + self.format_arguments()
|
255 |
-
|
256 |
-
def print_help(self):
|
257 |
-
# display help message
|
258 |
-
print(self.format_help())
|
259 |
-
|
260 |
-
|
261 |
-
def create_runner_config_validator():
|
262 |
-
validator = ConfigValidator(description="Runner configurations")
|
263 |
-
|
264 |
-
validator.add_argument(
|
265 |
-
"runner",
|
266 |
-
type=str,
|
267 |
-
choices=["runner_base", "runner_iter"],
|
268 |
-
help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
|
269 |
-
runner runs based on iters. Default: runner_base""",
|
270 |
-
)
|
271 |
-
# add argumetns for training dataset ratios
|
272 |
-
validator.add_argument(
|
273 |
-
"train_dataset_ratios",
|
274 |
-
type=Dict[str, float],
|
275 |
-
help="""Ratios of training dataset. This is used in iteration-based runner.
|
276 |
-
Do not support for epoch-based runner because how to define an epoch becomes tricky.
|
277 |
-
Default: None""",
|
278 |
-
)
|
279 |
-
validator.add_argument(
|
280 |
-
"max_iters",
|
281 |
-
type=float,
|
282 |
-
help="Maximum number of iterations to run.",
|
283 |
-
)
|
284 |
-
validator.add_argument(
|
285 |
-
"max_epoch",
|
286 |
-
type=int,
|
287 |
-
help="Maximum number of epochs to run.",
|
288 |
-
)
|
289 |
-
# add arguments for iters_per_inner_epoch
|
290 |
-
validator.add_argument(
|
291 |
-
"iters_per_inner_epoch",
|
292 |
-
type=float,
|
293 |
-
help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
|
294 |
-
)
|
295 |
-
lr_scheds_choices = registry.list_lr_schedulers()
|
296 |
-
validator.add_argument(
|
297 |
-
"lr_sched",
|
298 |
-
type=str,
|
299 |
-
choices=lr_scheds_choices,
|
300 |
-
help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
|
301 |
-
)
|
302 |
-
task_choices = registry.list_tasks()
|
303 |
-
validator.add_argument(
|
304 |
-
"task",
|
305 |
-
type=str,
|
306 |
-
choices=task_choices,
|
307 |
-
help="Task to use, from {}".format(task_choices),
|
308 |
-
)
|
309 |
-
# add arguments for init_lr
|
310 |
-
validator.add_argument(
|
311 |
-
"init_lr",
|
312 |
-
type=float,
|
313 |
-
help="Initial learning rate. This will be the learning rate after warmup and before decay.",
|
314 |
-
)
|
315 |
-
# add arguments for min_lr
|
316 |
-
validator.add_argument(
|
317 |
-
"min_lr",
|
318 |
-
type=float,
|
319 |
-
help="Minimum learning rate (after decay).",
|
320 |
-
)
|
321 |
-
# add arguments for warmup_lr
|
322 |
-
validator.add_argument(
|
323 |
-
"warmup_lr",
|
324 |
-
type=float,
|
325 |
-
help="Starting learning rate for warmup.",
|
326 |
-
)
|
327 |
-
# add arguments for learning rate decay rate
|
328 |
-
validator.add_argument(
|
329 |
-
"lr_decay_rate",
|
330 |
-
type=float,
|
331 |
-
help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
|
332 |
-
)
|
333 |
-
# add arguments for weight decay
|
334 |
-
validator.add_argument(
|
335 |
-
"weight_decay",
|
336 |
-
type=float,
|
337 |
-
help="Weight decay rate.",
|
338 |
-
)
|
339 |
-
# add arguments for training batch size
|
340 |
-
validator.add_argument(
|
341 |
-
"batch_size_train",
|
342 |
-
type=int,
|
343 |
-
help="Training batch size.",
|
344 |
-
)
|
345 |
-
# add arguments for evaluation batch size
|
346 |
-
validator.add_argument(
|
347 |
-
"batch_size_eval",
|
348 |
-
type=int,
|
349 |
-
help="Evaluation batch size, including validation and testing.",
|
350 |
-
)
|
351 |
-
# add arguments for number of workers for data loading
|
352 |
-
validator.add_argument(
|
353 |
-
"num_workers",
|
354 |
-
help="Number of workers for data loading.",
|
355 |
-
)
|
356 |
-
# add arguments for warm up steps
|
357 |
-
validator.add_argument(
|
358 |
-
"warmup_steps",
|
359 |
-
type=int,
|
360 |
-
help="Number of warmup steps. Required if a warmup schedule is used.",
|
361 |
-
)
|
362 |
-
# add arguments for random seed
|
363 |
-
validator.add_argument(
|
364 |
-
"seed",
|
365 |
-
type=int,
|
366 |
-
help="Random seed.",
|
367 |
-
)
|
368 |
-
# add arguments for output directory
|
369 |
-
validator.add_argument(
|
370 |
-
"output_dir",
|
371 |
-
type=str,
|
372 |
-
help="Output directory to save checkpoints and logs.",
|
373 |
-
)
|
374 |
-
# add arguments for whether only use evaluation
|
375 |
-
validator.add_argument(
|
376 |
-
"evaluate",
|
377 |
-
help="Whether to only evaluate the model. If true, training will not be performed.",
|
378 |
-
)
|
379 |
-
# add arguments for splits used for training, e.g. ["train", "val"]
|
380 |
-
validator.add_argument(
|
381 |
-
"train_splits",
|
382 |
-
type=list,
|
383 |
-
help="Splits to use for training.",
|
384 |
-
)
|
385 |
-
# add arguments for splits used for validation, e.g. ["val"]
|
386 |
-
validator.add_argument(
|
387 |
-
"valid_splits",
|
388 |
-
type=list,
|
389 |
-
help="Splits to use for validation. If not provided, will skip the validation.",
|
390 |
-
)
|
391 |
-
# add arguments for splits used for testing, e.g. ["test"]
|
392 |
-
validator.add_argument(
|
393 |
-
"test_splits",
|
394 |
-
type=list,
|
395 |
-
help="Splits to use for testing. If not provided, will skip the testing.",
|
396 |
-
)
|
397 |
-
# add arguments for accumulating gradient for iterations
|
398 |
-
validator.add_argument(
|
399 |
-
"accum_grad_iters",
|
400 |
-
type=int,
|
401 |
-
help="Number of iterations to accumulate gradient for.",
|
402 |
-
)
|
403 |
-
|
404 |
-
# ====== distributed training ======
|
405 |
-
validator.add_argument(
|
406 |
-
"device",
|
407 |
-
type=str,
|
408 |
-
choices=["cpu", "cuda"],
|
409 |
-
help="Device to use. Support 'cuda' or 'cpu' as for now.",
|
410 |
-
)
|
411 |
-
validator.add_argument(
|
412 |
-
"world_size",
|
413 |
-
type=int,
|
414 |
-
help="Number of processes participating in the job.",
|
415 |
-
)
|
416 |
-
validator.add_argument("dist_url", type=str)
|
417 |
-
validator.add_argument("distributed", type=bool)
|
418 |
-
# add arguments to opt using distributed sampler during evaluation or not
|
419 |
-
validator.add_argument(
|
420 |
-
"use_dist_eval_sampler",
|
421 |
-
type=bool,
|
422 |
-
help="Whether to use distributed sampler during evaluation or not.",
|
423 |
-
)
|
424 |
-
|
425 |
-
# ====== task specific ======
|
426 |
-
# generation task specific arguments
|
427 |
-
# add arguments for maximal length of text output
|
428 |
-
validator.add_argument(
|
429 |
-
"max_len",
|
430 |
-
type=int,
|
431 |
-
help="Maximal length of text output.",
|
432 |
-
)
|
433 |
-
# add arguments for minimal length of text output
|
434 |
-
validator.add_argument(
|
435 |
-
"min_len",
|
436 |
-
type=int,
|
437 |
-
help="Minimal length of text output.",
|
438 |
-
)
|
439 |
-
# add arguments number of beams
|
440 |
-
validator.add_argument(
|
441 |
-
"num_beams",
|
442 |
-
type=int,
|
443 |
-
help="Number of beams used for beam search.",
|
444 |
-
)
|
445 |
-
|
446 |
-
# vqa task specific arguments
|
447 |
-
# add arguments for number of answer candidates
|
448 |
-
validator.add_argument(
|
449 |
-
"num_ans_candidates",
|
450 |
-
type=int,
|
451 |
-
help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
|
452 |
-
)
|
453 |
-
# add arguments for inference method
|
454 |
-
validator.add_argument(
|
455 |
-
"inference_method",
|
456 |
-
type=str,
|
457 |
-
choices=["genearte", "rank"],
|
458 |
-
help="""Inference method to use for question answering. If rank, requires a answer list.""",
|
459 |
-
)
|
460 |
-
|
461 |
-
# ====== model specific ======
|
462 |
-
validator.add_argument(
|
463 |
-
"k_test",
|
464 |
-
type=int,
|
465 |
-
help="Number of top k most similar samples from ITC/VTC selection to be tested.",
|
466 |
-
)
|
467 |
-
|
468 |
-
return validator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/common/dist_utils.py
DELETED
@@ -1,137 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import datetime
|
9 |
-
import functools
|
10 |
-
import os
|
11 |
-
|
12 |
-
import torch
|
13 |
-
import torch.distributed as dist
|
14 |
-
import timm.models.hub as timm_hub
|
15 |
-
|
16 |
-
|
17 |
-
def setup_for_distributed(is_master):
|
18 |
-
"""
|
19 |
-
This function disables printing when not in master process
|
20 |
-
"""
|
21 |
-
import builtins as __builtin__
|
22 |
-
|
23 |
-
builtin_print = __builtin__.print
|
24 |
-
|
25 |
-
def print(*args, **kwargs):
|
26 |
-
force = kwargs.pop("force", False)
|
27 |
-
if is_master or force:
|
28 |
-
builtin_print(*args, **kwargs)
|
29 |
-
|
30 |
-
__builtin__.print = print
|
31 |
-
|
32 |
-
|
33 |
-
def is_dist_avail_and_initialized():
|
34 |
-
if not dist.is_available():
|
35 |
-
return False
|
36 |
-
if not dist.is_initialized():
|
37 |
-
return False
|
38 |
-
return True
|
39 |
-
|
40 |
-
|
41 |
-
def get_world_size():
|
42 |
-
if not is_dist_avail_and_initialized():
|
43 |
-
return 1
|
44 |
-
return dist.get_world_size()
|
45 |
-
|
46 |
-
|
47 |
-
def get_rank():
|
48 |
-
if not is_dist_avail_and_initialized():
|
49 |
-
return 0
|
50 |
-
return dist.get_rank()
|
51 |
-
|
52 |
-
|
53 |
-
def is_main_process():
|
54 |
-
return get_rank() == 0
|
55 |
-
|
56 |
-
|
57 |
-
def init_distributed_mode(args):
|
58 |
-
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
59 |
-
args.rank = int(os.environ["RANK"])
|
60 |
-
args.world_size = int(os.environ["WORLD_SIZE"])
|
61 |
-
args.gpu = int(os.environ["LOCAL_RANK"])
|
62 |
-
elif "SLURM_PROCID" in os.environ:
|
63 |
-
args.rank = int(os.environ["SLURM_PROCID"])
|
64 |
-
args.gpu = args.rank % torch.cuda.device_count()
|
65 |
-
else:
|
66 |
-
print("Not using distributed mode")
|
67 |
-
args.distributed = False
|
68 |
-
return
|
69 |
-
|
70 |
-
args.distributed = True
|
71 |
-
|
72 |
-
torch.cuda.set_device(args.gpu)
|
73 |
-
args.dist_backend = "nccl"
|
74 |
-
print(
|
75 |
-
"| distributed init (rank {}, world {}): {}".format(
|
76 |
-
args.rank, args.world_size, args.dist_url
|
77 |
-
),
|
78 |
-
flush=True,
|
79 |
-
)
|
80 |
-
torch.distributed.init_process_group(
|
81 |
-
backend=args.dist_backend,
|
82 |
-
init_method=args.dist_url,
|
83 |
-
world_size=args.world_size,
|
84 |
-
rank=args.rank,
|
85 |
-
timeout=datetime.timedelta(
|
86 |
-
days=365
|
87 |
-
), # allow auto-downloading and de-compressing
|
88 |
-
)
|
89 |
-
torch.distributed.barrier()
|
90 |
-
setup_for_distributed(args.rank == 0)
|
91 |
-
|
92 |
-
|
93 |
-
def get_dist_info():
|
94 |
-
if torch.__version__ < "1.0":
|
95 |
-
initialized = dist._initialized
|
96 |
-
else:
|
97 |
-
initialized = dist.is_initialized()
|
98 |
-
if initialized:
|
99 |
-
rank = dist.get_rank()
|
100 |
-
world_size = dist.get_world_size()
|
101 |
-
else: # non-distributed training
|
102 |
-
rank = 0
|
103 |
-
world_size = 1
|
104 |
-
return rank, world_size
|
105 |
-
|
106 |
-
|
107 |
-
def main_process(func):
|
108 |
-
@functools.wraps(func)
|
109 |
-
def wrapper(*args, **kwargs):
|
110 |
-
rank, _ = get_dist_info()
|
111 |
-
if rank == 0:
|
112 |
-
return func(*args, **kwargs)
|
113 |
-
|
114 |
-
return wrapper
|
115 |
-
|
116 |
-
|
117 |
-
def download_cached_file(url, check_hash=True, progress=False):
|
118 |
-
"""
|
119 |
-
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
|
120 |
-
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
|
121 |
-
"""
|
122 |
-
|
123 |
-
def get_cached_file_path():
|
124 |
-
# a hack to sync the file path across processes
|
125 |
-
parts = torch.hub.urlparse(url)
|
126 |
-
filename = os.path.basename(parts.path)
|
127 |
-
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
|
128 |
-
|
129 |
-
return cached_file
|
130 |
-
|
131 |
-
if is_main_process():
|
132 |
-
timm_hub.download_cached_file(url, check_hash, progress)
|
133 |
-
|
134 |
-
if is_dist_avail_and_initialized():
|
135 |
-
dist.barrier()
|
136 |
-
|
137 |
-
return get_cached_file_path()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/common/gradcam.py
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
from matplotlib import pyplot as plt
|
3 |
-
from scipy.ndimage import filters
|
4 |
-
from skimage import transform as skimage_transform
|
5 |
-
|
6 |
-
|
7 |
-
def getAttMap(img, attMap, blur=True, overlap=True):
|
8 |
-
attMap -= attMap.min()
|
9 |
-
if attMap.max() > 0:
|
10 |
-
attMap /= attMap.max()
|
11 |
-
attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
|
12 |
-
if blur:
|
13 |
-
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
|
14 |
-
attMap -= attMap.min()
|
15 |
-
attMap /= attMap.max()
|
16 |
-
cmap = plt.get_cmap("jet")
|
17 |
-
attMapV = cmap(attMap)
|
18 |
-
attMapV = np.delete(attMapV, 3, 2)
|
19 |
-
if overlap:
|
20 |
-
attMap = (
|
21 |
-
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
|
22 |
-
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
|
23 |
-
)
|
24 |
-
return attMap
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/common/logger.py
DELETED
@@ -1,195 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import datetime
|
9 |
-
import logging
|
10 |
-
import time
|
11 |
-
from collections import defaultdict, deque
|
12 |
-
|
13 |
-
import torch
|
14 |
-
import torch.distributed as dist
|
15 |
-
|
16 |
-
from sonique.Video_LLaMA.video_llama.common import dist_utils
|
17 |
-
|
18 |
-
|
19 |
-
class SmoothedValue(object):
|
20 |
-
"""Track a series of values and provide access to smoothed values over a
|
21 |
-
window or the global series average.
|
22 |
-
"""
|
23 |
-
|
24 |
-
def __init__(self, window_size=20, fmt=None):
|
25 |
-
if fmt is None:
|
26 |
-
fmt = "{median:.4f} ({global_avg:.4f})"
|
27 |
-
self.deque = deque(maxlen=window_size)
|
28 |
-
self.total = 0.0
|
29 |
-
self.count = 0
|
30 |
-
self.fmt = fmt
|
31 |
-
|
32 |
-
def update(self, value, n=1):
|
33 |
-
self.deque.append(value)
|
34 |
-
self.count += n
|
35 |
-
self.total += value * n
|
36 |
-
|
37 |
-
def synchronize_between_processes(self):
|
38 |
-
"""
|
39 |
-
Warning: does not synchronize the deque!
|
40 |
-
"""
|
41 |
-
if not dist_utils.is_dist_avail_and_initialized():
|
42 |
-
return
|
43 |
-
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
44 |
-
dist.barrier()
|
45 |
-
dist.all_reduce(t)
|
46 |
-
t = t.tolist()
|
47 |
-
self.count = int(t[0])
|
48 |
-
self.total = t[1]
|
49 |
-
|
50 |
-
@property
|
51 |
-
def median(self):
|
52 |
-
d = torch.tensor(list(self.deque))
|
53 |
-
return d.median().item()
|
54 |
-
|
55 |
-
@property
|
56 |
-
def avg(self):
|
57 |
-
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
58 |
-
return d.mean().item()
|
59 |
-
|
60 |
-
@property
|
61 |
-
def global_avg(self):
|
62 |
-
return self.total / self.count
|
63 |
-
|
64 |
-
@property
|
65 |
-
def max(self):
|
66 |
-
return max(self.deque)
|
67 |
-
|
68 |
-
@property
|
69 |
-
def value(self):
|
70 |
-
return self.deque[-1]
|
71 |
-
|
72 |
-
def __str__(self):
|
73 |
-
return self.fmt.format(
|
74 |
-
median=self.median,
|
75 |
-
avg=self.avg,
|
76 |
-
global_avg=self.global_avg,
|
77 |
-
max=self.max,
|
78 |
-
value=self.value,
|
79 |
-
)
|
80 |
-
|
81 |
-
|
82 |
-
class MetricLogger(object):
|
83 |
-
def __init__(self, delimiter="\t"):
|
84 |
-
self.meters = defaultdict(SmoothedValue)
|
85 |
-
self.delimiter = delimiter
|
86 |
-
|
87 |
-
def update(self, **kwargs):
|
88 |
-
for k, v in kwargs.items():
|
89 |
-
if isinstance(v, torch.Tensor):
|
90 |
-
v = v.item()
|
91 |
-
assert isinstance(v, (float, int))
|
92 |
-
self.meters[k].update(v)
|
93 |
-
|
94 |
-
def __getattr__(self, attr):
|
95 |
-
if attr in self.meters:
|
96 |
-
return self.meters[attr]
|
97 |
-
if attr in self.__dict__:
|
98 |
-
return self.__dict__[attr]
|
99 |
-
raise AttributeError(
|
100 |
-
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
|
101 |
-
)
|
102 |
-
|
103 |
-
def __str__(self):
|
104 |
-
loss_str = []
|
105 |
-
for name, meter in self.meters.items():
|
106 |
-
loss_str.append("{}: {}".format(name, str(meter)))
|
107 |
-
return self.delimiter.join(loss_str)
|
108 |
-
|
109 |
-
def global_avg(self):
|
110 |
-
loss_str = []
|
111 |
-
for name, meter in self.meters.items():
|
112 |
-
loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
|
113 |
-
return self.delimiter.join(loss_str)
|
114 |
-
|
115 |
-
def synchronize_between_processes(self):
|
116 |
-
for meter in self.meters.values():
|
117 |
-
meter.synchronize_between_processes()
|
118 |
-
|
119 |
-
def add_meter(self, name, meter):
|
120 |
-
self.meters[name] = meter
|
121 |
-
|
122 |
-
def log_every(self, iterable, print_freq, header=None):
|
123 |
-
i = 0
|
124 |
-
if not header:
|
125 |
-
header = ""
|
126 |
-
start_time = time.time()
|
127 |
-
end = time.time()
|
128 |
-
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
129 |
-
data_time = SmoothedValue(fmt="{avg:.4f}")
|
130 |
-
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
131 |
-
log_msg = [
|
132 |
-
header,
|
133 |
-
"[{0" + space_fmt + "}/{1}]",
|
134 |
-
"eta: {eta}",
|
135 |
-
"{meters}",
|
136 |
-
"time: {time}",
|
137 |
-
"data: {data}",
|
138 |
-
]
|
139 |
-
if torch.cuda.is_available():
|
140 |
-
log_msg.append("max mem: {memory:.0f}")
|
141 |
-
log_msg = self.delimiter.join(log_msg)
|
142 |
-
MB = 1024.0 * 1024.0
|
143 |
-
for obj in iterable:
|
144 |
-
data_time.update(time.time() - end)
|
145 |
-
yield obj
|
146 |
-
iter_time.update(time.time() - end)
|
147 |
-
if i % print_freq == 0 or i == len(iterable) - 1:
|
148 |
-
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
149 |
-
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
150 |
-
if torch.cuda.is_available():
|
151 |
-
print(
|
152 |
-
log_msg.format(
|
153 |
-
i,
|
154 |
-
len(iterable),
|
155 |
-
eta=eta_string,
|
156 |
-
meters=str(self),
|
157 |
-
time=str(iter_time),
|
158 |
-
data=str(data_time),
|
159 |
-
memory=torch.cuda.max_memory_allocated() / MB,
|
160 |
-
)
|
161 |
-
)
|
162 |
-
else:
|
163 |
-
print(
|
164 |
-
log_msg.format(
|
165 |
-
i,
|
166 |
-
len(iterable),
|
167 |
-
eta=eta_string,
|
168 |
-
meters=str(self),
|
169 |
-
time=str(iter_time),
|
170 |
-
data=str(data_time),
|
171 |
-
)
|
172 |
-
)
|
173 |
-
i += 1
|
174 |
-
end = time.time()
|
175 |
-
total_time = time.time() - start_time
|
176 |
-
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
177 |
-
print(
|
178 |
-
"{} Total time: {} ({:.4f} s / it)".format(
|
179 |
-
header, total_time_str, total_time / len(iterable)
|
180 |
-
)
|
181 |
-
)
|
182 |
-
|
183 |
-
|
184 |
-
class AttrDict(dict):
|
185 |
-
def __init__(self, *args, **kwargs):
|
186 |
-
super(AttrDict, self).__init__(*args, **kwargs)
|
187 |
-
self.__dict__ = self
|
188 |
-
|
189 |
-
|
190 |
-
def setup_logger():
|
191 |
-
logging.basicConfig(
|
192 |
-
level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
|
193 |
-
format="%(asctime)s [%(levelname)s] %(message)s",
|
194 |
-
handlers=[logging.StreamHandler()],
|
195 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/common/optims.py
DELETED
@@ -1,119 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import math
|
9 |
-
|
10 |
-
from sonique.Video_LLaMA.video_llama.common.registry import registry
|
11 |
-
|
12 |
-
|
13 |
-
@registry.register_lr_scheduler("linear_warmup_step_lr")
|
14 |
-
class LinearWarmupStepLRScheduler:
|
15 |
-
def __init__(
|
16 |
-
self,
|
17 |
-
optimizer,
|
18 |
-
max_epoch,
|
19 |
-
min_lr,
|
20 |
-
init_lr,
|
21 |
-
decay_rate=1,
|
22 |
-
warmup_start_lr=-1,
|
23 |
-
warmup_steps=0,
|
24 |
-
**kwargs
|
25 |
-
):
|
26 |
-
self.optimizer = optimizer
|
27 |
-
|
28 |
-
self.max_epoch = max_epoch
|
29 |
-
self.min_lr = min_lr
|
30 |
-
|
31 |
-
self.decay_rate = decay_rate
|
32 |
-
|
33 |
-
self.init_lr = init_lr
|
34 |
-
self.warmup_steps = warmup_steps
|
35 |
-
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
36 |
-
|
37 |
-
def step(self, cur_epoch, cur_step):
|
38 |
-
if cur_epoch == 0:
|
39 |
-
warmup_lr_schedule(
|
40 |
-
step=cur_step,
|
41 |
-
optimizer=self.optimizer,
|
42 |
-
max_step=self.warmup_steps,
|
43 |
-
init_lr=self.warmup_start_lr,
|
44 |
-
max_lr=self.init_lr,
|
45 |
-
)
|
46 |
-
else:
|
47 |
-
step_lr_schedule(
|
48 |
-
epoch=cur_epoch,
|
49 |
-
optimizer=self.optimizer,
|
50 |
-
init_lr=self.init_lr,
|
51 |
-
min_lr=self.min_lr,
|
52 |
-
decay_rate=self.decay_rate,
|
53 |
-
)
|
54 |
-
|
55 |
-
|
56 |
-
@registry.register_lr_scheduler("linear_warmup_cosine_lr")
|
57 |
-
class LinearWarmupCosineLRScheduler:
|
58 |
-
def __init__(
|
59 |
-
self,
|
60 |
-
optimizer,
|
61 |
-
max_epoch,
|
62 |
-
iters_per_epoch,
|
63 |
-
min_lr,
|
64 |
-
init_lr,
|
65 |
-
warmup_steps=0,
|
66 |
-
warmup_start_lr=-1,
|
67 |
-
**kwargs
|
68 |
-
):
|
69 |
-
self.optimizer = optimizer
|
70 |
-
|
71 |
-
self.max_epoch = max_epoch
|
72 |
-
self.iters_per_epoch = iters_per_epoch
|
73 |
-
self.min_lr = min_lr
|
74 |
-
|
75 |
-
self.init_lr = init_lr
|
76 |
-
self.warmup_steps = warmup_steps
|
77 |
-
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
78 |
-
|
79 |
-
def step(self, cur_epoch, cur_step):
|
80 |
-
total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
|
81 |
-
if total_cur_step < self.warmup_steps:
|
82 |
-
warmup_lr_schedule(
|
83 |
-
step=cur_step,
|
84 |
-
optimizer=self.optimizer,
|
85 |
-
max_step=self.warmup_steps,
|
86 |
-
init_lr=self.warmup_start_lr,
|
87 |
-
max_lr=self.init_lr,
|
88 |
-
)
|
89 |
-
else:
|
90 |
-
cosine_lr_schedule(
|
91 |
-
epoch=total_cur_step,
|
92 |
-
optimizer=self.optimizer,
|
93 |
-
max_epoch=self.max_epoch * self.iters_per_epoch,
|
94 |
-
init_lr=self.init_lr,
|
95 |
-
min_lr=self.min_lr,
|
96 |
-
)
|
97 |
-
|
98 |
-
|
99 |
-
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
|
100 |
-
"""Decay the learning rate"""
|
101 |
-
lr = (init_lr - min_lr) * 0.5 * (
|
102 |
-
1.0 + math.cos(math.pi * epoch / max_epoch)
|
103 |
-
) + min_lr
|
104 |
-
for param_group in optimizer.param_groups:
|
105 |
-
param_group["lr"] = lr
|
106 |
-
|
107 |
-
|
108 |
-
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
|
109 |
-
"""Warmup the learning rate"""
|
110 |
-
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
|
111 |
-
for param_group in optimizer.param_groups:
|
112 |
-
param_group["lr"] = lr
|
113 |
-
|
114 |
-
|
115 |
-
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
|
116 |
-
"""Decay the learning rate"""
|
117 |
-
lr = max(min_lr, init_lr * (decay_rate**epoch))
|
118 |
-
for param_group in optimizer.param_groups:
|
119 |
-
param_group["lr"] = lr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/common/registry.py
DELETED
@@ -1,329 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
|
9 |
-
class Registry:
|
10 |
-
mapping = {
|
11 |
-
"builder_name_mapping": {},
|
12 |
-
"task_name_mapping": {},
|
13 |
-
"processor_name_mapping": {},
|
14 |
-
"model_name_mapping": {},
|
15 |
-
"lr_scheduler_name_mapping": {},
|
16 |
-
"runner_name_mapping": {},
|
17 |
-
"state": {},
|
18 |
-
"paths": {},
|
19 |
-
}
|
20 |
-
|
21 |
-
@classmethod
|
22 |
-
def register_builder(cls, name):
|
23 |
-
r"""Register a dataset builder to registry with key 'name'
|
24 |
-
|
25 |
-
Args:
|
26 |
-
name: Key with which the builder will be registered.
|
27 |
-
|
28 |
-
Usage:
|
29 |
-
|
30 |
-
from video_llama.common.registry import registry
|
31 |
-
from video_llama.datasets.base_dataset_builder import BaseDatasetBuilder
|
32 |
-
"""
|
33 |
-
|
34 |
-
def wrap(builder_cls):
|
35 |
-
from sonique.Video_LLaMA.video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
36 |
-
|
37 |
-
assert issubclass(
|
38 |
-
builder_cls, BaseDatasetBuilder
|
39 |
-
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
|
40 |
-
builder_cls
|
41 |
-
)
|
42 |
-
if name in cls.mapping["builder_name_mapping"]:
|
43 |
-
raise KeyError(
|
44 |
-
"Name '{}' already registered for {}.".format(
|
45 |
-
name, cls.mapping["builder_name_mapping"][name]
|
46 |
-
)
|
47 |
-
)
|
48 |
-
cls.mapping["builder_name_mapping"][name] = builder_cls
|
49 |
-
return builder_cls
|
50 |
-
|
51 |
-
return wrap
|
52 |
-
|
53 |
-
@classmethod
|
54 |
-
def register_task(cls, name):
|
55 |
-
r"""Register a task to registry with key 'name'
|
56 |
-
|
57 |
-
Args:
|
58 |
-
name: Key with which the task will be registered.
|
59 |
-
|
60 |
-
Usage:
|
61 |
-
|
62 |
-
from video_llama.common.registry import registry
|
63 |
-
"""
|
64 |
-
|
65 |
-
def wrap(task_cls):
|
66 |
-
from sonique.Video_LLaMA.video_llama.tasks.base_task import BaseTask
|
67 |
-
|
68 |
-
assert issubclass(
|
69 |
-
task_cls, BaseTask
|
70 |
-
), "All tasks must inherit BaseTask class"
|
71 |
-
if name in cls.mapping["task_name_mapping"]:
|
72 |
-
raise KeyError(
|
73 |
-
"Name '{}' already registered for {}.".format(
|
74 |
-
name, cls.mapping["task_name_mapping"][name]
|
75 |
-
)
|
76 |
-
)
|
77 |
-
cls.mapping["task_name_mapping"][name] = task_cls
|
78 |
-
return task_cls
|
79 |
-
|
80 |
-
return wrap
|
81 |
-
|
82 |
-
@classmethod
|
83 |
-
def register_model(cls, name):
|
84 |
-
r"""Register a task to registry with key 'name'
|
85 |
-
|
86 |
-
Args:
|
87 |
-
name: Key with which the task will be registered.
|
88 |
-
|
89 |
-
Usage:
|
90 |
-
|
91 |
-
from video_llama.common.registry import registry
|
92 |
-
"""
|
93 |
-
|
94 |
-
def wrap(model_cls):
|
95 |
-
from sonique.Video_LLaMA.video_llama.models import BaseModel
|
96 |
-
|
97 |
-
assert issubclass(
|
98 |
-
model_cls, BaseModel
|
99 |
-
), "All models must inherit BaseModel class"
|
100 |
-
if name in cls.mapping["model_name_mapping"]:
|
101 |
-
raise KeyError(
|
102 |
-
"Name '{}' already registered for {}.".format(
|
103 |
-
name, cls.mapping["model_name_mapping"][name]
|
104 |
-
)
|
105 |
-
)
|
106 |
-
cls.mapping["model_name_mapping"][name] = model_cls
|
107 |
-
return model_cls
|
108 |
-
|
109 |
-
return wrap
|
110 |
-
|
111 |
-
@classmethod
|
112 |
-
def register_processor(cls, name):
|
113 |
-
r"""Register a processor to registry with key 'name'
|
114 |
-
|
115 |
-
Args:
|
116 |
-
name: Key with which the task will be registered.
|
117 |
-
|
118 |
-
Usage:
|
119 |
-
|
120 |
-
from video_llama.common.registry import registry
|
121 |
-
"""
|
122 |
-
|
123 |
-
def wrap(processor_cls):
|
124 |
-
from sonique.Video_LLaMA.video_llama.processors import BaseProcessor
|
125 |
-
|
126 |
-
assert issubclass(
|
127 |
-
processor_cls, BaseProcessor
|
128 |
-
), "All processors must inherit BaseProcessor class"
|
129 |
-
if name in cls.mapping["processor_name_mapping"]:
|
130 |
-
raise KeyError(
|
131 |
-
"Name '{}' already registered for {}.".format(
|
132 |
-
name, cls.mapping["processor_name_mapping"][name]
|
133 |
-
)
|
134 |
-
)
|
135 |
-
cls.mapping["processor_name_mapping"][name] = processor_cls
|
136 |
-
return processor_cls
|
137 |
-
|
138 |
-
return wrap
|
139 |
-
|
140 |
-
@classmethod
|
141 |
-
def register_lr_scheduler(cls, name):
|
142 |
-
r"""Register a model to registry with key 'name'
|
143 |
-
|
144 |
-
Args:
|
145 |
-
name: Key with which the task will be registered.
|
146 |
-
|
147 |
-
Usage:
|
148 |
-
|
149 |
-
from video_llama.common.registry import registry
|
150 |
-
"""
|
151 |
-
|
152 |
-
def wrap(lr_sched_cls):
|
153 |
-
if name in cls.mapping["lr_scheduler_name_mapping"]:
|
154 |
-
raise KeyError(
|
155 |
-
"Name '{}' already registered for {}.".format(
|
156 |
-
name, cls.mapping["lr_scheduler_name_mapping"][name]
|
157 |
-
)
|
158 |
-
)
|
159 |
-
cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
|
160 |
-
return lr_sched_cls
|
161 |
-
|
162 |
-
return wrap
|
163 |
-
|
164 |
-
@classmethod
|
165 |
-
def register_runner(cls, name):
|
166 |
-
r"""Register a model to registry with key 'name'
|
167 |
-
|
168 |
-
Args:
|
169 |
-
name: Key with which the task will be registered.
|
170 |
-
|
171 |
-
Usage:
|
172 |
-
|
173 |
-
from video_llama.common.registry import registry
|
174 |
-
"""
|
175 |
-
|
176 |
-
def wrap(runner_cls):
|
177 |
-
if name in cls.mapping["runner_name_mapping"]:
|
178 |
-
raise KeyError(
|
179 |
-
"Name '{}' already registered for {}.".format(
|
180 |
-
name, cls.mapping["runner_name_mapping"][name]
|
181 |
-
)
|
182 |
-
)
|
183 |
-
cls.mapping["runner_name_mapping"][name] = runner_cls
|
184 |
-
return runner_cls
|
185 |
-
|
186 |
-
return wrap
|
187 |
-
|
188 |
-
@classmethod
|
189 |
-
def register_path(cls, name, path):
|
190 |
-
r"""Register a path to registry with key 'name'
|
191 |
-
|
192 |
-
Args:
|
193 |
-
name: Key with which the path will be registered.
|
194 |
-
|
195 |
-
Usage:
|
196 |
-
|
197 |
-
from video_llama.common.registry import registry
|
198 |
-
"""
|
199 |
-
assert isinstance(path, str), "All path must be str."
|
200 |
-
if name in cls.mapping["paths"]:
|
201 |
-
raise KeyError("Name '{}' already registered.".format(name))
|
202 |
-
cls.mapping["paths"][name] = path
|
203 |
-
|
204 |
-
@classmethod
|
205 |
-
def register(cls, name, obj):
|
206 |
-
r"""Register an item to registry with key 'name'
|
207 |
-
|
208 |
-
Args:
|
209 |
-
name: Key with which the item will be registered.
|
210 |
-
|
211 |
-
Usage::
|
212 |
-
|
213 |
-
from video_llama.common.registry import registry
|
214 |
-
|
215 |
-
registry.register("config", {})
|
216 |
-
"""
|
217 |
-
path = name.split(".")
|
218 |
-
current = cls.mapping["state"]
|
219 |
-
|
220 |
-
for part in path[:-1]:
|
221 |
-
if part not in current:
|
222 |
-
current[part] = {}
|
223 |
-
current = current[part]
|
224 |
-
|
225 |
-
current[path[-1]] = obj
|
226 |
-
|
227 |
-
# @classmethod
|
228 |
-
# def get_trainer_class(cls, name):
|
229 |
-
# return cls.mapping["trainer_name_mapping"].get(name, None)
|
230 |
-
|
231 |
-
@classmethod
|
232 |
-
def get_builder_class(cls, name):
|
233 |
-
return cls.mapping["builder_name_mapping"].get(name, None)
|
234 |
-
|
235 |
-
@classmethod
|
236 |
-
def get_model_class(cls, name):
|
237 |
-
return cls.mapping["model_name_mapping"].get(name, None)
|
238 |
-
|
239 |
-
@classmethod
|
240 |
-
def get_task_class(cls, name):
|
241 |
-
return cls.mapping["task_name_mapping"].get(name, None)
|
242 |
-
|
243 |
-
@classmethod
|
244 |
-
def get_processor_class(cls, name):
|
245 |
-
return cls.mapping["processor_name_mapping"].get(name, None)
|
246 |
-
|
247 |
-
@classmethod
|
248 |
-
def get_lr_scheduler_class(cls, name):
|
249 |
-
return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
|
250 |
-
|
251 |
-
@classmethod
|
252 |
-
def get_runner_class(cls, name):
|
253 |
-
return cls.mapping["runner_name_mapping"].get(name, None)
|
254 |
-
|
255 |
-
@classmethod
|
256 |
-
def list_runners(cls):
|
257 |
-
return sorted(cls.mapping["runner_name_mapping"].keys())
|
258 |
-
|
259 |
-
@classmethod
|
260 |
-
def list_models(cls):
|
261 |
-
return sorted(cls.mapping["model_name_mapping"].keys())
|
262 |
-
|
263 |
-
@classmethod
|
264 |
-
def list_tasks(cls):
|
265 |
-
return sorted(cls.mapping["task_name_mapping"].keys())
|
266 |
-
|
267 |
-
@classmethod
|
268 |
-
def list_processors(cls):
|
269 |
-
return sorted(cls.mapping["processor_name_mapping"].keys())
|
270 |
-
|
271 |
-
@classmethod
|
272 |
-
def list_lr_schedulers(cls):
|
273 |
-
return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
|
274 |
-
|
275 |
-
@classmethod
|
276 |
-
def list_datasets(cls):
|
277 |
-
return sorted(cls.mapping["builder_name_mapping"].keys())
|
278 |
-
|
279 |
-
@classmethod
|
280 |
-
def get_path(cls, name):
|
281 |
-
return cls.mapping["paths"].get(name, None)
|
282 |
-
|
283 |
-
@classmethod
|
284 |
-
def get(cls, name, default=None, no_warning=False):
|
285 |
-
r"""Get an item from registry with key 'name'
|
286 |
-
|
287 |
-
Args:
|
288 |
-
name (string): Key whose value needs to be retrieved.
|
289 |
-
default: If passed and key is not in registry, default value will
|
290 |
-
be returned with a warning. Default: None
|
291 |
-
no_warning (bool): If passed as True, warning when key doesn't exist
|
292 |
-
will not be generated. Useful for MMF's
|
293 |
-
internal operations. Default: False
|
294 |
-
"""
|
295 |
-
original_name = name
|
296 |
-
name = name.split(".")
|
297 |
-
value = cls.mapping["state"]
|
298 |
-
for subname in name:
|
299 |
-
value = value.get(subname, default)
|
300 |
-
if value is default:
|
301 |
-
break
|
302 |
-
|
303 |
-
if (
|
304 |
-
"writer" in cls.mapping["state"]
|
305 |
-
and value == default
|
306 |
-
and no_warning is False
|
307 |
-
):
|
308 |
-
cls.mapping["state"]["writer"].warning(
|
309 |
-
"Key {} is not present in registry, returning default value "
|
310 |
-
"of {}".format(original_name, default)
|
311 |
-
)
|
312 |
-
return value
|
313 |
-
|
314 |
-
@classmethod
|
315 |
-
def unregister(cls, name):
|
316 |
-
r"""Remove an item from registry with key 'name'
|
317 |
-
|
318 |
-
Args:
|
319 |
-
name: Key which needs to be removed.
|
320 |
-
Usage::
|
321 |
-
|
322 |
-
from mmf.common.registry import registry
|
323 |
-
|
324 |
-
config = registry.unregister("config")
|
325 |
-
"""
|
326 |
-
return cls.mapping["state"].pop(name, None)
|
327 |
-
|
328 |
-
|
329 |
-
registry = Registry()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/common/utils.py
DELETED
@@ -1,424 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import io
|
9 |
-
import json
|
10 |
-
import logging
|
11 |
-
import os
|
12 |
-
import pickle
|
13 |
-
import re
|
14 |
-
import shutil
|
15 |
-
import urllib
|
16 |
-
import urllib.error
|
17 |
-
import urllib.request
|
18 |
-
from typing import Optional
|
19 |
-
from urllib.parse import urlparse
|
20 |
-
|
21 |
-
import numpy as np
|
22 |
-
import pandas as pd
|
23 |
-
import yaml
|
24 |
-
from iopath.common.download import download
|
25 |
-
from iopath.common.file_io import file_lock, g_pathmgr
|
26 |
-
from sonique.Video_LLaMA.video_llama.common.registry import registry
|
27 |
-
from torch.utils.model_zoo import tqdm
|
28 |
-
from torchvision.datasets.utils import (
|
29 |
-
check_integrity,
|
30 |
-
download_file_from_google_drive,
|
31 |
-
extract_archive,
|
32 |
-
)
|
33 |
-
|
34 |
-
|
35 |
-
def now():
|
36 |
-
from datetime import datetime
|
37 |
-
|
38 |
-
return datetime.now().strftime("%Y%m%d%H%M")[:-1]
|
39 |
-
|
40 |
-
|
41 |
-
def is_url(url_or_filename):
|
42 |
-
parsed = urlparse(url_or_filename)
|
43 |
-
return parsed.scheme in ("http", "https")
|
44 |
-
|
45 |
-
|
46 |
-
def get_cache_path(rel_path):
|
47 |
-
return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
|
48 |
-
|
49 |
-
|
50 |
-
def get_abs_path(rel_path):
|
51 |
-
return os.path.join(registry.get_path("library_root"), rel_path)
|
52 |
-
|
53 |
-
|
54 |
-
def load_json(filename):
|
55 |
-
with open(filename, "r") as f:
|
56 |
-
return json.load(f)
|
57 |
-
|
58 |
-
|
59 |
-
# The following are adapted from torchvision and vissl
|
60 |
-
# torchvision: https://github.com/pytorch/vision
|
61 |
-
# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
|
62 |
-
|
63 |
-
|
64 |
-
def makedir(dir_path):
|
65 |
-
"""
|
66 |
-
Create the directory if it does not exist.
|
67 |
-
"""
|
68 |
-
is_success = False
|
69 |
-
try:
|
70 |
-
if not g_pathmgr.exists(dir_path):
|
71 |
-
g_pathmgr.mkdirs(dir_path)
|
72 |
-
is_success = True
|
73 |
-
except BaseException:
|
74 |
-
print(f"Error creating directory: {dir_path}")
|
75 |
-
return is_success
|
76 |
-
|
77 |
-
|
78 |
-
def get_redirected_url(url: str):
|
79 |
-
"""
|
80 |
-
Given a URL, returns the URL it redirects to or the
|
81 |
-
original URL in case of no indirection
|
82 |
-
"""
|
83 |
-
import requests
|
84 |
-
|
85 |
-
with requests.Session() as session:
|
86 |
-
with session.get(url, stream=True, allow_redirects=True) as response:
|
87 |
-
if response.history:
|
88 |
-
return response.url
|
89 |
-
else:
|
90 |
-
return url
|
91 |
-
|
92 |
-
|
93 |
-
def to_google_drive_download_url(view_url: str) -> str:
|
94 |
-
"""
|
95 |
-
Utility function to transform a view URL of google drive
|
96 |
-
to a download URL for google drive
|
97 |
-
Example input:
|
98 |
-
https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
|
99 |
-
Example output:
|
100 |
-
https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
|
101 |
-
"""
|
102 |
-
splits = view_url.split("/")
|
103 |
-
assert splits[-1] == "view"
|
104 |
-
file_id = splits[-2]
|
105 |
-
return f"https://drive.google.com/uc?export=download&id={file_id}"
|
106 |
-
|
107 |
-
|
108 |
-
def download_google_drive_url(url: str, output_path: str, output_file_name: str):
|
109 |
-
"""
|
110 |
-
Download a file from google drive
|
111 |
-
Downloading an URL from google drive requires confirmation when
|
112 |
-
the file of the size is too big (google drive notifies that
|
113 |
-
anti-viral checks cannot be performed on such files)
|
114 |
-
"""
|
115 |
-
import requests
|
116 |
-
|
117 |
-
with requests.Session() as session:
|
118 |
-
|
119 |
-
# First get the confirmation token and append it to the URL
|
120 |
-
with session.get(url, stream=True, allow_redirects=True) as response:
|
121 |
-
for k, v in response.cookies.items():
|
122 |
-
if k.startswith("download_warning"):
|
123 |
-
url = url + "&confirm=" + v
|
124 |
-
|
125 |
-
# Then download the content of the file
|
126 |
-
with session.get(url, stream=True, verify=True) as response:
|
127 |
-
makedir(output_path)
|
128 |
-
path = os.path.join(output_path, output_file_name)
|
129 |
-
total_size = int(response.headers.get("Content-length", 0))
|
130 |
-
with open(path, "wb") as file:
|
131 |
-
from tqdm import tqdm
|
132 |
-
|
133 |
-
with tqdm(total=total_size) as progress_bar:
|
134 |
-
for block in response.iter_content(
|
135 |
-
chunk_size=io.DEFAULT_BUFFER_SIZE
|
136 |
-
):
|
137 |
-
file.write(block)
|
138 |
-
progress_bar.update(len(block))
|
139 |
-
|
140 |
-
|
141 |
-
def _get_google_drive_file_id(url: str) -> Optional[str]:
|
142 |
-
parts = urlparse(url)
|
143 |
-
|
144 |
-
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
|
145 |
-
return None
|
146 |
-
|
147 |
-
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
|
148 |
-
if match is None:
|
149 |
-
return None
|
150 |
-
|
151 |
-
return match.group("id")
|
152 |
-
|
153 |
-
|
154 |
-
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
|
155 |
-
with open(filename, "wb") as fh:
|
156 |
-
with urllib.request.urlopen(
|
157 |
-
urllib.request.Request(url, headers={"User-Agent": "vissl"})
|
158 |
-
) as response:
|
159 |
-
with tqdm(total=response.length) as pbar:
|
160 |
-
for chunk in iter(lambda: response.read(chunk_size), ""):
|
161 |
-
if not chunk:
|
162 |
-
break
|
163 |
-
pbar.update(chunk_size)
|
164 |
-
fh.write(chunk)
|
165 |
-
|
166 |
-
|
167 |
-
def download_url(
|
168 |
-
url: str,
|
169 |
-
root: str,
|
170 |
-
filename: Optional[str] = None,
|
171 |
-
md5: Optional[str] = None,
|
172 |
-
) -> None:
|
173 |
-
"""Download a file from a url and place it in root.
|
174 |
-
Args:
|
175 |
-
url (str): URL to download file from
|
176 |
-
root (str): Directory to place downloaded file in
|
177 |
-
filename (str, optional): Name to save the file under.
|
178 |
-
If None, use the basename of the URL.
|
179 |
-
md5 (str, optional): MD5 checksum of the download. If None, do not check
|
180 |
-
"""
|
181 |
-
root = os.path.expanduser(root)
|
182 |
-
if not filename:
|
183 |
-
filename = os.path.basename(url)
|
184 |
-
fpath = os.path.join(root, filename)
|
185 |
-
|
186 |
-
makedir(root)
|
187 |
-
|
188 |
-
# check if file is already present locally
|
189 |
-
if check_integrity(fpath, md5):
|
190 |
-
print("Using downloaded and verified file: " + fpath)
|
191 |
-
return
|
192 |
-
|
193 |
-
# expand redirect chain if needed
|
194 |
-
url = get_redirected_url(url)
|
195 |
-
|
196 |
-
# check if file is located on Google Drive
|
197 |
-
file_id = _get_google_drive_file_id(url)
|
198 |
-
if file_id is not None:
|
199 |
-
return download_file_from_google_drive(file_id, root, filename, md5)
|
200 |
-
|
201 |
-
# download the file
|
202 |
-
try:
|
203 |
-
print("Downloading " + url + " to " + fpath)
|
204 |
-
_urlretrieve(url, fpath)
|
205 |
-
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
|
206 |
-
if url[:5] == "https":
|
207 |
-
url = url.replace("https:", "http:")
|
208 |
-
print(
|
209 |
-
"Failed download. Trying https -> http instead."
|
210 |
-
" Downloading " + url + " to " + fpath
|
211 |
-
)
|
212 |
-
_urlretrieve(url, fpath)
|
213 |
-
else:
|
214 |
-
raise e
|
215 |
-
|
216 |
-
# check integrity of downloaded file
|
217 |
-
if not check_integrity(fpath, md5):
|
218 |
-
raise RuntimeError("File not found or corrupted.")
|
219 |
-
|
220 |
-
|
221 |
-
def download_and_extract_archive(
|
222 |
-
url: str,
|
223 |
-
download_root: str,
|
224 |
-
extract_root: Optional[str] = None,
|
225 |
-
filename: Optional[str] = None,
|
226 |
-
md5: Optional[str] = None,
|
227 |
-
remove_finished: bool = False,
|
228 |
-
) -> None:
|
229 |
-
download_root = os.path.expanduser(download_root)
|
230 |
-
if extract_root is None:
|
231 |
-
extract_root = download_root
|
232 |
-
if not filename:
|
233 |
-
filename = os.path.basename(url)
|
234 |
-
|
235 |
-
download_url(url, download_root, filename, md5)
|
236 |
-
|
237 |
-
archive = os.path.join(download_root, filename)
|
238 |
-
print("Extracting {} to {}".format(archive, extract_root))
|
239 |
-
extract_archive(archive, extract_root, remove_finished)
|
240 |
-
|
241 |
-
|
242 |
-
def cache_url(url: str, cache_dir: str) -> str:
|
243 |
-
"""
|
244 |
-
This implementation downloads the remote resource and caches it locally.
|
245 |
-
The resource will only be downloaded if not previously requested.
|
246 |
-
"""
|
247 |
-
parsed_url = urlparse(url)
|
248 |
-
dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
|
249 |
-
makedir(dirname)
|
250 |
-
filename = url.split("/")[-1]
|
251 |
-
cached = os.path.join(dirname, filename)
|
252 |
-
with file_lock(cached):
|
253 |
-
if not os.path.isfile(cached):
|
254 |
-
logging.info(f"Downloading {url} to {cached} ...")
|
255 |
-
cached = download(url, dirname, filename=filename)
|
256 |
-
logging.info(f"URL {url} cached in {cached}")
|
257 |
-
return cached
|
258 |
-
|
259 |
-
|
260 |
-
# TODO (prigoyal): convert this into RAII-style API
|
261 |
-
def create_file_symlink(file1, file2):
|
262 |
-
"""
|
263 |
-
Simply create the symlinks for a given file1 to file2.
|
264 |
-
Useful during model checkpointing to symlinks to the
|
265 |
-
latest successful checkpoint.
|
266 |
-
"""
|
267 |
-
try:
|
268 |
-
if g_pathmgr.exists(file2):
|
269 |
-
g_pathmgr.rm(file2)
|
270 |
-
g_pathmgr.symlink(file1, file2)
|
271 |
-
except Exception as e:
|
272 |
-
logging.info(f"Could NOT create symlink. Error: {e}")
|
273 |
-
|
274 |
-
|
275 |
-
def save_file(data, filename, append_to_json=True, verbose=True):
|
276 |
-
"""
|
277 |
-
Common i/o utility to handle saving data to various file formats.
|
278 |
-
Supported:
|
279 |
-
.pkl, .pickle, .npy, .json
|
280 |
-
Specifically for .json, users have the option to either append (default)
|
281 |
-
or rewrite by passing in Boolean value to append_to_json.
|
282 |
-
"""
|
283 |
-
if verbose:
|
284 |
-
logging.info(f"Saving data to file: {filename}")
|
285 |
-
file_ext = os.path.splitext(filename)[1]
|
286 |
-
if file_ext in [".pkl", ".pickle"]:
|
287 |
-
with g_pathmgr.open(filename, "wb") as fopen:
|
288 |
-
pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
|
289 |
-
elif file_ext == ".npy":
|
290 |
-
with g_pathmgr.open(filename, "wb") as fopen:
|
291 |
-
np.save(fopen, data)
|
292 |
-
elif file_ext == ".json":
|
293 |
-
if append_to_json:
|
294 |
-
with g_pathmgr.open(filename, "a") as fopen:
|
295 |
-
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
296 |
-
fopen.flush()
|
297 |
-
else:
|
298 |
-
with g_pathmgr.open(filename, "w") as fopen:
|
299 |
-
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
300 |
-
fopen.flush()
|
301 |
-
elif file_ext == ".yaml":
|
302 |
-
with g_pathmgr.open(filename, "w") as fopen:
|
303 |
-
dump = yaml.dump(data)
|
304 |
-
fopen.write(dump)
|
305 |
-
fopen.flush()
|
306 |
-
else:
|
307 |
-
raise Exception(f"Saving {file_ext} is not supported yet")
|
308 |
-
|
309 |
-
if verbose:
|
310 |
-
logging.info(f"Saved data to file: {filename}")
|
311 |
-
|
312 |
-
|
313 |
-
def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
|
314 |
-
"""
|
315 |
-
Common i/o utility to handle loading data from various file formats.
|
316 |
-
Supported:
|
317 |
-
.pkl, .pickle, .npy, .json
|
318 |
-
For the npy files, we support reading the files in mmap_mode.
|
319 |
-
If the mmap_mode of reading is not successful, we load data without the
|
320 |
-
mmap_mode.
|
321 |
-
"""
|
322 |
-
if verbose:
|
323 |
-
logging.info(f"Loading data from file: {filename}")
|
324 |
-
|
325 |
-
file_ext = os.path.splitext(filename)[1]
|
326 |
-
if file_ext == ".txt":
|
327 |
-
with g_pathmgr.open(filename, "r") as fopen:
|
328 |
-
data = fopen.readlines()
|
329 |
-
elif file_ext in [".pkl", ".pickle"]:
|
330 |
-
with g_pathmgr.open(filename, "rb") as fopen:
|
331 |
-
data = pickle.load(fopen, encoding="latin1")
|
332 |
-
elif file_ext == ".npy":
|
333 |
-
if mmap_mode:
|
334 |
-
try:
|
335 |
-
with g_pathmgr.open(filename, "rb") as fopen:
|
336 |
-
data = np.load(
|
337 |
-
fopen,
|
338 |
-
allow_pickle=allow_pickle,
|
339 |
-
encoding="latin1",
|
340 |
-
mmap_mode=mmap_mode,
|
341 |
-
)
|
342 |
-
except ValueError as e:
|
343 |
-
logging.info(
|
344 |
-
f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
|
345 |
-
)
|
346 |
-
data = np.load(
|
347 |
-
filename,
|
348 |
-
allow_pickle=allow_pickle,
|
349 |
-
encoding="latin1",
|
350 |
-
mmap_mode=mmap_mode,
|
351 |
-
)
|
352 |
-
logging.info("Successfully loaded without g_pathmgr")
|
353 |
-
except Exception:
|
354 |
-
logging.info("Could not mmap without g_pathmgr. Trying without mmap")
|
355 |
-
with g_pathmgr.open(filename, "rb") as fopen:
|
356 |
-
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
357 |
-
else:
|
358 |
-
with g_pathmgr.open(filename, "rb") as fopen:
|
359 |
-
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
360 |
-
elif file_ext == ".json":
|
361 |
-
with g_pathmgr.open(filename, "r") as fopen:
|
362 |
-
data = json.load(fopen)
|
363 |
-
elif file_ext == ".yaml":
|
364 |
-
with g_pathmgr.open(filename, "r") as fopen:
|
365 |
-
data = yaml.load(fopen, Loader=yaml.FullLoader)
|
366 |
-
elif file_ext == ".csv":
|
367 |
-
with g_pathmgr.open(filename, "r") as fopen:
|
368 |
-
data = pd.read_csv(fopen)
|
369 |
-
else:
|
370 |
-
raise Exception(f"Reading from {file_ext} is not supported yet")
|
371 |
-
return data
|
372 |
-
|
373 |
-
|
374 |
-
def abspath(resource_path: str):
|
375 |
-
"""
|
376 |
-
Make a path absolute, but take into account prefixes like
|
377 |
-
"http://" or "manifold://"
|
378 |
-
"""
|
379 |
-
regex = re.compile(r"^\w+://")
|
380 |
-
if regex.match(resource_path) is None:
|
381 |
-
return os.path.abspath(resource_path)
|
382 |
-
else:
|
383 |
-
return resource_path
|
384 |
-
|
385 |
-
|
386 |
-
def makedir(dir_path):
|
387 |
-
"""
|
388 |
-
Create the directory if it does not exist.
|
389 |
-
"""
|
390 |
-
is_success = False
|
391 |
-
try:
|
392 |
-
if not g_pathmgr.exists(dir_path):
|
393 |
-
g_pathmgr.mkdirs(dir_path)
|
394 |
-
is_success = True
|
395 |
-
except BaseException:
|
396 |
-
logging.info(f"Error creating directory: {dir_path}")
|
397 |
-
return is_success
|
398 |
-
|
399 |
-
|
400 |
-
def is_url(input_url):
|
401 |
-
"""
|
402 |
-
Check if an input string is a url. look for http(s):// and ignoring the case
|
403 |
-
"""
|
404 |
-
is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
|
405 |
-
return is_url
|
406 |
-
|
407 |
-
|
408 |
-
def cleanup_dir(dir):
|
409 |
-
"""
|
410 |
-
Utility for deleting a directory. Useful for cleaning the storage space
|
411 |
-
that contains various training artifacts like checkpoints, data etc.
|
412 |
-
"""
|
413 |
-
if os.path.exists(dir):
|
414 |
-
logging.info(f"Deleting directory: {dir}")
|
415 |
-
shutil.rmtree(dir)
|
416 |
-
logging.info(f"Deleted contents of directory: {dir}")
|
417 |
-
|
418 |
-
|
419 |
-
def get_file_size(filename):
|
420 |
-
"""
|
421 |
-
Given a file, get the size of file in MB
|
422 |
-
"""
|
423 |
-
size_in_mb = os.path.getsize(filename) / float(1024**2)
|
424 |
-
return size_in_mb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/configs/datasets/cc_sbu/align.yaml
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
datasets:
|
2 |
-
cc_sbu_align:
|
3 |
-
data_type: images
|
4 |
-
build_info:
|
5 |
-
storage: /path/to/cc_sbu_align_dataset
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/configs/datasets/cc_sbu/defaults.yaml
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
datasets:
|
2 |
-
cc_sbu:
|
3 |
-
data_type: images
|
4 |
-
build_info:
|
5 |
-
storage: /path/to/cc_sbu_dataset/{00000..00001}.tar
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/configs/datasets/instruct/llava_instruct.yaml
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
datasets:
|
2 |
-
llava_instruct:
|
3 |
-
data_type: image
|
4 |
-
build_info:
|
5 |
-
anno_dir: /path/llava_instruct_150k.json
|
6 |
-
videos_dir: /path/train2014/train2014/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/configs/datasets/instruct/webvid_instruct.yaml
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
datasets:
|
2 |
-
webvid_instruct:
|
3 |
-
data_type: image
|
4 |
-
build_info:
|
5 |
-
anno_dir: /path/webvid_align/videochat_instruct_11k.json
|
6 |
-
videos_dir: /path/webvid_align/videos/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/configs/datasets/laion/defaults.yaml
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
datasets:
|
2 |
-
laion:
|
3 |
-
data_type: images
|
4 |
-
build_info:
|
5 |
-
storage: path/laion/laion_dataset/{00000..00001}.tar
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/configs/datasets/webvid/defaults.yaml
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
datasets:
|
2 |
-
webvid:
|
3 |
-
data_type: video
|
4 |
-
build_info:
|
5 |
-
anno_dir: path/webvid/webvid_tain_data/annotations/
|
6 |
-
videos_dir: path//webvid/webvid_tain_data/videos/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/configs/default.yaml
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
env:
|
2 |
-
# For default users
|
3 |
-
# cache_root: "cache"
|
4 |
-
# For internal use with persistent storage
|
5 |
-
cache_root: "/export/home/.cache/minigpt4"
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/configs/models/minigpt4.yaml
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
arch: mini_gpt4
|
3 |
-
|
4 |
-
# vit encoder
|
5 |
-
image_size: 224
|
6 |
-
drop_path_rate: 0
|
7 |
-
use_grad_checkpoint: False
|
8 |
-
vit_precision: "fp16"
|
9 |
-
freeze_vit: True
|
10 |
-
freeze_qformer: True
|
11 |
-
|
12 |
-
# Q-Former
|
13 |
-
num_query_token: 32
|
14 |
-
|
15 |
-
# Vicuna
|
16 |
-
llama_model: "ckpt/vicuna-13b/"
|
17 |
-
|
18 |
-
# generation configs
|
19 |
-
prompt: ""
|
20 |
-
|
21 |
-
preprocess:
|
22 |
-
vis_processor:
|
23 |
-
train:
|
24 |
-
name: "blip2_image_train"
|
25 |
-
image_size: 224
|
26 |
-
eval:
|
27 |
-
name: "blip2_image_eval"
|
28 |
-
image_size: 224
|
29 |
-
text_processor:
|
30 |
-
train:
|
31 |
-
name: "blip_caption"
|
32 |
-
eval:
|
33 |
-
name: "blip_caption"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/configs/models/video_llama.yaml
DELETED
@@ -1,36 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
arch: video_llama
|
3 |
-
|
4 |
-
# vit encoder
|
5 |
-
image_size: 224
|
6 |
-
drop_path_rate: 0
|
7 |
-
use_grad_checkpoint: False
|
8 |
-
vit_precision: "fp16"
|
9 |
-
freeze_vit: True
|
10 |
-
freeze_qformer: True
|
11 |
-
|
12 |
-
# Q-Former
|
13 |
-
num_query_token: 32
|
14 |
-
|
15 |
-
# Vicuna
|
16 |
-
llama_model: "ckpt/vicuna-7b/"
|
17 |
-
|
18 |
-
# generation configs
|
19 |
-
prompt: ""
|
20 |
-
|
21 |
-
preprocess:
|
22 |
-
vis_processor:
|
23 |
-
train:
|
24 |
-
name: "alpro_video_train"
|
25 |
-
image_size: 224
|
26 |
-
n_frms: 8
|
27 |
-
eval:
|
28 |
-
name: "alpro_video_eval"
|
29 |
-
image_size: 224
|
30 |
-
n_frms: 8
|
31 |
-
text_processor:
|
32 |
-
train:
|
33 |
-
name: "blip_caption"
|
34 |
-
eval:
|
35 |
-
name: "blip_caption"
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/conversation/__init__.py
DELETED
File without changes
|
sonique/Video_LLaMA/video_llama/conversation/conversation_video.py
DELETED
@@ -1,348 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Conversation prompt template of Video-LLaMA.
|
3 |
-
Adapted from: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/conversation/conversation.py
|
4 |
-
"""
|
5 |
-
import argparse
|
6 |
-
import time
|
7 |
-
from PIL import Image
|
8 |
-
import sys
|
9 |
-
import os
|
10 |
-
import torch
|
11 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
|
12 |
-
from transformers import StoppingCriteria, StoppingCriteriaList
|
13 |
-
|
14 |
-
import dataclasses
|
15 |
-
from enum import auto, Enum
|
16 |
-
from typing import List, Tuple, Any
|
17 |
-
import os
|
18 |
-
from sonique.Video_LLaMA.video_llama.common.registry import registry
|
19 |
-
from sonique.Video_LLaMA.video_llama.processors.video_processor import ToTHWC,ToUint8,load_video
|
20 |
-
from sonique.Video_LLaMA.video_llama.processors import Blip2ImageEvalProcessor
|
21 |
-
|
22 |
-
from sonique.Video_LLaMA.video_llama.models.ImageBind.data import load_and_transform_audio_data
|
23 |
-
class SeparatorStyle(Enum):
|
24 |
-
"""Different separator style."""
|
25 |
-
SINGLE = auto()
|
26 |
-
TWO = auto()
|
27 |
-
LLAMA_2 = auto()
|
28 |
-
|
29 |
-
|
30 |
-
@dataclasses.dataclass
|
31 |
-
class Conversation:
|
32 |
-
"""A class that keeps all conversation history."""
|
33 |
-
system: str
|
34 |
-
roles: List[str]
|
35 |
-
messages: List[List[str]]
|
36 |
-
offset: int
|
37 |
-
# system_img: List[Image.Image] = []
|
38 |
-
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
39 |
-
sep: str = "###"
|
40 |
-
sep2: str = None
|
41 |
-
|
42 |
-
skip_next: bool = False
|
43 |
-
conv_id: Any = None
|
44 |
-
|
45 |
-
def get_prompt(self):
|
46 |
-
if self.sep_style == SeparatorStyle.SINGLE:
|
47 |
-
ret = self.system + self.sep
|
48 |
-
for role, message in self.messages:
|
49 |
-
if message:
|
50 |
-
ret += role + ": " + message + self.sep
|
51 |
-
else:
|
52 |
-
ret += role + ":"
|
53 |
-
return ret
|
54 |
-
elif self.sep_style == SeparatorStyle.TWO:
|
55 |
-
seps = [self.sep, self.sep2]
|
56 |
-
ret = self.system + seps[0]
|
57 |
-
for i, (role, message) in enumerate(self.messages):
|
58 |
-
if message:
|
59 |
-
ret += role + ": " + message + seps[i % 2]
|
60 |
-
else:
|
61 |
-
ret += role + ":"
|
62 |
-
return ret
|
63 |
-
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
64 |
-
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
65 |
-
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
66 |
-
ret = ""
|
67 |
-
|
68 |
-
for i, (role, message) in enumerate(self.messages):
|
69 |
-
if i == 0:
|
70 |
-
assert message, "first message should not be none"
|
71 |
-
assert role == self.roles[0], "first message should come from user"
|
72 |
-
if message:
|
73 |
-
if type(message) is tuple:
|
74 |
-
message, _, _ = message
|
75 |
-
if i == 0: message = wrap_sys(self.system) + message
|
76 |
-
if i % 2 == 0:
|
77 |
-
message = wrap_inst(message)
|
78 |
-
ret += self.sep + message
|
79 |
-
else:
|
80 |
-
ret += " " + message + " " + self.sep2
|
81 |
-
else:
|
82 |
-
ret += ""
|
83 |
-
ret = ret.lstrip(self.sep)
|
84 |
-
return ret
|
85 |
-
else:
|
86 |
-
raise ValueError(f"Invalid style: {self.sep_style}")
|
87 |
-
|
88 |
-
def append_message(self, role, message):
|
89 |
-
self.messages.append([role, message])
|
90 |
-
|
91 |
-
def to_gradio_chatbot(self):
|
92 |
-
ret = []
|
93 |
-
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
94 |
-
if i % 2 == 0:
|
95 |
-
ret.append([msg, None])
|
96 |
-
else:
|
97 |
-
ret[-1][-1] = msg
|
98 |
-
return ret
|
99 |
-
|
100 |
-
def copy(self):
|
101 |
-
return Conversation(
|
102 |
-
system=self.system,
|
103 |
-
# system_img=self.system_img,
|
104 |
-
roles=self.roles,
|
105 |
-
messages=[[x, y] for x, y in self.messages],
|
106 |
-
offset=self.offset,
|
107 |
-
sep_style=self.sep_style,
|
108 |
-
sep=self.sep,
|
109 |
-
sep2=self.sep2,
|
110 |
-
conv_id=self.conv_id)
|
111 |
-
|
112 |
-
def dict(self):
|
113 |
-
return {
|
114 |
-
"system": self.system,
|
115 |
-
# "system_img": self.system_img,
|
116 |
-
"roles": self.roles,
|
117 |
-
"messages": self.messages,
|
118 |
-
"offset": self.offset,
|
119 |
-
"sep": self.sep,
|
120 |
-
"sep2": self.sep2,
|
121 |
-
"conv_id": self.conv_id,
|
122 |
-
}
|
123 |
-
|
124 |
-
|
125 |
-
class StoppingCriteriaSub(StoppingCriteria):
|
126 |
-
|
127 |
-
def __init__(self, stops=[], encounters=1):
|
128 |
-
super().__init__()
|
129 |
-
self.stops = stops
|
130 |
-
|
131 |
-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
132 |
-
for stop in self.stops:
|
133 |
-
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
134 |
-
return True
|
135 |
-
|
136 |
-
return False
|
137 |
-
|
138 |
-
|
139 |
-
CONV_VISION = Conversation(
|
140 |
-
system="Give the following image: <Img>ImageContent</Img>. "
|
141 |
-
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
142 |
-
roles=("Human", "Assistant"),
|
143 |
-
messages=[],
|
144 |
-
offset=0,
|
145 |
-
sep_style=SeparatorStyle.SINGLE,
|
146 |
-
sep="###",
|
147 |
-
)
|
148 |
-
|
149 |
-
default_conversation = Conversation(
|
150 |
-
system="",
|
151 |
-
roles=("Human", "Assistant"),
|
152 |
-
messages=[],
|
153 |
-
offset=0,
|
154 |
-
sep_style=SeparatorStyle.SINGLE,
|
155 |
-
sep="###",
|
156 |
-
)
|
157 |
-
conv_llava_llama_2 = Conversation(
|
158 |
-
system="You are a helpful language and vision assistant. "
|
159 |
-
"You are able to understand the visual content that the user provides, "
|
160 |
-
"and assist the user with a variety of tasks using natural language.",
|
161 |
-
roles=("USER", "ASSISTANT"),
|
162 |
-
messages=(),
|
163 |
-
offset=0,
|
164 |
-
sep_style=SeparatorStyle.LLAMA_2,
|
165 |
-
sep="<s>",
|
166 |
-
sep2="</s>",
|
167 |
-
)
|
168 |
-
class Chat:
|
169 |
-
def __init__(self, model, vis_processor, device='cuda:0'):
|
170 |
-
self.device = device
|
171 |
-
self.model = model
|
172 |
-
self.vis_processor = vis_processor
|
173 |
-
self.image_vis_processor = Blip2ImageEvalProcessor()
|
174 |
-
# stop_words_ids = [torch.tensor([835]).to(self.device),
|
175 |
-
# torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
|
176 |
-
# self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
177 |
-
|
178 |
-
def ask(self, text, conv):
|
179 |
-
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
180 |
-
and ('</Video>' in conv.messages[-1][1] or '</Image>' in conv.messages[-1][1]): # last message is image.
|
181 |
-
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
182 |
-
else:
|
183 |
-
conv.append_message(conv.roles[0], text)
|
184 |
-
|
185 |
-
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
186 |
-
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
|
187 |
-
conv.append_message(conv.roles[1], None)
|
188 |
-
embs = self.get_context_emb(conv, img_list)
|
189 |
-
|
190 |
-
current_max_len = embs.shape[1] + max_new_tokens
|
191 |
-
if current_max_len - max_length > 0:
|
192 |
-
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
193 |
-
'The model will not see the contexts outside the range.')
|
194 |
-
begin_idx = max(0, current_max_len - max_length)
|
195 |
-
|
196 |
-
embs = embs[:, begin_idx:]
|
197 |
-
if conv.sep =="###":
|
198 |
-
stop_words_ids = [torch.tensor([835]).to(self.device),
|
199 |
-
torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
|
200 |
-
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
201 |
-
else:
|
202 |
-
stop_words_ids = [torch.tensor([2]).to(self.device)]
|
203 |
-
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
204 |
-
|
205 |
-
# stopping_criteria
|
206 |
-
outputs = self.model.llama_model.generate(
|
207 |
-
inputs_embeds=embs,
|
208 |
-
max_new_tokens=max_new_tokens,
|
209 |
-
stopping_criteria=stopping_criteria,
|
210 |
-
num_beams=num_beams,
|
211 |
-
do_sample=True,
|
212 |
-
min_length=min_length,
|
213 |
-
top_p=top_p,
|
214 |
-
repetition_penalty=repetition_penalty,
|
215 |
-
length_penalty=length_penalty,
|
216 |
-
temperature=temperature,
|
217 |
-
)
|
218 |
-
output_token = outputs[0]
|
219 |
-
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
|
220 |
-
output_token = output_token[1:]
|
221 |
-
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
|
222 |
-
output_token = output_token[1:]
|
223 |
-
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
|
224 |
-
if conv.sep =="###":
|
225 |
-
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
226 |
-
output_text = output_text.split('Assistant:')[-1].strip()
|
227 |
-
else:
|
228 |
-
output_text = output_text.split(conv.sep2)[0] # remove the stop sign '###'
|
229 |
-
output_text = output_text.split(conv.roles[1]+':')[-1].strip()
|
230 |
-
conv.messages[-1][1] = output_text
|
231 |
-
return output_text, output_token.cpu().numpy()
|
232 |
-
|
233 |
-
def upload_video(self, video_path, conv, img_list):
|
234 |
-
|
235 |
-
msg = ""
|
236 |
-
if isinstance(video_path, str): # is a video path
|
237 |
-
ext = os.path.splitext(video_path)[-1].lower()
|
238 |
-
print(video_path)
|
239 |
-
# image = self.vis_processor(image).unsqueeze(0).to(self.device)
|
240 |
-
video, msg = load_video(
|
241 |
-
video_path=video_path,
|
242 |
-
n_frms=8,
|
243 |
-
height=224,
|
244 |
-
width=224,
|
245 |
-
sampling ="uniform", return_msg = True
|
246 |
-
)
|
247 |
-
video = self.vis_processor.transform(video)
|
248 |
-
video = video.unsqueeze(0).to(self.device)
|
249 |
-
# print(image)
|
250 |
-
else:
|
251 |
-
raise NotImplementedError
|
252 |
-
|
253 |
-
try:
|
254 |
-
audio_flag = 1
|
255 |
-
audio = load_and_transform_audio_data([video_path],"cpu", clips_per_video=8)
|
256 |
-
audio = audio.to(self.device)
|
257 |
-
except :
|
258 |
-
print('no audio is found')
|
259 |
-
audio_flag = 0
|
260 |
-
finally:
|
261 |
-
if audio_flag == 1:
|
262 |
-
# image_emb, _ = self.model.encode_videoQformer_audiovideo(video,audio)
|
263 |
-
image_emb, _ = self.model.encode_videoQformer_visual(video)
|
264 |
-
audio_emb,_ = self.model.encode_audioQformer(audio)
|
265 |
-
img_list.append(audio_emb)
|
266 |
-
img_list.append(image_emb)
|
267 |
-
conv.system = ""
|
268 |
-
# conv.append_message(conv.roles[0], "The audio of this video is <Video><ImageHere></Video> ")
|
269 |
-
conv.append_message(conv.roles[0], "Close your eyes, open your ears and you imagine only based on the sound that: <ImageHere>. \
|
270 |
-
Close your ears, open your eyes and you see that <Video><ImageHere></Video>. \
|
271 |
-
Now answer my question based on what you have just seen and heard.")
|
272 |
-
|
273 |
-
else: # only vison no audio
|
274 |
-
# conv.system = "You can understand the video that the user provides. Follow the instructions carefully and explain your answers in detail."
|
275 |
-
image_emb, _ = self.model.encode_videoQformer_visual(video)
|
276 |
-
img_list.append(image_emb)
|
277 |
-
conv.append_message(conv.roles[0], "<Video><ImageHere></Video> "+ msg)
|
278 |
-
return "Received."
|
279 |
-
|
280 |
-
def upload_video_without_audio(self, video_path, conv, img_list):
|
281 |
-
msg = ""
|
282 |
-
if isinstance(video_path, str): # is a video path
|
283 |
-
ext = os.path.splitext(video_path)[-1].lower()
|
284 |
-
print(video_path)
|
285 |
-
# image = self.vis_processor(image).unsqueeze(0).to(self.device)
|
286 |
-
video, msg = load_video(
|
287 |
-
video_path=video_path,
|
288 |
-
n_frms=8,
|
289 |
-
height=224,
|
290 |
-
width=224,
|
291 |
-
sampling ="uniform", return_msg = True
|
292 |
-
)
|
293 |
-
video = self.vis_processor.transform(video)
|
294 |
-
video = video.unsqueeze(0).to(self.device)
|
295 |
-
# print(image)
|
296 |
-
else:
|
297 |
-
raise NotImplementedError
|
298 |
-
|
299 |
-
|
300 |
-
# conv.system = "You can understand the video that the user provides. Follow the instructions carefully and explain your answers in detail."
|
301 |
-
image_emb, _ = self.model.encode_videoQformer_visual(video)
|
302 |
-
img_list.append(image_emb)
|
303 |
-
conv.append_message(conv.roles[0], "<Video><ImageHere></Video> "+ msg)
|
304 |
-
return "Received."
|
305 |
-
|
306 |
-
def upload_img(self, image, conv, img_list):
|
307 |
-
|
308 |
-
msg = ""
|
309 |
-
if isinstance(image, str): # is a image path
|
310 |
-
raw_image = Image.open(image).convert('RGB') # 增加一个时间维度
|
311 |
-
image = self.image_vis_processor(raw_image).unsqueeze(0).unsqueeze(2).to(self.device)
|
312 |
-
elif isinstance(image, Image.Image):
|
313 |
-
raw_image = image
|
314 |
-
image = self.image_vis_processor(raw_image).unsqueeze(0).unsqueeze(2).to(self.device)
|
315 |
-
elif isinstance(image, torch.Tensor):
|
316 |
-
if len(image.shape) == 3:
|
317 |
-
image = image.unsqueeze(0)
|
318 |
-
image = image.to(self.device)
|
319 |
-
else:
|
320 |
-
raise NotImplementedError
|
321 |
-
|
322 |
-
image_emb, _ = self.model.encode_videoQformer_visual(image)
|
323 |
-
img_list.append(image_emb)
|
324 |
-
# Todo msg=""
|
325 |
-
conv.append_message(conv.roles[0], "<Image><ImageHere></Image> "+ msg)
|
326 |
-
|
327 |
-
return "Received."
|
328 |
-
|
329 |
-
def get_context_emb(self, conv, img_list):
|
330 |
-
prompt = conv.get_prompt()
|
331 |
-
prompt_segs = prompt.split('<ImageHere>')
|
332 |
-
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
333 |
-
seg_tokens = [
|
334 |
-
self.model.llama_tokenizer(
|
335 |
-
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
|
336 |
-
# only add bos to the first seg
|
337 |
-
for i, seg in enumerate(prompt_segs)
|
338 |
-
]
|
339 |
-
seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
340 |
-
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
341 |
-
mixed_embs = torch.cat(mixed_embs, dim=1)
|
342 |
-
return mixed_embs
|
343 |
-
|
344 |
-
if __name__ =='__main__':
|
345 |
-
video_path = '/mnt/workspace/videoGPT/Video-LLaMA/examples/applausing.mp4'
|
346 |
-
# import torch.classes.torchaudio.ffmpeg_StreamReader
|
347 |
-
# ffmpeg_StreamReader(video_path)
|
348 |
-
load_and_transform_audio_data([video_path],"cpu", clips_per_video=8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/datasets/__init__.py
DELETED
File without changes
|
sonique/Video_LLaMA/video_llama/datasets/builders/__init__.py
DELETED
@@ -1,77 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
from sonique.Video_LLaMA.video_llama.datasets.builders.base_dataset_builder import load_dataset_config
|
9 |
-
from sonique.Video_LLaMA.video_llama.datasets.builders.image_text_pair_builder import (
|
10 |
-
CCSBUBuilder,
|
11 |
-
LaionBuilder,
|
12 |
-
CCSBUAlignBuilder
|
13 |
-
)
|
14 |
-
from sonique.Video_LLaMA.video_llama.datasets.builders.video_caption_builder import WebvidBuilder
|
15 |
-
from sonique.Video_LLaMA.video_llama.common.registry import registry
|
16 |
-
from sonique.Video_LLaMA.video_llama.datasets.builders.instruct_builder import WebvidInstruct_Builder,LlavaInstruct_Builder
|
17 |
-
__all__ = [
|
18 |
-
"CCSBUBuilder",
|
19 |
-
"LaionBuilder",
|
20 |
-
"CCSBUAlignBuilder",
|
21 |
-
"WebvidBuilder",
|
22 |
-
"LlavaInstruct_Builder",
|
23 |
-
"WebvidInstruct_Builder"
|
24 |
-
|
25 |
-
]
|
26 |
-
|
27 |
-
|
28 |
-
def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
|
29 |
-
"""
|
30 |
-
Example
|
31 |
-
|
32 |
-
>>> dataset = load_dataset("coco_caption", cfg=None)
|
33 |
-
>>> splits = dataset.keys()
|
34 |
-
>>> print([len(dataset[split]) for split in splits])
|
35 |
-
|
36 |
-
"""
|
37 |
-
if cfg_path is None:
|
38 |
-
cfg = None
|
39 |
-
else:
|
40 |
-
cfg = load_dataset_config(cfg_path)
|
41 |
-
|
42 |
-
try:
|
43 |
-
builder = registry.get_builder_class(name)(cfg)
|
44 |
-
except TypeError:
|
45 |
-
print(
|
46 |
-
f"Dataset {name} not found. Available datasets:\n"
|
47 |
-
+ ", ".join([str(k) for k in dataset_zoo.get_names()])
|
48 |
-
)
|
49 |
-
exit(1)
|
50 |
-
|
51 |
-
if vis_path is not None:
|
52 |
-
if data_type is None:
|
53 |
-
# use default data type in the config
|
54 |
-
data_type = builder.config.data_type
|
55 |
-
|
56 |
-
assert (
|
57 |
-
data_type in builder.config.build_info
|
58 |
-
), f"Invalid data_type {data_type} for {name}."
|
59 |
-
|
60 |
-
builder.config.build_info.get(data_type).storage = vis_path
|
61 |
-
|
62 |
-
dataset = builder.build_datasets()
|
63 |
-
return dataset
|
64 |
-
|
65 |
-
|
66 |
-
class DatasetZoo:
|
67 |
-
def __init__(self) -> None:
|
68 |
-
self.dataset_zoo = {
|
69 |
-
k: list(v.DATASET_CONFIG_DICT.keys())
|
70 |
-
for k, v in sorted(registry.mapping["builder_name_mapping"].items())
|
71 |
-
}
|
72 |
-
|
73 |
-
def get_names(self):
|
74 |
-
return list(self.dataset_zoo.keys())
|
75 |
-
|
76 |
-
|
77 |
-
dataset_zoo = DatasetZoo()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/datasets/builders/base_dataset_builder.py
DELETED
@@ -1,236 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
This file is from
|
3 |
-
Copyright (c) 2022, salesforce.com, inc.
|
4 |
-
All rights reserved.
|
5 |
-
SPDX-License-Identifier: BSD-3-Clause
|
6 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
7 |
-
"""
|
8 |
-
|
9 |
-
import logging
|
10 |
-
import os
|
11 |
-
import shutil
|
12 |
-
import warnings
|
13 |
-
|
14 |
-
from omegaconf import OmegaConf
|
15 |
-
import torch.distributed as dist
|
16 |
-
from torchvision.datasets.utils import download_url
|
17 |
-
|
18 |
-
import sonique.Video_LLaMA.video_llama.common.utils as utils
|
19 |
-
from sonique.Video_LLaMA.video_llama.common.dist_utils import is_dist_avail_and_initialized, is_main_process
|
20 |
-
from sonique.Video_LLaMA.video_llama.common.registry import registry
|
21 |
-
from sonique.Video_LLaMA.video_llama.processors.base_processor import BaseProcessor
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
class BaseDatasetBuilder:
|
26 |
-
train_dataset_cls, eval_dataset_cls = None, None
|
27 |
-
|
28 |
-
def __init__(self, cfg=None):
|
29 |
-
super().__init__()
|
30 |
-
|
31 |
-
if cfg is None:
|
32 |
-
# help to create datasets from default config.
|
33 |
-
self.config = load_dataset_config(self.default_config_path())
|
34 |
-
elif isinstance(cfg, str):
|
35 |
-
self.config = load_dataset_config(cfg)
|
36 |
-
else:
|
37 |
-
# when called from task.build_dataset()
|
38 |
-
self.config = cfg
|
39 |
-
|
40 |
-
self.data_type = self.config.data_type
|
41 |
-
|
42 |
-
self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
43 |
-
self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
44 |
-
|
45 |
-
def build_datasets(self):
|
46 |
-
# download, split, etc...
|
47 |
-
# only called on 1 GPU/TPU in distributed
|
48 |
-
|
49 |
-
if is_main_process():
|
50 |
-
self._download_data()
|
51 |
-
|
52 |
-
if is_dist_avail_and_initialized():
|
53 |
-
dist.barrier()
|
54 |
-
|
55 |
-
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
56 |
-
logging.info("Building datasets...")
|
57 |
-
datasets = self.build() # dataset['train'/'val'/'test']
|
58 |
-
|
59 |
-
return datasets
|
60 |
-
|
61 |
-
def build_processors(self):
|
62 |
-
vis_proc_cfg = self.config.get("vis_processor")
|
63 |
-
txt_proc_cfg = self.config.get("text_processor")
|
64 |
-
|
65 |
-
if vis_proc_cfg is not None:
|
66 |
-
vis_train_cfg = vis_proc_cfg.get("train")
|
67 |
-
vis_eval_cfg = vis_proc_cfg.get("eval")
|
68 |
-
|
69 |
-
self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
|
70 |
-
self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
|
71 |
-
|
72 |
-
if txt_proc_cfg is not None:
|
73 |
-
txt_train_cfg = txt_proc_cfg.get("train")
|
74 |
-
txt_eval_cfg = txt_proc_cfg.get("eval")
|
75 |
-
|
76 |
-
self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
|
77 |
-
self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
|
78 |
-
|
79 |
-
@staticmethod
|
80 |
-
def _build_proc_from_cfg(cfg):
|
81 |
-
return (
|
82 |
-
registry.get_processor_class(cfg.name).from_config(cfg)
|
83 |
-
if cfg is not None
|
84 |
-
else None
|
85 |
-
)
|
86 |
-
|
87 |
-
@classmethod
|
88 |
-
def default_config_path(cls, type="default"):
|
89 |
-
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
|
90 |
-
|
91 |
-
def _download_data(self):
|
92 |
-
self._download_ann()
|
93 |
-
self._download_vis()
|
94 |
-
|
95 |
-
def _download_ann(self):
|
96 |
-
"""
|
97 |
-
Download annotation files if necessary.
|
98 |
-
All the vision-language datasets should have annotations of unified format.
|
99 |
-
|
100 |
-
storage_path can be:
|
101 |
-
(1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
|
102 |
-
(2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
|
103 |
-
|
104 |
-
Local annotation paths should be relative.
|
105 |
-
"""
|
106 |
-
anns = self.config.build_info.annotations
|
107 |
-
|
108 |
-
splits = anns.keys()
|
109 |
-
|
110 |
-
cache_root = registry.get_path("cache_root")
|
111 |
-
|
112 |
-
for split in splits:
|
113 |
-
info = anns[split]
|
114 |
-
|
115 |
-
urls, storage_paths = info.get("url", None), info.storage
|
116 |
-
|
117 |
-
if isinstance(urls, str):
|
118 |
-
urls = [urls]
|
119 |
-
if isinstance(storage_paths, str):
|
120 |
-
storage_paths = [storage_paths]
|
121 |
-
|
122 |
-
assert len(urls) == len(storage_paths)
|
123 |
-
|
124 |
-
for url_or_filename, storage_path in zip(urls, storage_paths):
|
125 |
-
# if storage_path is relative, make it full by prefixing with cache_root.
|
126 |
-
if not os.path.isabs(storage_path):
|
127 |
-
storage_path = os.path.join(cache_root, storage_path)
|
128 |
-
|
129 |
-
dirname = os.path.dirname(storage_path)
|
130 |
-
if not os.path.exists(dirname):
|
131 |
-
os.makedirs(dirname)
|
132 |
-
|
133 |
-
if os.path.isfile(url_or_filename):
|
134 |
-
src, dst = url_or_filename, storage_path
|
135 |
-
if not os.path.exists(dst):
|
136 |
-
shutil.copyfile(src=src, dst=dst)
|
137 |
-
else:
|
138 |
-
logging.info("Using existing file {}.".format(dst))
|
139 |
-
else:
|
140 |
-
if os.path.isdir(storage_path):
|
141 |
-
# if only dirname is provided, suffix with basename of URL.
|
142 |
-
raise ValueError(
|
143 |
-
"Expecting storage_path to be a file path, got directory {}".format(
|
144 |
-
storage_path
|
145 |
-
)
|
146 |
-
)
|
147 |
-
else:
|
148 |
-
filename = os.path.basename(storage_path)
|
149 |
-
|
150 |
-
download_url(url=url_or_filename, root=dirname, filename=filename)
|
151 |
-
|
152 |
-
def _download_vis(self):
|
153 |
-
|
154 |
-
storage_path = self.config.build_info.get(self.data_type).storage
|
155 |
-
storage_path = utils.get_cache_path(storage_path)
|
156 |
-
|
157 |
-
if not os.path.exists(storage_path):
|
158 |
-
warnings.warn(
|
159 |
-
f"""
|
160 |
-
The specified path {storage_path} for visual inputs does not exist.
|
161 |
-
Please provide a correct path to the visual inputs or
|
162 |
-
refer to datasets/download_scripts/README.md for downloading instructions.
|
163 |
-
"""
|
164 |
-
)
|
165 |
-
|
166 |
-
def build(self):
|
167 |
-
"""
|
168 |
-
Create by split datasets inheriting torch.utils.data.Datasets.
|
169 |
-
|
170 |
-
# build() can be dataset-specific. Overwrite to customize.
|
171 |
-
"""
|
172 |
-
self.build_processors()
|
173 |
-
|
174 |
-
build_info = self.config.build_info
|
175 |
-
|
176 |
-
ann_info = build_info.annotations
|
177 |
-
vis_info = build_info.get(self.data_type)
|
178 |
-
|
179 |
-
datasets = dict()
|
180 |
-
for split in ann_info.keys():
|
181 |
-
if split not in ["train", "val", "test"]:
|
182 |
-
continue
|
183 |
-
|
184 |
-
is_train = split == "train"
|
185 |
-
|
186 |
-
# processors
|
187 |
-
vis_processor = (
|
188 |
-
self.vis_processors["train"]
|
189 |
-
if is_train
|
190 |
-
else self.vis_processors["eval"]
|
191 |
-
)
|
192 |
-
text_processor = (
|
193 |
-
self.text_processors["train"]
|
194 |
-
if is_train
|
195 |
-
else self.text_processors["eval"]
|
196 |
-
)
|
197 |
-
|
198 |
-
# annotation path
|
199 |
-
ann_paths = ann_info.get(split).storage
|
200 |
-
if isinstance(ann_paths, str):
|
201 |
-
ann_paths = [ann_paths]
|
202 |
-
|
203 |
-
abs_ann_paths = []
|
204 |
-
for ann_path in ann_paths:
|
205 |
-
if not os.path.isabs(ann_path):
|
206 |
-
ann_path = utils.get_cache_path(ann_path)
|
207 |
-
abs_ann_paths.append(ann_path)
|
208 |
-
ann_paths = abs_ann_paths
|
209 |
-
|
210 |
-
# visual data storage path
|
211 |
-
vis_path = os.path.join(vis_info.storage, split)
|
212 |
-
|
213 |
-
if not os.path.isabs(vis_path):
|
214 |
-
# vis_path = os.path.join(utils.get_cache_path(), vis_path)
|
215 |
-
vis_path = utils.get_cache_path(vis_path)
|
216 |
-
|
217 |
-
if not os.path.exists(vis_path):
|
218 |
-
warnings.warn("storage path {} does not exist.".format(vis_path))
|
219 |
-
|
220 |
-
# create datasets
|
221 |
-
dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
|
222 |
-
datasets[split] = dataset_cls(
|
223 |
-
vis_processor=vis_processor,
|
224 |
-
text_processor=text_processor,
|
225 |
-
ann_paths=ann_paths,
|
226 |
-
vis_root=vis_path,
|
227 |
-
)
|
228 |
-
|
229 |
-
return datasets
|
230 |
-
|
231 |
-
|
232 |
-
def load_dataset_config(cfg_path):
|
233 |
-
cfg = OmegaConf.load(cfg_path).datasets
|
234 |
-
cfg = cfg[list(cfg.keys())[0]]
|
235 |
-
|
236 |
-
return cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/datasets/builders/image_text_pair_builder.py
DELETED
@@ -1,106 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import logging
|
3 |
-
import warnings
|
4 |
-
|
5 |
-
from sonique.Video_LLaMA.video_llama.common.registry import registry
|
6 |
-
from sonique.Video_LLaMA.video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
7 |
-
from sonique.Video_LLaMA.video_llama.datasets.datasets.laion_dataset import LaionDataset
|
8 |
-
from sonique.Video_LLaMA.video_llama.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
|
9 |
-
|
10 |
-
|
11 |
-
@registry.register_builder("cc_sbu")
|
12 |
-
class CCSBUBuilder(BaseDatasetBuilder):
|
13 |
-
train_dataset_cls = CCSBUDataset
|
14 |
-
|
15 |
-
DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
|
16 |
-
|
17 |
-
def _download_ann(self):
|
18 |
-
pass
|
19 |
-
|
20 |
-
def _download_vis(self):
|
21 |
-
pass
|
22 |
-
|
23 |
-
def build(self):
|
24 |
-
self.build_processors()
|
25 |
-
|
26 |
-
build_info = self.config.build_info
|
27 |
-
|
28 |
-
datasets = dict()
|
29 |
-
split = "train"
|
30 |
-
|
31 |
-
# create datasets
|
32 |
-
# [NOTE] return inner_datasets (wds.DataPipeline)
|
33 |
-
dataset_cls = self.train_dataset_cls
|
34 |
-
datasets[split] = dataset_cls(
|
35 |
-
vis_processor=self.vis_processors[split],
|
36 |
-
text_processor=self.text_processors[split],
|
37 |
-
location=build_info.storage,
|
38 |
-
).inner_dataset
|
39 |
-
|
40 |
-
return datasets
|
41 |
-
|
42 |
-
|
43 |
-
@registry.register_builder("laion")
|
44 |
-
class LaionBuilder(BaseDatasetBuilder):
|
45 |
-
train_dataset_cls = LaionDataset
|
46 |
-
|
47 |
-
DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
|
48 |
-
|
49 |
-
def _download_ann(self):
|
50 |
-
pass
|
51 |
-
|
52 |
-
def _download_vis(self):
|
53 |
-
pass
|
54 |
-
|
55 |
-
def build(self):
|
56 |
-
self.build_processors()
|
57 |
-
|
58 |
-
build_info = self.config.build_info
|
59 |
-
|
60 |
-
datasets = dict()
|
61 |
-
split = "train"
|
62 |
-
|
63 |
-
# create datasets
|
64 |
-
# [NOTE] return inner_datasets (wds.DataPipeline)
|
65 |
-
dataset_cls = self.train_dataset_cls
|
66 |
-
datasets[split] = dataset_cls(
|
67 |
-
vis_processor=self.vis_processors[split],
|
68 |
-
text_processor=self.text_processors[split],
|
69 |
-
location=build_info.storage,
|
70 |
-
).inner_dataset
|
71 |
-
|
72 |
-
return datasets
|
73 |
-
|
74 |
-
|
75 |
-
@registry.register_builder("cc_sbu_align")
|
76 |
-
class CCSBUAlignBuilder(BaseDatasetBuilder):
|
77 |
-
train_dataset_cls = CCSBUAlignDataset
|
78 |
-
|
79 |
-
DATASET_CONFIG_DICT = {
|
80 |
-
"default": "configs/datasets/cc_sbu/align.yaml",
|
81 |
-
}
|
82 |
-
|
83 |
-
def build_datasets(self):
|
84 |
-
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
85 |
-
logging.info("Building datasets...")
|
86 |
-
self.build_processors()
|
87 |
-
|
88 |
-
build_info = self.config.build_info
|
89 |
-
storage_path = build_info.storage
|
90 |
-
|
91 |
-
datasets = dict()
|
92 |
-
|
93 |
-
if not os.path.exists(storage_path):
|
94 |
-
warnings.warn("storage path {} does not exist.".format(storage_path))
|
95 |
-
|
96 |
-
# create datasets
|
97 |
-
dataset_cls = self.train_dataset_cls
|
98 |
-
datasets['train'] = dataset_cls(
|
99 |
-
vis_processor=self.vis_processors["train"],
|
100 |
-
text_processor=self.text_processors["train"],
|
101 |
-
ann_paths=[os.path.join(storage_path, 'filter_cap.json')],
|
102 |
-
vis_root=os.path.join(storage_path, 'image'),
|
103 |
-
)
|
104 |
-
|
105 |
-
return datasets
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/datasets/builders/instruct_builder.py
DELETED
@@ -1,79 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import logging
|
3 |
-
import warnings
|
4 |
-
|
5 |
-
from sonique.Video_LLaMA.video_llama.common.registry import registry
|
6 |
-
from sonique.Video_LLaMA.video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
7 |
-
from sonique.Video_LLaMA.video_llama.datasets.datasets.laion_dataset import LaionDataset
|
8 |
-
from sonique.Video_LLaMA.video_llama.datasets.datasets.llava_instruct_dataset import Instruct_Dataset
|
9 |
-
from sonique.Video_LLaMA.video_llama.datasets.datasets.video_instruct_dataset import Video_Instruct_Dataset
|
10 |
-
|
11 |
-
@registry.register_builder("instruct")
|
12 |
-
class Instruct_Builder(BaseDatasetBuilder):
|
13 |
-
train_dataset_cls = Instruct_Dataset
|
14 |
-
|
15 |
-
DATASET_CONFIG_DICT = {"default": "configs/datasets/instruct/defaults.yaml"}
|
16 |
-
|
17 |
-
def _download_ann(self):
|
18 |
-
pass
|
19 |
-
|
20 |
-
def _download_vis(self):
|
21 |
-
pass
|
22 |
-
|
23 |
-
def build(self):
|
24 |
-
self.build_processors()
|
25 |
-
datasets = dict()
|
26 |
-
split = "train"
|
27 |
-
|
28 |
-
build_info = self.config.build_info
|
29 |
-
dataset_cls = self.train_dataset_cls
|
30 |
-
if self.config.num_video_query_token:
|
31 |
-
num_video_query_token = self.config.num_video_query_token
|
32 |
-
else:
|
33 |
-
num_video_query_token = 32
|
34 |
-
|
35 |
-
if self.config.tokenizer_name:
|
36 |
-
tokenizer_name = self.config.tokenizer_name
|
37 |
-
else:
|
38 |
-
tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/'
|
39 |
-
|
40 |
-
|
41 |
-
datasets[split] = dataset_cls(
|
42 |
-
vis_processor=self.vis_processors[split],
|
43 |
-
text_processor=self.text_processors[split],
|
44 |
-
vis_root=build_info.videos_dir,
|
45 |
-
ann_root=build_info.anno_dir,
|
46 |
-
num_video_query_token = num_video_query_token,
|
47 |
-
tokenizer_name = tokenizer_name,
|
48 |
-
data_type = self.config.data_type,
|
49 |
-
model_type = self.config.model_type
|
50 |
-
)
|
51 |
-
|
52 |
-
return datasets
|
53 |
-
|
54 |
-
@registry.register_builder("webvid_instruct")
|
55 |
-
class WebvidInstruct_Builder(Instruct_Builder):
|
56 |
-
train_dataset_cls = Video_Instruct_Dataset
|
57 |
-
|
58 |
-
DATASET_CONFIG_DICT = {
|
59 |
-
"default": "configs/datasets/instruct/webvid_instruct.yaml",
|
60 |
-
}
|
61 |
-
|
62 |
-
@registry.register_builder("webvid_instruct_zh")
|
63 |
-
class WebvidInstruct_zh_Builder(Instruct_Builder):
|
64 |
-
train_dataset_cls = Video_Instruct_Dataset
|
65 |
-
|
66 |
-
DATASET_CONFIG_DICT = {
|
67 |
-
"default": "configs/datasets/instruct/webvid_instruct.yaml",
|
68 |
-
}
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
@registry.register_builder("llava_instruct")
|
73 |
-
class LlavaInstruct_Builder(Instruct_Builder):
|
74 |
-
train_dataset_cls = Instruct_Dataset
|
75 |
-
|
76 |
-
DATASET_CONFIG_DICT = {
|
77 |
-
"default": "configs/datasets/instruct/llava_instruct.yaml",
|
78 |
-
}
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/datasets/builders/video_caption_builder.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import logging
|
3 |
-
import warnings
|
4 |
-
|
5 |
-
from sonique.Video_LLaMA.video_llama.common.registry import registry
|
6 |
-
from sonique.Video_LLaMA.video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
7 |
-
from sonique.Video_LLaMA.video_llama.datasets.datasets.webvid_datasets import WebvidDataset
|
8 |
-
|
9 |
-
@registry.register_builder("webvid")
|
10 |
-
class WebvidBuilder(BaseDatasetBuilder):
|
11 |
-
train_dataset_cls = WebvidDataset
|
12 |
-
DATASET_CONFIG_DICT = {"default": "configs/datasets/webvid/defaults.yaml"}
|
13 |
-
|
14 |
-
def _download_ann(self):
|
15 |
-
pass
|
16 |
-
|
17 |
-
def _download_vis(self):
|
18 |
-
pass
|
19 |
-
|
20 |
-
def build(self):
|
21 |
-
self.build_processors()
|
22 |
-
datasets = dict()
|
23 |
-
split = "train"
|
24 |
-
|
25 |
-
build_info = self.config.build_info
|
26 |
-
dataset_cls = self.train_dataset_cls
|
27 |
-
datasets[split] = dataset_cls(
|
28 |
-
vis_processor=self.vis_processors[split],
|
29 |
-
text_processor=self.text_processors[split],
|
30 |
-
vis_root=build_info.videos_dir,
|
31 |
-
ann_root=build_info.anno_dir
|
32 |
-
)
|
33 |
-
|
34 |
-
return datasets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/datasets/data_utils.py
DELETED
@@ -1,196 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import gzip
|
9 |
-
import logging
|
10 |
-
import os
|
11 |
-
import random as rnd
|
12 |
-
import tarfile
|
13 |
-
import zipfile
|
14 |
-
import random
|
15 |
-
from typing import List
|
16 |
-
from tqdm import tqdm
|
17 |
-
|
18 |
-
import decord
|
19 |
-
from decord import VideoReader
|
20 |
-
import webdataset as wds
|
21 |
-
import numpy as np
|
22 |
-
import torch
|
23 |
-
from torch.utils.data.dataset import IterableDataset
|
24 |
-
|
25 |
-
from sonique.Video_LLaMA.video_llama.common.registry import registry
|
26 |
-
from sonique.Video_LLaMA.video_llama.datasets.datasets.base_dataset import ConcatDataset
|
27 |
-
|
28 |
-
|
29 |
-
decord.bridge.set_bridge("torch")
|
30 |
-
MAX_INT = registry.get("MAX_INT")
|
31 |
-
|
32 |
-
|
33 |
-
class ChainDataset(wds.DataPipeline):
|
34 |
-
r"""Dataset for chaining multiple :class:`DataPipeline` s.
|
35 |
-
|
36 |
-
This class is useful to assemble different existing dataset streams. The
|
37 |
-
chaining operation is done on-the-fly, so concatenating large-scale
|
38 |
-
datasets with this class will be efficient.
|
39 |
-
|
40 |
-
Args:
|
41 |
-
datasets (iterable of IterableDataset): datasets to be chained together
|
42 |
-
"""
|
43 |
-
def __init__(self, datasets: List[wds.DataPipeline]) -> None:
|
44 |
-
super().__init__()
|
45 |
-
self.datasets = datasets
|
46 |
-
self.prob = []
|
47 |
-
self.names = []
|
48 |
-
for dataset in self.datasets:
|
49 |
-
if hasattr(dataset, 'name'):
|
50 |
-
self.names.append(dataset.name)
|
51 |
-
else:
|
52 |
-
self.names.append('Unknown')
|
53 |
-
if hasattr(dataset, 'sample_ratio'):
|
54 |
-
self.prob.append(dataset.sample_ratio)
|
55 |
-
else:
|
56 |
-
self.prob.append(1)
|
57 |
-
logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
|
58 |
-
|
59 |
-
def __iter__(self):
|
60 |
-
datastreams = [iter(dataset) for dataset in self.datasets]
|
61 |
-
while True:
|
62 |
-
select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
|
63 |
-
yield next(select_datastream)
|
64 |
-
|
65 |
-
|
66 |
-
def apply_to_sample(f, sample):
|
67 |
-
if len(sample) == 0:
|
68 |
-
return {}
|
69 |
-
|
70 |
-
def _apply(x):
|
71 |
-
if torch.is_tensor(x):
|
72 |
-
return f(x)
|
73 |
-
elif isinstance(x, dict):
|
74 |
-
return {key: _apply(value) for key, value in x.items()}
|
75 |
-
elif isinstance(x, list):
|
76 |
-
return [_apply(x) for x in x]
|
77 |
-
else:
|
78 |
-
return x
|
79 |
-
|
80 |
-
return _apply(sample)
|
81 |
-
|
82 |
-
|
83 |
-
def move_to_cuda(sample):
|
84 |
-
def _move_to_cuda(tensor):
|
85 |
-
return tensor.cuda()
|
86 |
-
|
87 |
-
return apply_to_sample(_move_to_cuda, sample)
|
88 |
-
|
89 |
-
|
90 |
-
def prepare_sample(samples, cuda_enabled=True):
|
91 |
-
if cuda_enabled:
|
92 |
-
samples = move_to_cuda(samples)
|
93 |
-
|
94 |
-
# TODO fp16 support
|
95 |
-
|
96 |
-
return samples
|
97 |
-
|
98 |
-
|
99 |
-
def reorg_datasets_by_split(datasets):
|
100 |
-
"""
|
101 |
-
Organizes datasets by split.
|
102 |
-
|
103 |
-
Args:
|
104 |
-
datasets: dict of torch.utils.data.Dataset objects by name.
|
105 |
-
|
106 |
-
Returns:
|
107 |
-
Dict of datasets by split {split_name: List[Datasets]}.
|
108 |
-
"""
|
109 |
-
# if len(datasets) == 1:
|
110 |
-
# return datasets[list(datasets.keys())[0]]
|
111 |
-
# else:
|
112 |
-
reorg_datasets = dict()
|
113 |
-
|
114 |
-
# reorganize by split
|
115 |
-
for _, dataset in datasets.items():
|
116 |
-
for split_name, dataset_split in dataset.items():
|
117 |
-
if split_name not in reorg_datasets:
|
118 |
-
reorg_datasets[split_name] = [dataset_split]
|
119 |
-
else:
|
120 |
-
reorg_datasets[split_name].append(dataset_split)
|
121 |
-
|
122 |
-
return reorg_datasets
|
123 |
-
|
124 |
-
|
125 |
-
def concat_datasets(datasets):
|
126 |
-
"""
|
127 |
-
Concatenates multiple datasets into a single dataset.
|
128 |
-
|
129 |
-
It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
|
130 |
-
generic IterableDataset because it requires creating separate samplers.
|
131 |
-
|
132 |
-
Now only supports conctenating training datasets and assuming validation and testing
|
133 |
-
have only a single dataset. This is because metrics should not be computed on the concatenated
|
134 |
-
datasets.
|
135 |
-
|
136 |
-
Args:
|
137 |
-
datasets: dict of torch.utils.data.Dataset objects by split.
|
138 |
-
|
139 |
-
Returns:
|
140 |
-
Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
|
141 |
-
"val" and "test" remain the same.
|
142 |
-
|
143 |
-
If the input training datasets contain both map-style and DataPipeline datasets, returns
|
144 |
-
a tuple, where the first element is a concatenated map-style dataset and the second
|
145 |
-
element is a chained DataPipeline dataset.
|
146 |
-
|
147 |
-
"""
|
148 |
-
# concatenate datasets in the same split
|
149 |
-
for split_name in datasets:
|
150 |
-
if split_name != "train":
|
151 |
-
assert (
|
152 |
-
len(datasets[split_name]) == 1
|
153 |
-
), "Do not support multiple {} datasets.".format(split_name)
|
154 |
-
datasets[split_name] = datasets[split_name][0]
|
155 |
-
else:
|
156 |
-
iterable_datasets, map_datasets = [], []
|
157 |
-
for dataset in datasets[split_name]:
|
158 |
-
if isinstance(dataset, wds.DataPipeline):
|
159 |
-
logging.info(
|
160 |
-
"Dataset {} is IterableDataset, can't be concatenated.".format(
|
161 |
-
dataset
|
162 |
-
)
|
163 |
-
)
|
164 |
-
iterable_datasets.append(dataset)
|
165 |
-
elif isinstance(dataset, IterableDataset):
|
166 |
-
raise NotImplementedError(
|
167 |
-
"Do not support concatenation of generic IterableDataset."
|
168 |
-
)
|
169 |
-
else:
|
170 |
-
map_datasets.append(dataset)
|
171 |
-
|
172 |
-
# if len(iterable_datasets) > 0:
|
173 |
-
# concatenate map-style datasets and iterable-style datasets separately
|
174 |
-
if len(iterable_datasets) > 1:
|
175 |
-
chained_datasets = (
|
176 |
-
ChainDataset(iterable_datasets)
|
177 |
-
)
|
178 |
-
elif len(iterable_datasets) == 1:
|
179 |
-
chained_datasets = iterable_datasets[0]
|
180 |
-
else:
|
181 |
-
chained_datasets = None
|
182 |
-
|
183 |
-
concat_datasets = (
|
184 |
-
ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
|
185 |
-
)
|
186 |
-
|
187 |
-
train_datasets = concat_datasets, chained_datasets
|
188 |
-
train_datasets = tuple([x for x in train_datasets if x is not None])
|
189 |
-
train_datasets = (
|
190 |
-
train_datasets[0] if len(train_datasets) == 1 else train_datasets
|
191 |
-
)
|
192 |
-
|
193 |
-
datasets[split_name] = train_datasets
|
194 |
-
|
195 |
-
return datasets
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/datasets/datasets/__init__.py
DELETED
File without changes
|
sonique/Video_LLaMA/video_llama/datasets/datasets/base_dataset.py
DELETED
@@ -1,68 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import json
|
9 |
-
from typing import Iterable
|
10 |
-
|
11 |
-
from torch.utils.data import Dataset, ConcatDataset
|
12 |
-
from torch.utils.data.dataloader import default_collate
|
13 |
-
|
14 |
-
|
15 |
-
class BaseDataset(Dataset):
|
16 |
-
def __init__(
|
17 |
-
self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
|
18 |
-
):
|
19 |
-
"""
|
20 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
21 |
-
ann_root (string): directory to store the annotation file
|
22 |
-
"""
|
23 |
-
self.vis_root = vis_root
|
24 |
-
|
25 |
-
self.annotation = []
|
26 |
-
for ann_path in ann_paths:
|
27 |
-
self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
|
28 |
-
|
29 |
-
self.vis_processor = vis_processor
|
30 |
-
self.text_processor = text_processor
|
31 |
-
|
32 |
-
self._add_instance_ids()
|
33 |
-
|
34 |
-
def __len__(self):
|
35 |
-
return len(self.annotation)
|
36 |
-
|
37 |
-
def collater(self, samples):
|
38 |
-
return default_collate(samples)
|
39 |
-
|
40 |
-
def set_processors(self, vis_processor, text_processor):
|
41 |
-
self.vis_processor = vis_processor
|
42 |
-
self.text_processor = text_processor
|
43 |
-
|
44 |
-
def _add_instance_ids(self, key="instance_id"):
|
45 |
-
for idx, ann in enumerate(self.annotation):
|
46 |
-
ann[key] = str(idx)
|
47 |
-
|
48 |
-
|
49 |
-
class ConcatDataset(ConcatDataset):
|
50 |
-
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
51 |
-
super().__init__(datasets)
|
52 |
-
|
53 |
-
def collater(self, samples):
|
54 |
-
# TODO For now only supports datasets with same underlying collater implementations
|
55 |
-
|
56 |
-
all_keys = set()
|
57 |
-
for s in samples:
|
58 |
-
all_keys.update(s)
|
59 |
-
|
60 |
-
shared_keys = all_keys
|
61 |
-
for s in samples:
|
62 |
-
shared_keys = shared_keys & set(s.keys())
|
63 |
-
|
64 |
-
samples_shared_keys = []
|
65 |
-
for s in samples:
|
66 |
-
samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
|
67 |
-
|
68 |
-
return self.datasets[0].collater(samples_shared_keys)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/datasets/datasets/caption_datasets.py
DELETED
@@ -1,85 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import os
|
9 |
-
from collections import OrderedDict
|
10 |
-
|
11 |
-
from sonique.Video_LLaMA.video_llama.datasets.datasets.base_dataset import BaseDataset
|
12 |
-
from PIL import Image
|
13 |
-
|
14 |
-
|
15 |
-
class __DisplMixin:
|
16 |
-
def displ_item(self, index):
|
17 |
-
sample, ann = self.__getitem__(index), self.annotation[index]
|
18 |
-
|
19 |
-
return OrderedDict(
|
20 |
-
{
|
21 |
-
"file": ann["image"],
|
22 |
-
"caption": ann["caption"],
|
23 |
-
"image": sample["image"],
|
24 |
-
}
|
25 |
-
)
|
26 |
-
|
27 |
-
|
28 |
-
class CaptionDataset(BaseDataset, __DisplMixin):
|
29 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
30 |
-
"""
|
31 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
32 |
-
ann_root (string): directory to store the annotation file
|
33 |
-
"""
|
34 |
-
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
35 |
-
|
36 |
-
self.img_ids = {}
|
37 |
-
n = 0
|
38 |
-
for ann in self.annotation:
|
39 |
-
img_id = ann["image_id"]
|
40 |
-
if img_id not in self.img_ids.keys():
|
41 |
-
self.img_ids[img_id] = n
|
42 |
-
n += 1
|
43 |
-
|
44 |
-
def __getitem__(self, index):
|
45 |
-
|
46 |
-
# TODO this assumes image input, not general enough
|
47 |
-
ann = self.annotation[index]
|
48 |
-
|
49 |
-
img_file = '{:0>12}.jpg'.format(ann["image_id"])
|
50 |
-
image_path = os.path.join(self.vis_root, img_file)
|
51 |
-
image = Image.open(image_path).convert("RGB")
|
52 |
-
|
53 |
-
image = self.vis_processor(image)
|
54 |
-
caption = self.text_processor(ann["caption"])
|
55 |
-
|
56 |
-
return {
|
57 |
-
"image": image,
|
58 |
-
"text_input": caption,
|
59 |
-
"image_id": self.img_ids[ann["image_id"]],
|
60 |
-
}
|
61 |
-
|
62 |
-
|
63 |
-
class CaptionEvalDataset(BaseDataset, __DisplMixin):
|
64 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
65 |
-
"""
|
66 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
67 |
-
ann_root (string): directory to store the annotation file
|
68 |
-
split (string): val or test
|
69 |
-
"""
|
70 |
-
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
71 |
-
|
72 |
-
def __getitem__(self, index):
|
73 |
-
|
74 |
-
ann = self.annotation[index]
|
75 |
-
|
76 |
-
image_path = os.path.join(self.vis_root, ann["image"])
|
77 |
-
image = Image.open(image_path).convert("RGB")
|
78 |
-
|
79 |
-
image = self.vis_processor(image)
|
80 |
-
|
81 |
-
return {
|
82 |
-
"image": image,
|
83 |
-
"image_id": ann["image_id"],
|
84 |
-
"instance_id": ann["instance_id"],
|
85 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/datasets/datasets/cc_sbu_dataset.py
DELETED
@@ -1,49 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from PIL import Image
|
3 |
-
import webdataset as wds
|
4 |
-
from sonique.Video_LLaMA.video_llama.datasets.datasets.base_dataset import BaseDataset
|
5 |
-
from sonique.Video_LLaMA.video_llama.datasets.datasets.caption_datasets import CaptionDataset
|
6 |
-
|
7 |
-
|
8 |
-
class CCSBUDataset(BaseDataset):
|
9 |
-
def __init__(self, vis_processor, text_processor, location):
|
10 |
-
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
11 |
-
|
12 |
-
self.inner_dataset = wds.DataPipeline(
|
13 |
-
wds.ResampledShards(location),
|
14 |
-
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
15 |
-
wds.shuffle(1000, handler=wds.warn_and_continue),
|
16 |
-
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
17 |
-
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
18 |
-
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
|
19 |
-
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
20 |
-
)
|
21 |
-
|
22 |
-
def to_dict(self, sample):
|
23 |
-
return {
|
24 |
-
"image": sample[0],
|
25 |
-
"text_input": self.text_processor(sample[1]["caption"]),
|
26 |
-
"type":'image',
|
27 |
-
}
|
28 |
-
|
29 |
-
|
30 |
-
class CCSBUAlignDataset(CaptionDataset):
|
31 |
-
|
32 |
-
def __getitem__(self, index):
|
33 |
-
|
34 |
-
# TODO this assumes image input, not general enough
|
35 |
-
ann = self.annotation[index]
|
36 |
-
|
37 |
-
img_file = '{}.jpg'.format(ann["image_id"])
|
38 |
-
image_path = os.path.join(self.vis_root, img_file)
|
39 |
-
image = Image.open(image_path).convert("RGB")
|
40 |
-
|
41 |
-
image = self.vis_processor(image)
|
42 |
-
caption = ann["caption"]
|
43 |
-
|
44 |
-
return {
|
45 |
-
"image": image,
|
46 |
-
"text_input": caption,
|
47 |
-
"image_id": self.img_ids[ann["image_id"]],
|
48 |
-
"type":'image',
|
49 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/datasets/datasets/dataloader_utils.py
DELETED
@@ -1,162 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import time
|
9 |
-
import random
|
10 |
-
import torch
|
11 |
-
from sonique.Video_LLaMA.video_llama.datasets.data_utils import move_to_cuda
|
12 |
-
from torch.utils.data import DataLoader
|
13 |
-
|
14 |
-
|
15 |
-
class MultiIterLoader:
|
16 |
-
"""
|
17 |
-
A simple wrapper for iterating over multiple iterators.
|
18 |
-
|
19 |
-
Args:
|
20 |
-
loaders (List[Loader]): List of Iterator loaders.
|
21 |
-
ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
|
22 |
-
"""
|
23 |
-
|
24 |
-
def __init__(self, loaders, ratios=None):
|
25 |
-
# assert all loaders has __next__ method
|
26 |
-
for loader in loaders:
|
27 |
-
assert hasattr(
|
28 |
-
loader, "__next__"
|
29 |
-
), "Loader {} has no __next__ method.".format(loader)
|
30 |
-
|
31 |
-
if ratios is None:
|
32 |
-
ratios = [1.0] * len(loaders)
|
33 |
-
else:
|
34 |
-
assert len(ratios) == len(loaders)
|
35 |
-
ratios = [float(ratio) / sum(ratios) for ratio in ratios]
|
36 |
-
|
37 |
-
self.loaders = loaders
|
38 |
-
self.ratios = ratios
|
39 |
-
|
40 |
-
def __next__(self):
|
41 |
-
# random sample from each loader by ratio
|
42 |
-
loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
|
43 |
-
return next(self.loaders[loader_idx])
|
44 |
-
|
45 |
-
|
46 |
-
class PrefetchLoader(object):
|
47 |
-
"""
|
48 |
-
Modified from https://github.com/ChenRocks/UNITER.
|
49 |
-
|
50 |
-
overlap compute and cuda data transfer
|
51 |
-
(copied and then modified from nvidia apex)
|
52 |
-
"""
|
53 |
-
|
54 |
-
def __init__(self, loader):
|
55 |
-
self.loader = loader
|
56 |
-
self.stream = torch.cuda.Stream()
|
57 |
-
|
58 |
-
def __iter__(self):
|
59 |
-
loader_it = iter(self.loader)
|
60 |
-
self.preload(loader_it)
|
61 |
-
batch = self.next(loader_it)
|
62 |
-
while batch is not None:
|
63 |
-
is_tuple = isinstance(batch, tuple)
|
64 |
-
if is_tuple:
|
65 |
-
task, batch = batch
|
66 |
-
|
67 |
-
if is_tuple:
|
68 |
-
yield task, batch
|
69 |
-
else:
|
70 |
-
yield batch
|
71 |
-
batch = self.next(loader_it)
|
72 |
-
|
73 |
-
def __len__(self):
|
74 |
-
return len(self.loader)
|
75 |
-
|
76 |
-
def preload(self, it):
|
77 |
-
try:
|
78 |
-
self.batch = next(it)
|
79 |
-
except StopIteration:
|
80 |
-
self.batch = None
|
81 |
-
return
|
82 |
-
# if record_stream() doesn't work, another option is to make sure
|
83 |
-
# device inputs are created on the main stream.
|
84 |
-
# self.next_input_gpu = torch.empty_like(self.next_input,
|
85 |
-
# device='cuda')
|
86 |
-
# self.next_target_gpu = torch.empty_like(self.next_target,
|
87 |
-
# device='cuda')
|
88 |
-
# Need to make sure the memory allocated for next_* is not still in use
|
89 |
-
# by the main stream at the time we start copying to next_*:
|
90 |
-
# self.stream.wait_stream(torch.cuda.current_stream())
|
91 |
-
with torch.cuda.stream(self.stream):
|
92 |
-
self.batch = move_to_cuda(self.batch)
|
93 |
-
# more code for the alternative if record_stream() doesn't work:
|
94 |
-
# copy_ will record the use of the pinned source tensor in this
|
95 |
-
# side stream.
|
96 |
-
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
|
97 |
-
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
|
98 |
-
# self.next_input = self.next_input_gpu
|
99 |
-
# self.next_target = self.next_target_gpu
|
100 |
-
|
101 |
-
def next(self, it):
|
102 |
-
torch.cuda.current_stream().wait_stream(self.stream)
|
103 |
-
batch = self.batch
|
104 |
-
if batch is not None:
|
105 |
-
record_cuda_stream(batch)
|
106 |
-
self.preload(it)
|
107 |
-
return batch
|
108 |
-
|
109 |
-
def __getattr__(self, name):
|
110 |
-
method = self.loader.__getattribute__(name)
|
111 |
-
return method
|
112 |
-
|
113 |
-
|
114 |
-
def record_cuda_stream(batch):
|
115 |
-
if isinstance(batch, torch.Tensor):
|
116 |
-
batch.record_stream(torch.cuda.current_stream())
|
117 |
-
elif isinstance(batch, list) or isinstance(batch, tuple):
|
118 |
-
for t in batch:
|
119 |
-
record_cuda_stream(t)
|
120 |
-
elif isinstance(batch, dict):
|
121 |
-
for t in batch.values():
|
122 |
-
record_cuda_stream(t)
|
123 |
-
else:
|
124 |
-
pass
|
125 |
-
|
126 |
-
|
127 |
-
class IterLoader:
|
128 |
-
"""
|
129 |
-
A wrapper to convert DataLoader as an infinite iterator.
|
130 |
-
|
131 |
-
Modified from:
|
132 |
-
https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
|
133 |
-
"""
|
134 |
-
|
135 |
-
def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
|
136 |
-
self._dataloader = dataloader
|
137 |
-
self.iter_loader = iter(self._dataloader)
|
138 |
-
self._use_distributed = use_distributed
|
139 |
-
self._epoch = 0
|
140 |
-
|
141 |
-
@property
|
142 |
-
def epoch(self) -> int:
|
143 |
-
return self._epoch
|
144 |
-
|
145 |
-
def __next__(self):
|
146 |
-
try:
|
147 |
-
data = next(self.iter_loader)
|
148 |
-
except StopIteration:
|
149 |
-
self._epoch += 1
|
150 |
-
if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
|
151 |
-
self._dataloader.sampler.set_epoch(self._epoch)
|
152 |
-
time.sleep(2) # Prevent possible deadlock during epoch transition
|
153 |
-
self.iter_loader = iter(self._dataloader)
|
154 |
-
data = next(self.iter_loader)
|
155 |
-
|
156 |
-
return data
|
157 |
-
|
158 |
-
def __iter__(self):
|
159 |
-
return self
|
160 |
-
|
161 |
-
def __len__(self):
|
162 |
-
return len(self._dataloader)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/datasets/datasets/laion_dataset.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import webdataset as wds
|
9 |
-
from sonique.Video_LLaMA.video_llama.datasets.datasets.base_dataset import BaseDataset
|
10 |
-
|
11 |
-
|
12 |
-
class LaionDataset(BaseDataset):
|
13 |
-
def __init__(self, vis_processor, text_processor, location):
|
14 |
-
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
15 |
-
|
16 |
-
self.inner_dataset = wds.DataPipeline(
|
17 |
-
wds.ResampledShards(location),
|
18 |
-
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
19 |
-
wds.shuffle(1000, handler=wds.warn_and_continue),
|
20 |
-
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
21 |
-
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
22 |
-
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
|
23 |
-
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
24 |
-
)
|
25 |
-
|
26 |
-
def to_dict(self, sample):
|
27 |
-
return {
|
28 |
-
"image": sample[0],
|
29 |
-
"text_input": self.text_processor(sample[1]["caption"]),
|
30 |
-
}
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/datasets/datasets/llava_instruct_dataset.py
DELETED
@@ -1,312 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from sonique.Video_LLaMA.video_llama.datasets.datasets.base_dataset import BaseDataset
|
3 |
-
from sonique.Video_LLaMA.video_llama.datasets.datasets.caption_datasets import CaptionDataset
|
4 |
-
import pandas as pd
|
5 |
-
import decord
|
6 |
-
from decord import VideoReader
|
7 |
-
import random
|
8 |
-
import torch
|
9 |
-
from torch.utils.data.dataloader import default_collate
|
10 |
-
from PIL import Image
|
11 |
-
from typing import Dict, Optional, Sequence
|
12 |
-
import transformers
|
13 |
-
import pathlib
|
14 |
-
import json
|
15 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
|
16 |
-
from sonique.Video_LLaMA.video_llama.conversation.conversation_video import Conversation,SeparatorStyle
|
17 |
-
DEFAULT_IMAGE_PATCH_TOKEN = '<ImageHere>'
|
18 |
-
DEFAULT_IMAGE_TOKEN = "<image>"
|
19 |
-
import copy
|
20 |
-
from sonique.Video_LLaMA.video_llama.processors import transforms_video,AlproVideoTrainProcessor
|
21 |
-
IGNORE_INDEX = -100
|
22 |
-
image_conversation = Conversation(
|
23 |
-
system="",
|
24 |
-
roles=("Human", "Assistant"),
|
25 |
-
messages=[],
|
26 |
-
offset=0,
|
27 |
-
sep_style=SeparatorStyle.SINGLE,
|
28 |
-
sep="###",
|
29 |
-
)
|
30 |
-
llama_v2_image_conversation = Conversation(
|
31 |
-
system=" ",
|
32 |
-
roles=("USER", "ASSISTANT"),
|
33 |
-
messages=(),
|
34 |
-
offset=0,
|
35 |
-
sep_style=SeparatorStyle.LLAMA_2,
|
36 |
-
sep="<s>",
|
37 |
-
sep2="</s>",
|
38 |
-
)
|
39 |
-
IGNORE_INDEX = -100
|
40 |
-
|
41 |
-
class Instruct_Dataset(BaseDataset):
|
42 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_root,num_video_query_token=32,tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/',data_type = 'image', model_type='vicuna'):
|
43 |
-
"""
|
44 |
-
vis_root (string): Root directory of Llava images (e.g. webvid_eval/video/)
|
45 |
-
ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
|
46 |
-
split (string): val or test
|
47 |
-
"""
|
48 |
-
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
49 |
-
|
50 |
-
data_path = pathlib.Path(ann_root)
|
51 |
-
with data_path.open(encoding='utf-8') as f:
|
52 |
-
self.annotation = json.load(f)
|
53 |
-
|
54 |
-
self.vis_root = vis_root
|
55 |
-
self.resize_size = 224
|
56 |
-
self.num_frm = 8
|
57 |
-
self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name, use_fast=False)
|
58 |
-
self.tokenizer.pad_token = self.tokenizer.unk_token
|
59 |
-
self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
60 |
-
self.num_video_query_token = num_video_query_token
|
61 |
-
self.IMAGE_PATCH_TOKEN_ID = self.tokenizer.get_vocab()[DEFAULT_IMAGE_PATCH_TOKEN]
|
62 |
-
|
63 |
-
self.transform = AlproVideoTrainProcessor(
|
64 |
-
image_size=self.resize_size, n_frms = self.num_frm
|
65 |
-
).transform
|
66 |
-
self.data_type = data_type
|
67 |
-
self.model_type = model_type
|
68 |
-
|
69 |
-
def _get_image_path(self, sample):
|
70 |
-
rel_video_fp ='COCO_train2014_' + sample['image']
|
71 |
-
full_video_fp = os.path.join(self.vis_root, rel_video_fp)
|
72 |
-
return full_video_fp
|
73 |
-
|
74 |
-
def __getitem__(self, index):
|
75 |
-
num_retries = 10 # skip error videos
|
76 |
-
for _ in range(num_retries):
|
77 |
-
try:
|
78 |
-
sample = self.annotation[index]
|
79 |
-
|
80 |
-
image_path = self._get_image_path(sample)
|
81 |
-
conversation_list = sample['conversations']
|
82 |
-
image = Image.open(image_path).convert("RGB")
|
83 |
-
|
84 |
-
image = self.vis_processor(image)
|
85 |
-
# text = self.text_processor(text)
|
86 |
-
sources = preprocess_multimodal(copy.deepcopy(conversation_list), None, cur_token_len=self.num_video_query_token)
|
87 |
-
if self.model_type =='vicuna':
|
88 |
-
data_dict = preprocess(
|
89 |
-
sources,
|
90 |
-
self.tokenizer)
|
91 |
-
elif self.model_type =='llama_v2':
|
92 |
-
data_dict = preprocess_for_llama_v2(
|
93 |
-
sources,
|
94 |
-
self.tokenizer)
|
95 |
-
else:
|
96 |
-
print('not support')
|
97 |
-
raise('not support')
|
98 |
-
data_dict = dict(input_ids=data_dict["input_ids"][0],
|
99 |
-
labels=data_dict["labels"][0])
|
100 |
-
|
101 |
-
# image exist in the data
|
102 |
-
data_dict['image'] = image
|
103 |
-
except:
|
104 |
-
print(f"Failed to load examples with image: {image_path}. "
|
105 |
-
f"Will randomly sample an example as a replacement.")
|
106 |
-
index = random.randint(0, len(self) - 1)
|
107 |
-
continue
|
108 |
-
break
|
109 |
-
else:
|
110 |
-
raise RuntimeError(f"Failed to fetch image after {num_retries} retries.")
|
111 |
-
# "image_id" is kept to stay compatible with the COCO evaluation format
|
112 |
-
return {
|
113 |
-
"image": image,
|
114 |
-
"text_input": data_dict["input_ids"],
|
115 |
-
"labels": data_dict["labels"],
|
116 |
-
"type":'image',
|
117 |
-
}
|
118 |
-
|
119 |
-
def __len__(self):
|
120 |
-
return len(self.annotation)
|
121 |
-
|
122 |
-
def collater(self, instances):
|
123 |
-
input_ids, labels = tuple([instance[key] for instance in instances]
|
124 |
-
for key in ("text_input", "labels"))
|
125 |
-
input_ids = torch.nn.utils.rnn.pad_sequence(
|
126 |
-
input_ids,
|
127 |
-
batch_first=True,
|
128 |
-
padding_value=self.tokenizer.pad_token_id)
|
129 |
-
labels = torch.nn.utils.rnn.pad_sequence(labels,
|
130 |
-
batch_first=True,
|
131 |
-
padding_value=IGNORE_INDEX)
|
132 |
-
batch = dict(
|
133 |
-
input_ids=input_ids,
|
134 |
-
labels=labels,
|
135 |
-
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
136 |
-
)
|
137 |
-
|
138 |
-
if 'image' in instances[0]:
|
139 |
-
images = [instance['image'] for instance in instances]
|
140 |
-
if all(x is not None and x.shape == images[0].shape for x in images):
|
141 |
-
batch['images'] = torch.stack(images)
|
142 |
-
else:
|
143 |
-
batch['images'] = images
|
144 |
-
batch['conv_type'] = 'multi'
|
145 |
-
return batch
|
146 |
-
|
147 |
-
|
148 |
-
def preprocess_multimodal(
|
149 |
-
conversation_list: Sequence[str],
|
150 |
-
multimodal_cfg: dict,
|
151 |
-
cur_token_len: int,
|
152 |
-
) -> Dict:
|
153 |
-
# 将conversational list中
|
154 |
-
is_multimodal = True
|
155 |
-
# image_token_len = multimodal_cfg['image_token_len']
|
156 |
-
image_token_len = cur_token_len
|
157 |
-
|
158 |
-
for sentence in conversation_list:
|
159 |
-
replace_token = '<Image>'+DEFAULT_IMAGE_PATCH_TOKEN * image_token_len+'</Image>'
|
160 |
-
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
161 |
-
|
162 |
-
return [conversation_list]
|
163 |
-
|
164 |
-
def _add_speaker_and_signal(header, source, get_conversation=True):
|
165 |
-
"""Add speaker and start/end signal on each round."""
|
166 |
-
BEGIN_SIGNAL = "###"
|
167 |
-
END_SIGNAL = "\n"
|
168 |
-
conversation = header
|
169 |
-
for sentence in source:
|
170 |
-
from_str = sentence["from"]
|
171 |
-
if from_str.lower() == "human":
|
172 |
-
from_str = image_conversation.roles[0]
|
173 |
-
elif from_str.lower() == "gpt":
|
174 |
-
from_str = image_conversation.roles[1]
|
175 |
-
else:
|
176 |
-
from_str = 'unknown'
|
177 |
-
sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
|
178 |
-
sentence["value"] + END_SIGNAL)
|
179 |
-
if get_conversation:
|
180 |
-
conversation += sentence["value"]
|
181 |
-
conversation += BEGIN_SIGNAL
|
182 |
-
return conversation
|
183 |
-
|
184 |
-
def _tokenize_fn(strings: Sequence[str],
|
185 |
-
tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
186 |
-
"""Tokenize a list of strings."""
|
187 |
-
tokenized_list = [
|
188 |
-
tokenizer(
|
189 |
-
text,
|
190 |
-
return_tensors="pt",
|
191 |
-
padding="longest",
|
192 |
-
max_length=512,
|
193 |
-
truncation=True,
|
194 |
-
) for text in strings
|
195 |
-
]
|
196 |
-
input_ids = labels = [
|
197 |
-
tokenized.input_ids[0] for tokenized in tokenized_list
|
198 |
-
]
|
199 |
-
input_ids_lens = labels_lens = [
|
200 |
-
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
|
201 |
-
for tokenized in tokenized_list
|
202 |
-
]
|
203 |
-
return dict(
|
204 |
-
input_ids=input_ids,
|
205 |
-
labels=labels,
|
206 |
-
input_ids_lens=input_ids_lens,
|
207 |
-
labels_lens=labels_lens,
|
208 |
-
)
|
209 |
-
|
210 |
-
def preprocess(
|
211 |
-
sources: Sequence[str],
|
212 |
-
tokenizer: transformers.PreTrainedTokenizer,
|
213 |
-
) -> Dict:
|
214 |
-
"""
|
215 |
-
Given a list of sources, each is a conversation list. This transform:
|
216 |
-
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
|
217 |
-
2. Concatenate conversations together;
|
218 |
-
3. Tokenize the concatenated conversation;
|
219 |
-
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
|
220 |
-
"""
|
221 |
-
# add end signal and concatenate together
|
222 |
-
conversations = []
|
223 |
-
for source in sources:
|
224 |
-
header = f"{image_conversation.system}\n\n"
|
225 |
-
conversation = _add_speaker_and_signal(header, source)
|
226 |
-
conversations.append(conversation)
|
227 |
-
# tokenize conversations
|
228 |
-
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
|
229 |
-
input_ids = conversations_tokenized["input_ids"]
|
230 |
-
targets = copy.deepcopy(input_ids)
|
231 |
-
for target, source in zip(targets, sources):
|
232 |
-
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source],
|
233 |
-
tokenizer)["input_ids_lens"]
|
234 |
-
speakers = [sentence["from"] for sentence in source]
|
235 |
-
_mask_targets(target, tokenized_lens, speakers)
|
236 |
-
|
237 |
-
return dict(input_ids=input_ids, labels=targets)
|
238 |
-
|
239 |
-
def preprocess_for_llama_v2(
|
240 |
-
sources: Sequence[str],
|
241 |
-
tokenizer: transformers.PreTrainedTokenizer,
|
242 |
-
) -> Dict:
|
243 |
-
"""
|
244 |
-
Given a list of sources, each is a conversation list. This transform:
|
245 |
-
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
|
246 |
-
2. Concatenate conversations together;
|
247 |
-
3. Tokenize the concatenated conversation;
|
248 |
-
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
|
249 |
-
"""
|
250 |
-
# add end signal and concatenate together
|
251 |
-
conversations = []
|
252 |
-
conv = copy.deepcopy(llama_v2_image_conversation.copy())
|
253 |
-
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
254 |
-
for source in sources:
|
255 |
-
# <s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n
|
256 |
-
header = f"<s>[INST] <<SYS>>\n{conv.system}\n</SYS>>\n\n"
|
257 |
-
|
258 |
-
if roles[source[0]["from"]] != conv.roles[0]:
|
259 |
-
# Skip the first one if it is not from human
|
260 |
-
source = source[1:]
|
261 |
-
conv.messages = []
|
262 |
-
for j, sentence in enumerate(source):
|
263 |
-
role = roles[sentence["from"]]
|
264 |
-
assert role == conv.roles[j % 2]
|
265 |
-
conv.append_message(role, sentence["value"])
|
266 |
-
conversations.append(conv.get_prompt())
|
267 |
-
|
268 |
-
input_ids = tokenizer(
|
269 |
-
conversations,
|
270 |
-
return_tensors="pt",
|
271 |
-
padding="longest",
|
272 |
-
max_length=512,
|
273 |
-
truncation=True,
|
274 |
-
).input_ids
|
275 |
-
targets = copy.deepcopy(input_ids)
|
276 |
-
|
277 |
-
|
278 |
-
sep = "[/INST] "
|
279 |
-
for conversation, target in zip(conversations, targets):
|
280 |
-
# total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
281 |
-
rounds = conversation.split(conv.sep2)
|
282 |
-
cur_len = 1
|
283 |
-
target[:cur_len] = IGNORE_INDEX
|
284 |
-
for i, rou in enumerate(rounds):
|
285 |
-
if rou == "":
|
286 |
-
break
|
287 |
-
|
288 |
-
parts = rou.split(sep)
|
289 |
-
if len(parts) != 2:
|
290 |
-
break
|
291 |
-
parts[0] += sep
|
292 |
-
|
293 |
-
|
294 |
-
round_len = len(tokenizer(rou).input_ids)
|
295 |
-
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 # 为什么减去2,speical token 的数目
|
296 |
-
|
297 |
-
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
298 |
-
|
299 |
-
cur_len += round_len
|
300 |
-
target[cur_len:] = IGNORE_INDEX
|
301 |
-
|
302 |
-
return dict(input_ids=input_ids, labels=targets)
|
303 |
-
|
304 |
-
def _mask_targets(target, tokenized_lens, speakers):
|
305 |
-
# cur_idx = 0
|
306 |
-
cur_idx = tokenized_lens[0]
|
307 |
-
tokenized_lens = tokenized_lens[1:]
|
308 |
-
target[:cur_idx] = IGNORE_INDEX
|
309 |
-
for tokenized_len, speaker in zip(tokenized_lens, speakers):
|
310 |
-
if speaker == "human":
|
311 |
-
target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
|
312 |
-
cur_idx += tokenized_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/datasets/datasets/video_instruct_dataset.py
DELETED
@@ -1,335 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from sonique.Video_LLaMA.video_llama.datasets.datasets.base_dataset import BaseDataset
|
3 |
-
from sonique.Video_LLaMA.video_llama.datasets.datasets.caption_datasets import CaptionDataset
|
4 |
-
import pandas as pd
|
5 |
-
import decord
|
6 |
-
from decord import VideoReader
|
7 |
-
import random
|
8 |
-
import torch
|
9 |
-
from torch.utils.data.dataloader import default_collate
|
10 |
-
from PIL import Image
|
11 |
-
from typing import Dict, Optional, Sequence
|
12 |
-
import transformers
|
13 |
-
import pathlib
|
14 |
-
import json
|
15 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
|
16 |
-
import copy
|
17 |
-
from sonique.Video_LLaMA.video_llama.processors import transforms_video,AlproVideoTrainProcessor
|
18 |
-
from torchvision import transforms
|
19 |
-
from sonique.Video_LLaMA.video_llama.processors.video_processor import ToTHWC,ToUint8,load_video
|
20 |
-
from sonique.Video_LLaMA.video_llama.conversation.conversation_video import Conversation,SeparatorStyle
|
21 |
-
|
22 |
-
DEFAULT_IMAGE_PATCH_TOKEN = '<ImageHere>'
|
23 |
-
video_conversation = Conversation(
|
24 |
-
system="",
|
25 |
-
roles=("Human", "Assistant"),
|
26 |
-
messages=[],
|
27 |
-
offset=0,
|
28 |
-
sep_style=SeparatorStyle.SINGLE,
|
29 |
-
sep="###",
|
30 |
-
)
|
31 |
-
llama_v2_video_conversation = Conversation(
|
32 |
-
system=" ",
|
33 |
-
roles=("USER", "ASSISTANT"),
|
34 |
-
messages=(),
|
35 |
-
offset=0,
|
36 |
-
sep_style=SeparatorStyle.LLAMA_2,
|
37 |
-
sep="<s>",
|
38 |
-
sep2="</s>",
|
39 |
-
)
|
40 |
-
IGNORE_INDEX = -100
|
41 |
-
|
42 |
-
class Video_Instruct_Dataset(BaseDataset):
|
43 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_root,num_video_query_token=32,tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/',data_type = 'video', model_type='vicuna'):
|
44 |
-
"""
|
45 |
-
vis_root (string): Root directory of Llava images (e.g. webvid_eval/video/)
|
46 |
-
ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
|
47 |
-
split (string): val or test
|
48 |
-
"""
|
49 |
-
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
50 |
-
|
51 |
-
data_path = pathlib.Path(ann_root)
|
52 |
-
with data_path.open(encoding='utf-8') as f:
|
53 |
-
self.annotation = json.load(f)
|
54 |
-
|
55 |
-
self.num_video_query_token = num_video_query_token
|
56 |
-
self.vis_root = vis_root
|
57 |
-
self.resize_size = 224
|
58 |
-
self.num_frm = 8
|
59 |
-
self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name, use_fast=False)
|
60 |
-
self.tokenizer.pad_token = self.tokenizer.unk_token
|
61 |
-
self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
62 |
-
self.IMAGE_PATCH_TOKEN_ID = self.tokenizer.get_vocab()[DEFAULT_IMAGE_PATCH_TOKEN]
|
63 |
-
|
64 |
-
self.transform = AlproVideoTrainProcessor(
|
65 |
-
image_size=self.resize_size, n_frms = self.num_frm
|
66 |
-
).transform
|
67 |
-
self.data_type = data_type
|
68 |
-
self.model_type = model_type
|
69 |
-
|
70 |
-
def _get_video_path(self, sample):
|
71 |
-
rel_video_fp = sample['video']
|
72 |
-
full_video_fp = os.path.join(self.vis_root, rel_video_fp)
|
73 |
-
return full_video_fp
|
74 |
-
|
75 |
-
def __getitem__(self, index):
|
76 |
-
num_retries = 10 # skip error videos
|
77 |
-
for _ in range(num_retries):
|
78 |
-
try:
|
79 |
-
sample = self.annotation[index]
|
80 |
-
|
81 |
-
video_path = self._get_video_path(sample)
|
82 |
-
conversation_list = sample['QA']
|
83 |
-
|
84 |
-
video, msg = load_video(
|
85 |
-
video_path=video_path,
|
86 |
-
n_frms=self.num_frm,
|
87 |
-
height=self.resize_size,
|
88 |
-
width=self.resize_size,
|
89 |
-
sampling ="uniform", return_msg = True
|
90 |
-
)
|
91 |
-
video = self.transform(video)
|
92 |
-
if 'cn' in self.data_type:
|
93 |
-
msg = ""
|
94 |
-
# 添加视频<DEFAULT_IMAGE_PATCH_TOKEN>,以及msg到convsation list 0
|
95 |
-
sources = preprocess_multimodal(copy.deepcopy(conversation_list), None, cur_token_len=self.num_video_query_token,msg = msg)
|
96 |
-
new_sources = convert_source_vicuna_format(sources)
|
97 |
-
|
98 |
-
if self.model_type =='vicuna':
|
99 |
-
data_dict = preprocess(
|
100 |
-
new_sources,
|
101 |
-
self.tokenizer)
|
102 |
-
elif self.model_type =='llama_v2':
|
103 |
-
data_dict = preprocess_for_llama_v2(
|
104 |
-
new_sources,
|
105 |
-
self.tokenizer)
|
106 |
-
else:
|
107 |
-
print('not support')
|
108 |
-
raise('not support')
|
109 |
-
data_dict = dict(input_ids=data_dict["input_ids"][0],
|
110 |
-
labels=data_dict["labels"][0])
|
111 |
-
# image exist in the data
|
112 |
-
data_dict['image'] = video
|
113 |
-
except:
|
114 |
-
print(f"Failed to load examples with video: {video_path}. "
|
115 |
-
f"Will randomly sample an example as a replacement.")
|
116 |
-
index = random.randint(0, len(self) - 1)
|
117 |
-
continue
|
118 |
-
break
|
119 |
-
else:
|
120 |
-
raise RuntimeError(f"Failed to fetch video after {num_retries} retries.")
|
121 |
-
# "image_id" is kept to stay compatible with the COCO evaluation format
|
122 |
-
return {
|
123 |
-
"image": video,
|
124 |
-
"text_input": data_dict["input_ids"],
|
125 |
-
"labels": data_dict["labels"],
|
126 |
-
"type":'video',
|
127 |
-
}
|
128 |
-
|
129 |
-
def __len__(self):
|
130 |
-
return len(self.annotation)
|
131 |
-
|
132 |
-
def collater(self, instances):
|
133 |
-
input_ids, labels = tuple([instance[key] for instance in instances]
|
134 |
-
for key in ("text_input", "labels"))
|
135 |
-
input_ids = torch.nn.utils.rnn.pad_sequence(
|
136 |
-
input_ids,
|
137 |
-
batch_first=True,
|
138 |
-
padding_value=self.tokenizer.pad_token_id)
|
139 |
-
labels = torch.nn.utils.rnn.pad_sequence(labels,
|
140 |
-
batch_first=True,
|
141 |
-
padding_value=IGNORE_INDEX)
|
142 |
-
batch = dict(
|
143 |
-
input_ids=input_ids,
|
144 |
-
labels=labels,
|
145 |
-
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
146 |
-
)
|
147 |
-
|
148 |
-
if 'image' in instances[0]:
|
149 |
-
images = [instance['image'] for instance in instances]
|
150 |
-
if all(x is not None and x.shape == images[0].shape for x in images):
|
151 |
-
batch['images'] = torch.stack(images)
|
152 |
-
else:
|
153 |
-
batch['images'] = images
|
154 |
-
batch['conv_type'] = 'multi'
|
155 |
-
return batch
|
156 |
-
|
157 |
-
def convert_source_vicuna_format(sources):
|
158 |
-
new_sources = []
|
159 |
-
for source in sources:
|
160 |
-
new_source = []
|
161 |
-
for i, sentence in enumerate(source):
|
162 |
-
role_0_msg = sentence['q']
|
163 |
-
role_1_msg = sentence['a']
|
164 |
-
new_source.append({
|
165 |
-
'from':'human',
|
166 |
-
'value': role_0_msg,
|
167 |
-
})
|
168 |
-
new_source.append({
|
169 |
-
'from':'gpt',
|
170 |
-
'value': role_1_msg,
|
171 |
-
})
|
172 |
-
new_sources.append(new_source)
|
173 |
-
return new_sources
|
174 |
-
|
175 |
-
def preprocess_multimodal(
|
176 |
-
conversation_list: Sequence[str],
|
177 |
-
multimodal_cfg: dict,
|
178 |
-
cur_token_len: int,
|
179 |
-
msg=''
|
180 |
-
) -> Dict:
|
181 |
-
# 将conversational list中
|
182 |
-
is_multimodal = True
|
183 |
-
# image_token_len = multimodal_cfg['image_token_len']
|
184 |
-
image_token_len = cur_token_len
|
185 |
-
conversation_list[0]["q"] = "<Video>"+DEFAULT_IMAGE_PATCH_TOKEN * image_token_len +"</Video> " + msg + conversation_list[0]["q"]
|
186 |
-
return [conversation_list]
|
187 |
-
|
188 |
-
def _add_speaker_and_signal(header, source, get_conversation=True):
|
189 |
-
"""Add speaker and start/end signal on each round."""
|
190 |
-
BEGIN_SIGNAL = "###"
|
191 |
-
END_SIGNAL = "\n"
|
192 |
-
conversation = header
|
193 |
-
for sentence in source:
|
194 |
-
from_str = sentence["from"]
|
195 |
-
if from_str.lower() == "human":
|
196 |
-
from_str = video_conversation.roles[0]
|
197 |
-
elif from_str.lower() == "gpt":
|
198 |
-
from_str = video_conversation.roles[1]
|
199 |
-
else:
|
200 |
-
from_str = 'unknown'
|
201 |
-
sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
|
202 |
-
sentence["value"] + END_SIGNAL)
|
203 |
-
if get_conversation:
|
204 |
-
conversation += sentence["value"]
|
205 |
-
conversation += BEGIN_SIGNAL
|
206 |
-
return conversation
|
207 |
-
|
208 |
-
def _tokenize_fn(strings: Sequence[str],
|
209 |
-
tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
210 |
-
"""Tokenize a list of strings."""
|
211 |
-
tokenized_list = [
|
212 |
-
tokenizer(
|
213 |
-
text,
|
214 |
-
return_tensors="pt",
|
215 |
-
padding="longest",
|
216 |
-
max_length=512,
|
217 |
-
truncation=True,
|
218 |
-
) for text in strings
|
219 |
-
]
|
220 |
-
input_ids = labels = [
|
221 |
-
tokenized.input_ids[0] for tokenized in tokenized_list
|
222 |
-
]
|
223 |
-
input_ids_lens = labels_lens = [
|
224 |
-
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
|
225 |
-
for tokenized in tokenized_list
|
226 |
-
]
|
227 |
-
return dict(
|
228 |
-
input_ids=input_ids,
|
229 |
-
labels=labels,
|
230 |
-
input_ids_lens=input_ids_lens,
|
231 |
-
labels_lens=labels_lens,
|
232 |
-
)
|
233 |
-
|
234 |
-
def preprocess(
|
235 |
-
sources: Sequence[str],
|
236 |
-
tokenizer: transformers.PreTrainedTokenizer,
|
237 |
-
) -> Dict:
|
238 |
-
"""
|
239 |
-
Given a list of sources, each is a conversation list. This transform:
|
240 |
-
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
|
241 |
-
2. Concatenate conversations together;
|
242 |
-
3. Tokenize the concatenated conversation;
|
243 |
-
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
|
244 |
-
"""
|
245 |
-
# add end signal and concatenate together
|
246 |
-
conversations = []
|
247 |
-
for source in sources:
|
248 |
-
header = f"{video_conversation.system}\n\n"
|
249 |
-
conversation = _add_speaker_and_signal(header, source)
|
250 |
-
conversations.append(conversation)
|
251 |
-
# tokenize conversations
|
252 |
-
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
|
253 |
-
input_ids = conversations_tokenized["input_ids"]
|
254 |
-
targets = copy.deepcopy(input_ids)
|
255 |
-
for target, source in zip(targets, sources):
|
256 |
-
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source],
|
257 |
-
tokenizer)["input_ids_lens"]
|
258 |
-
speakers = [sentence["from"] for sentence in source]
|
259 |
-
_mask_targets(target, tokenized_lens, speakers)
|
260 |
-
|
261 |
-
return dict(input_ids=input_ids, labels=targets)
|
262 |
-
|
263 |
-
def preprocess_for_llama_v2(
|
264 |
-
sources: Sequence[str],
|
265 |
-
tokenizer: transformers.PreTrainedTokenizer,
|
266 |
-
) -> Dict:
|
267 |
-
"""
|
268 |
-
Given a list of sources, each is a conversation list. This transform:
|
269 |
-
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
|
270 |
-
2. Concatenate conversations together;
|
271 |
-
3. Tokenize the concatenated conversation;
|
272 |
-
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
|
273 |
-
"""
|
274 |
-
# add end signal and concatenate together
|
275 |
-
conversations = []
|
276 |
-
conv = copy.deepcopy(llama_v2_video_conversation.copy())
|
277 |
-
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
278 |
-
for source in sources:
|
279 |
-
# <s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n
|
280 |
-
header = f"<s>[INST] <<SYS>>\n{conv.system}\n</SYS>>\n\n"
|
281 |
-
|
282 |
-
if roles[source[0]["from"]] != conv.roles[0]:
|
283 |
-
# Skip the first one if it is not from human
|
284 |
-
source = source[1:]
|
285 |
-
conv.messages = []
|
286 |
-
for j, sentence in enumerate(source):
|
287 |
-
role = roles[sentence["from"]]
|
288 |
-
assert role == conv.roles[j % 2]
|
289 |
-
conv.append_message(role, sentence["value"])
|
290 |
-
conversations.append(conv.get_prompt())
|
291 |
-
|
292 |
-
input_ids = tokenizer(
|
293 |
-
conversations,
|
294 |
-
return_tensors="pt",
|
295 |
-
padding="longest",
|
296 |
-
max_length=512,
|
297 |
-
truncation=True,
|
298 |
-
).input_ids
|
299 |
-
targets = copy.deepcopy(input_ids)
|
300 |
-
|
301 |
-
|
302 |
-
sep = "[/INST] "
|
303 |
-
for conversation, target in zip(conversations, targets):
|
304 |
-
# total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
305 |
-
rounds = conversation.split(conv.sep2)
|
306 |
-
cur_len = 1
|
307 |
-
target[:cur_len] = IGNORE_INDEX
|
308 |
-
for i, rou in enumerate(rounds):
|
309 |
-
if rou == "":
|
310 |
-
break
|
311 |
-
|
312 |
-
parts = rou.split(sep)
|
313 |
-
if len(parts) != 2:
|
314 |
-
break
|
315 |
-
parts[0] += sep
|
316 |
-
|
317 |
-
|
318 |
-
round_len = len(tokenizer(rou).input_ids)
|
319 |
-
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 # 为什么减去2,speical token 的数目
|
320 |
-
|
321 |
-
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
322 |
-
|
323 |
-
cur_len += round_len
|
324 |
-
target[cur_len:] = IGNORE_INDEX
|
325 |
-
|
326 |
-
return dict(input_ids=input_ids, labels=targets)
|
327 |
-
def _mask_targets(target, tokenized_lens, speakers):
|
328 |
-
# cur_idx = 0
|
329 |
-
cur_idx = tokenized_lens[0]
|
330 |
-
tokenized_lens = tokenized_lens[1:]
|
331 |
-
target[:cur_idx] = IGNORE_INDEX
|
332 |
-
for tokenized_len, speaker in zip(tokenized_lens, speakers):
|
333 |
-
if speaker == "human":
|
334 |
-
target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
|
335 |
-
cur_idx += tokenized_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/datasets/datasets/webvid_datasets.py
DELETED
@@ -1,122 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
3 |
-
All rights reserved.
|
4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
5 |
-
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
-
"""
|
7 |
-
|
8 |
-
import os
|
9 |
-
from sonique.Video_LLaMA.video_llama.datasets.datasets.base_dataset import BaseDataset
|
10 |
-
from sonique.Video_LLaMA.video_llama.datasets.datasets.caption_datasets import CaptionDataset
|
11 |
-
import pandas as pd
|
12 |
-
import decord
|
13 |
-
from decord import VideoReader
|
14 |
-
import random
|
15 |
-
import torch
|
16 |
-
from torch.utils.data.dataloader import default_collate
|
17 |
-
class WebvidDataset(BaseDataset):
|
18 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_root):
|
19 |
-
"""
|
20 |
-
vis_root (string): Root directory of video (e.g. webvid_eval/video/)
|
21 |
-
ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
|
22 |
-
split (string): val or test
|
23 |
-
"""
|
24 |
-
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
25 |
-
|
26 |
-
|
27 |
-
# 读取一个路径下所有的
|
28 |
-
|
29 |
-
ts_df = []
|
30 |
-
for file_name in os.listdir(ann_root):
|
31 |
-
if file_name.endswith('.csv'):
|
32 |
-
df = pd.read_csv(os.path.join(ann_root, file_name))
|
33 |
-
ts_df.append(df)
|
34 |
-
|
35 |
-
merged_df = pd.concat(ts_df)
|
36 |
-
self.annotation = merged_df
|
37 |
-
self.vis_root = vis_root
|
38 |
-
self.resize_size = 224
|
39 |
-
self.num_frm = 8
|
40 |
-
self.frm_sampling_strategy = 'headtail'
|
41 |
-
|
42 |
-
def _get_video_path(self, sample):
|
43 |
-
rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4')
|
44 |
-
full_video_fp = os.path.join(self.vis_root, rel_video_fp)
|
45 |
-
return full_video_fp
|
46 |
-
|
47 |
-
def __getitem__(self, index):
|
48 |
-
num_retries = 10 # skip error videos
|
49 |
-
for _ in range(num_retries):
|
50 |
-
sample = self.annotation.iloc[index]
|
51 |
-
sample_dict = sample.to_dict()
|
52 |
-
video_id = sample_dict['videoid']
|
53 |
-
|
54 |
-
if 'name' in sample_dict.keys():
|
55 |
-
text = sample_dict['name'].strip()
|
56 |
-
else:
|
57 |
-
raise NotImplementedError("Un-supported text annotation format.")
|
58 |
-
|
59 |
-
# fetch video
|
60 |
-
video_path = self._get_video_path(sample_dict)
|
61 |
-
# if os.path.exists(video_path):
|
62 |
-
try:
|
63 |
-
video = self.vis_processor(video_path)
|
64 |
-
except:
|
65 |
-
print(f"Failed to load examples with video: {video_path}. "
|
66 |
-
f"Will randomly sample an example as a replacement.")
|
67 |
-
index = random.randint(0, len(self) - 1)
|
68 |
-
continue
|
69 |
-
caption = self.text_processor(text)
|
70 |
-
|
71 |
-
# print(video.size())
|
72 |
-
if video is None or caption is None \
|
73 |
-
or video.size()!=torch.Size([3,self.vis_processor.n_frms,224,224]):
|
74 |
-
print(f"Failed to load examples with video: {video_path}. "
|
75 |
-
f"Will randomly sample an example as a replacement.")
|
76 |
-
index = random.randint(0, len(self) - 1)
|
77 |
-
continue
|
78 |
-
else:
|
79 |
-
break
|
80 |
-
else:
|
81 |
-
raise RuntimeError(f"Failed to fetch video after {num_retries} retries.")
|
82 |
-
# "image_id" is kept to stay compatible with the COCO evaluation format
|
83 |
-
return {
|
84 |
-
"image": video,
|
85 |
-
"text_input": caption,
|
86 |
-
"type":'video',
|
87 |
-
}
|
88 |
-
|
89 |
-
def __len__(self):
|
90 |
-
return len(self.annotation)
|
91 |
-
|
92 |
-
# def collater(self, samples):
|
93 |
-
# new_result = {}
|
94 |
-
# new_result['image'] = default_collate( [sample["image"] for sample in samples])
|
95 |
-
# new_result['text_input'] = default_collate( [sample["text_input"] for sample in samples])
|
96 |
-
# return new_result
|
97 |
-
|
98 |
-
class WebvidDatasetEvalDataset(BaseDataset):
|
99 |
-
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
100 |
-
"""
|
101 |
-
vis_root (string): Root directory of images (e.g. coco/images/)
|
102 |
-
ann_root (string): directory to store the annotation file
|
103 |
-
split (string): val or test
|
104 |
-
"""
|
105 |
-
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
106 |
-
|
107 |
-
def __getitem__(self, index):
|
108 |
-
|
109 |
-
ann = self.annotation[index]
|
110 |
-
|
111 |
-
vname = ann["video"]
|
112 |
-
video_path = os.path.join(self.vis_root, vname)
|
113 |
-
|
114 |
-
video = self.vis_processor(video_path)
|
115 |
-
|
116 |
-
return {
|
117 |
-
"video": video,
|
118 |
-
"image_id": ann["image_id"],
|
119 |
-
"instance_id": ann["instance_id"],
|
120 |
-
}
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/models/ImageBind/.assets/bird_image.jpg
DELETED
Binary file (115 kB)
|
|
sonique/Video_LLaMA/video_llama/models/ImageBind/.assets/car_image.jpg
DELETED
Binary file (59.3 kB)
|
|
sonique/Video_LLaMA/video_llama/models/ImageBind/.assets/dog_image.jpg
DELETED
Binary file (86.1 kB)
|
|
sonique/Video_LLaMA/video_llama/models/ImageBind/CODE_OF_CONDUCT.md
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
# Code of Conduct
|
2 |
-
|
3 |
-
## Our Pledge
|
4 |
-
|
5 |
-
In the interest of fostering an open and welcoming environment, we as
|
6 |
-
contributors and maintainers pledge to make participation in our project and
|
7 |
-
our community a harassment-free experience for everyone, regardless of age, body
|
8 |
-
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
9 |
-
level of experience, education, socio-economic status, nationality, personal
|
10 |
-
appearance, race, religion, or sexual identity and orientation.
|
11 |
-
|
12 |
-
## Our Standards
|
13 |
-
|
14 |
-
Examples of behavior that contributes to creating a positive environment
|
15 |
-
include:
|
16 |
-
|
17 |
-
* Using welcoming and inclusive language
|
18 |
-
* Being respectful of differing viewpoints and experiences
|
19 |
-
* Gracefully accepting constructive criticism
|
20 |
-
* Focusing on what is best for the community
|
21 |
-
* Showing empathy towards other community members
|
22 |
-
|
23 |
-
Examples of unacceptable behavior by participants include:
|
24 |
-
|
25 |
-
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26 |
-
advances
|
27 |
-
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28 |
-
* Public or private harassment
|
29 |
-
* Publishing others' private information, such as a physical or electronic
|
30 |
-
address, without explicit permission
|
31 |
-
* Other conduct which could reasonably be considered inappropriate in a
|
32 |
-
professional setting
|
33 |
-
|
34 |
-
## Our Responsibilities
|
35 |
-
|
36 |
-
Project maintainers are responsible for clarifying the standards of acceptable
|
37 |
-
behavior and are expected to take appropriate and fair corrective action in
|
38 |
-
response to any instances of unacceptable behavior.
|
39 |
-
|
40 |
-
Project maintainers have the right and responsibility to remove, edit, or
|
41 |
-
reject comments, commits, code, wiki edits, issues, and other contributions
|
42 |
-
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43 |
-
permanently any contributor for other behaviors that they deem inappropriate,
|
44 |
-
threatening, offensive, or harmful.
|
45 |
-
|
46 |
-
## Scope
|
47 |
-
|
48 |
-
This Code of Conduct applies within all project spaces, and it also applies when
|
49 |
-
an individual is representing the project or its community in public spaces.
|
50 |
-
Examples of representing a project or community include using an official
|
51 |
-
project e-mail address, posting via an official social media account, or acting
|
52 |
-
as an appointed representative at an online or offline event. Representation of
|
53 |
-
a project may be further defined and clarified by project maintainers.
|
54 |
-
|
55 |
-
This Code of Conduct also applies outside the project spaces when there is a
|
56 |
-
reasonable belief that an individual's behavior may have a negative impact on
|
57 |
-
the project or its community.
|
58 |
-
|
59 |
-
## Enforcement
|
60 |
-
|
61 |
-
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
62 |
-
reported by contacting the project team at <[email protected]>. All
|
63 |
-
complaints will be reviewed and investigated and will result in a response that
|
64 |
-
is deemed necessary and appropriate to the circumstances. The project team is
|
65 |
-
obligated to maintain confidentiality with regard to the reporter of an incident.
|
66 |
-
Further details of specific enforcement policies may be posted separately.
|
67 |
-
|
68 |
-
Project maintainers who do not follow or enforce the Code of Conduct in good
|
69 |
-
faith may face temporary or permanent repercussions as determined by other
|
70 |
-
members of the project's leadership.
|
71 |
-
|
72 |
-
## Attribution
|
73 |
-
|
74 |
-
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
75 |
-
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
76 |
-
|
77 |
-
[homepage]: https://www.contributor-covenant.org
|
78 |
-
|
79 |
-
For answers to common questions about this code of conduct, see
|
80 |
-
https://www.contributor-covenant.org/faq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/models/ImageBind/CONTRIBUTING.md
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
# Contributing to ImageBind
|
2 |
-
We want to make contributing to this project as easy and transparent as
|
3 |
-
possible.
|
4 |
-
|
5 |
-
## Pull Requests
|
6 |
-
We actively welcome your pull requests.
|
7 |
-
|
8 |
-
1. Fork the repo and create your branch from `main`.
|
9 |
-
2. If you've added code that should be tested, add tests.
|
10 |
-
3. If you've changed APIs, update the documentation.
|
11 |
-
4. Ensure the test suite passes.
|
12 |
-
5. Make sure your code lints.
|
13 |
-
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
14 |
-
|
15 |
-
## Contributor License Agreement ("CLA")
|
16 |
-
In order to accept your pull request, we need you to submit a CLA. You only need
|
17 |
-
to do this once to work on any of Meta's open source projects.
|
18 |
-
|
19 |
-
Complete your CLA here: <https://code.facebook.com/cla>
|
20 |
-
|
21 |
-
## Issues
|
22 |
-
We use GitHub issues to track public bugs. Please ensure your description is
|
23 |
-
clear and has sufficient instructions to be able to reproduce the issue.
|
24 |
-
|
25 |
-
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
26 |
-
disclosure of security bugs. In those cases, please go through the process
|
27 |
-
outlined on that page and do not file a public issue.
|
28 |
-
|
29 |
-
## License
|
30 |
-
By contributing to Omnivore, you agree that your contributions will be licensed
|
31 |
-
under the [LICENSE](LICENSE) file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/models/ImageBind/LICENSE
DELETED
@@ -1,437 +0,0 @@
|
|
1 |
-
Attribution-NonCommercial-ShareAlike 4.0 International
|
2 |
-
|
3 |
-
=======================================================================
|
4 |
-
|
5 |
-
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
-
does not provide legal services or legal advice. Distribution of
|
7 |
-
Creative Commons public licenses does not create a lawyer-client or
|
8 |
-
other relationship. Creative Commons makes its licenses and related
|
9 |
-
information available on an "as-is" basis. Creative Commons gives no
|
10 |
-
warranties regarding its licenses, any material licensed under their
|
11 |
-
terms and conditions, or any related information. Creative Commons
|
12 |
-
disclaims all liability for damages resulting from their use to the
|
13 |
-
fullest extent possible.
|
14 |
-
|
15 |
-
Using Creative Commons Public Licenses
|
16 |
-
|
17 |
-
Creative Commons public licenses provide a standard set of terms and
|
18 |
-
conditions that creators and other rights holders may use to share
|
19 |
-
original works of authorship and other material subject to copyright
|
20 |
-
and certain other rights specified in the public license below. The
|
21 |
-
following considerations are for informational purposes only, are not
|
22 |
-
exhaustive, and do not form part of our licenses.
|
23 |
-
|
24 |
-
Considerations for licensors: Our public licenses are
|
25 |
-
intended for use by those authorized to give the public
|
26 |
-
permission to use material in ways otherwise restricted by
|
27 |
-
copyright and certain other rights. Our licenses are
|
28 |
-
irrevocable. Licensors should read and understand the terms
|
29 |
-
and conditions of the license they choose before applying it.
|
30 |
-
Licensors should also secure all rights necessary before
|
31 |
-
applying our licenses so that the public can reuse the
|
32 |
-
material as expected. Licensors should clearly mark any
|
33 |
-
material not subject to the license. This includes other CC-
|
34 |
-
licensed material, or material used under an exception or
|
35 |
-
limitation to copyright. More considerations for licensors:
|
36 |
-
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
-
|
38 |
-
Considerations for the public: By using one of our public
|
39 |
-
licenses, a licensor grants the public permission to use the
|
40 |
-
licensed material under specified terms and conditions. If
|
41 |
-
the licensor's permission is not necessary for any reason--for
|
42 |
-
example, because of any applicable exception or limitation to
|
43 |
-
copyright--then that use is not regulated by the license. Our
|
44 |
-
licenses grant only permissions under copyright and certain
|
45 |
-
other rights that a licensor has authority to grant. Use of
|
46 |
-
the licensed material may still be restricted for other
|
47 |
-
reasons, including because others have copyright or other
|
48 |
-
rights in the material. A licensor may make special requests,
|
49 |
-
such as asking that all changes be marked or described.
|
50 |
-
Although not required by our licenses, you are encouraged to
|
51 |
-
respect those requests where reasonable. More considerations
|
52 |
-
for the public:
|
53 |
-
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
-
|
55 |
-
=======================================================================
|
56 |
-
|
57 |
-
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
58 |
-
Public License
|
59 |
-
|
60 |
-
By exercising the Licensed Rights (defined below), You accept and agree
|
61 |
-
to be bound by the terms and conditions of this Creative Commons
|
62 |
-
Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
63 |
-
("Public License"). To the extent this Public License may be
|
64 |
-
interpreted as a contract, You are granted the Licensed Rights in
|
65 |
-
consideration of Your acceptance of these terms and conditions, and the
|
66 |
-
Licensor grants You such rights in consideration of benefits the
|
67 |
-
Licensor receives from making the Licensed Material available under
|
68 |
-
these terms and conditions.
|
69 |
-
|
70 |
-
|
71 |
-
Section 1 -- Definitions.
|
72 |
-
|
73 |
-
a. Adapted Material means material subject to Copyright and Similar
|
74 |
-
Rights that is derived from or based upon the Licensed Material
|
75 |
-
and in which the Licensed Material is translated, altered,
|
76 |
-
arranged, transformed, or otherwise modified in a manner requiring
|
77 |
-
permission under the Copyright and Similar Rights held by the
|
78 |
-
Licensor. For purposes of this Public License, where the Licensed
|
79 |
-
Material is a musical work, performance, or sound recording,
|
80 |
-
Adapted Material is always produced where the Licensed Material is
|
81 |
-
synched in timed relation with a moving image.
|
82 |
-
|
83 |
-
b. Adapter's License means the license You apply to Your Copyright
|
84 |
-
and Similar Rights in Your contributions to Adapted Material in
|
85 |
-
accordance with the terms and conditions of this Public License.
|
86 |
-
|
87 |
-
c. BY-NC-SA Compatible License means a license listed at
|
88 |
-
creativecommons.org/compatiblelicenses, approved by Creative
|
89 |
-
Commons as essentially the equivalent of this Public License.
|
90 |
-
|
91 |
-
d. Copyright and Similar Rights means copyright and/or similar rights
|
92 |
-
closely related to copyright including, without limitation,
|
93 |
-
performance, broadcast, sound recording, and Sui Generis Database
|
94 |
-
Rights, without regard to how the rights are labeled or
|
95 |
-
categorized. For purposes of this Public License, the rights
|
96 |
-
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
97 |
-
Rights.
|
98 |
-
|
99 |
-
e. Effective Technological Measures means those measures that, in the
|
100 |
-
absence of proper authority, may not be circumvented under laws
|
101 |
-
fulfilling obligations under Article 11 of the WIPO Copyright
|
102 |
-
Treaty adopted on December 20, 1996, and/or similar international
|
103 |
-
agreements.
|
104 |
-
|
105 |
-
f. Exceptions and Limitations means fair use, fair dealing, and/or
|
106 |
-
any other exception or limitation to Copyright and Similar Rights
|
107 |
-
that applies to Your use of the Licensed Material.
|
108 |
-
|
109 |
-
g. License Elements means the license attributes listed in the name
|
110 |
-
of a Creative Commons Public License. The License Elements of this
|
111 |
-
Public License are Attribution, NonCommercial, and ShareAlike.
|
112 |
-
|
113 |
-
h. Licensed Material means the artistic or literary work, database,
|
114 |
-
or other material to which the Licensor applied this Public
|
115 |
-
License.
|
116 |
-
|
117 |
-
i. Licensed Rights means the rights granted to You subject to the
|
118 |
-
terms and conditions of this Public License, which are limited to
|
119 |
-
all Copyright and Similar Rights that apply to Your use of the
|
120 |
-
Licensed Material and that the Licensor has authority to license.
|
121 |
-
|
122 |
-
j. Licensor means the individual(s) or entity(ies) granting rights
|
123 |
-
under this Public License.
|
124 |
-
|
125 |
-
k. NonCommercial means not primarily intended for or directed towards
|
126 |
-
commercial advantage or monetary compensation. For purposes of
|
127 |
-
this Public License, the exchange of the Licensed Material for
|
128 |
-
other material subject to Copyright and Similar Rights by digital
|
129 |
-
file-sharing or similar means is NonCommercial provided there is
|
130 |
-
no payment of monetary compensation in connection with the
|
131 |
-
exchange.
|
132 |
-
|
133 |
-
l. Share means to provide material to the public by any means or
|
134 |
-
process that requires permission under the Licensed Rights, such
|
135 |
-
as reproduction, public display, public performance, distribution,
|
136 |
-
dissemination, communication, or importation, and to make material
|
137 |
-
available to the public including in ways that members of the
|
138 |
-
public may access the material from a place and at a time
|
139 |
-
individually chosen by them.
|
140 |
-
|
141 |
-
m. Sui Generis Database Rights means rights other than copyright
|
142 |
-
resulting from Directive 96/9/EC of the European Parliament and of
|
143 |
-
the Council of 11 March 1996 on the legal protection of databases,
|
144 |
-
as amended and/or succeeded, as well as other essentially
|
145 |
-
equivalent rights anywhere in the world.
|
146 |
-
|
147 |
-
n. You means the individual or entity exercising the Licensed Rights
|
148 |
-
under this Public License. Your has a corresponding meaning.
|
149 |
-
|
150 |
-
|
151 |
-
Section 2 -- Scope.
|
152 |
-
|
153 |
-
a. License grant.
|
154 |
-
|
155 |
-
1. Subject to the terms and conditions of this Public License,
|
156 |
-
the Licensor hereby grants You a worldwide, royalty-free,
|
157 |
-
non-sublicensable, non-exclusive, irrevocable license to
|
158 |
-
exercise the Licensed Rights in the Licensed Material to:
|
159 |
-
|
160 |
-
a. reproduce and Share the Licensed Material, in whole or
|
161 |
-
in part, for NonCommercial purposes only; and
|
162 |
-
|
163 |
-
b. produce, reproduce, and Share Adapted Material for
|
164 |
-
NonCommercial purposes only.
|
165 |
-
|
166 |
-
2. Exceptions and Limitations. For the avoidance of doubt, where
|
167 |
-
Exceptions and Limitations apply to Your use, this Public
|
168 |
-
License does not apply, and You do not need to comply with
|
169 |
-
its terms and conditions.
|
170 |
-
|
171 |
-
3. Term. The term of this Public License is specified in Section
|
172 |
-
6(a).
|
173 |
-
|
174 |
-
4. Media and formats; technical modifications allowed. The
|
175 |
-
Licensor authorizes You to exercise the Licensed Rights in
|
176 |
-
all media and formats whether now known or hereafter created,
|
177 |
-
and to make technical modifications necessary to do so. The
|
178 |
-
Licensor waives and/or agrees not to assert any right or
|
179 |
-
authority to forbid You from making technical modifications
|
180 |
-
necessary to exercise the Licensed Rights, including
|
181 |
-
technical modifications necessary to circumvent Effective
|
182 |
-
Technological Measures. For purposes of this Public License,
|
183 |
-
simply making modifications authorized by this Section 2(a)
|
184 |
-
(4) never produces Adapted Material.
|
185 |
-
|
186 |
-
5. Downstream recipients.
|
187 |
-
|
188 |
-
a. Offer from the Licensor -- Licensed Material. Every
|
189 |
-
recipient of the Licensed Material automatically
|
190 |
-
receives an offer from the Licensor to exercise the
|
191 |
-
Licensed Rights under the terms and conditions of this
|
192 |
-
Public License.
|
193 |
-
|
194 |
-
b. Additional offer from the Licensor -- Adapted Material.
|
195 |
-
Every recipient of Adapted Material from You
|
196 |
-
automatically receives an offer from the Licensor to
|
197 |
-
exercise the Licensed Rights in the Adapted Material
|
198 |
-
under the conditions of the Adapter's License You apply.
|
199 |
-
|
200 |
-
c. No downstream restrictions. You may not offer or impose
|
201 |
-
any additional or different terms or conditions on, or
|
202 |
-
apply any Effective Technological Measures to, the
|
203 |
-
Licensed Material if doing so restricts exercise of the
|
204 |
-
Licensed Rights by any recipient of the Licensed
|
205 |
-
Material.
|
206 |
-
|
207 |
-
6. No endorsement. Nothing in this Public License constitutes or
|
208 |
-
may be construed as permission to assert or imply that You
|
209 |
-
are, or that Your use of the Licensed Material is, connected
|
210 |
-
with, or sponsored, endorsed, or granted official status by,
|
211 |
-
the Licensor or others designated to receive attribution as
|
212 |
-
provided in Section 3(a)(1)(A)(i).
|
213 |
-
|
214 |
-
b. Other rights.
|
215 |
-
|
216 |
-
1. Moral rights, such as the right of integrity, are not
|
217 |
-
licensed under this Public License, nor are publicity,
|
218 |
-
privacy, and/or other similar personality rights; however, to
|
219 |
-
the extent possible, the Licensor waives and/or agrees not to
|
220 |
-
assert any such rights held by the Licensor to the limited
|
221 |
-
extent necessary to allow You to exercise the Licensed
|
222 |
-
Rights, but not otherwise.
|
223 |
-
|
224 |
-
2. Patent and trademark rights are not licensed under this
|
225 |
-
Public License.
|
226 |
-
|
227 |
-
3. To the extent possible, the Licensor waives any right to
|
228 |
-
collect royalties from You for the exercise of the Licensed
|
229 |
-
Rights, whether directly or through a collecting society
|
230 |
-
under any voluntary or waivable statutory or compulsory
|
231 |
-
licensing scheme. In all other cases the Licensor expressly
|
232 |
-
reserves any right to collect such royalties, including when
|
233 |
-
the Licensed Material is used other than for NonCommercial
|
234 |
-
purposes.
|
235 |
-
|
236 |
-
|
237 |
-
Section 3 -- License Conditions.
|
238 |
-
|
239 |
-
Your exercise of the Licensed Rights is expressly made subject to the
|
240 |
-
following conditions.
|
241 |
-
|
242 |
-
a. Attribution.
|
243 |
-
|
244 |
-
1. If You Share the Licensed Material (including in modified
|
245 |
-
form), You must:
|
246 |
-
|
247 |
-
a. retain the following if it is supplied by the Licensor
|
248 |
-
with the Licensed Material:
|
249 |
-
|
250 |
-
i. identification of the creator(s) of the Licensed
|
251 |
-
Material and any others designated to receive
|
252 |
-
attribution, in any reasonable manner requested by
|
253 |
-
the Licensor (including by pseudonym if
|
254 |
-
designated);
|
255 |
-
|
256 |
-
ii. a copyright notice;
|
257 |
-
|
258 |
-
iii. a notice that refers to this Public License;
|
259 |
-
|
260 |
-
iv. a notice that refers to the disclaimer of
|
261 |
-
warranties;
|
262 |
-
|
263 |
-
v. a URI or hyperlink to the Licensed Material to the
|
264 |
-
extent reasonably practicable;
|
265 |
-
|
266 |
-
b. indicate if You modified the Licensed Material and
|
267 |
-
retain an indication of any previous modifications; and
|
268 |
-
|
269 |
-
c. indicate the Licensed Material is licensed under this
|
270 |
-
Public License, and include the text of, or the URI or
|
271 |
-
hyperlink to, this Public License.
|
272 |
-
|
273 |
-
2. You may satisfy the conditions in Section 3(a)(1) in any
|
274 |
-
reasonable manner based on the medium, means, and context in
|
275 |
-
which You Share the Licensed Material. For example, it may be
|
276 |
-
reasonable to satisfy the conditions by providing a URI or
|
277 |
-
hyperlink to a resource that includes the required
|
278 |
-
information.
|
279 |
-
3. If requested by the Licensor, You must remove any of the
|
280 |
-
information required by Section 3(a)(1)(A) to the extent
|
281 |
-
reasonably practicable.
|
282 |
-
|
283 |
-
b. ShareAlike.
|
284 |
-
|
285 |
-
In addition to the conditions in Section 3(a), if You Share
|
286 |
-
Adapted Material You produce, the following conditions also apply.
|
287 |
-
|
288 |
-
1. The Adapter's License You apply must be a Creative Commons
|
289 |
-
license with the same License Elements, this version or
|
290 |
-
later, or a BY-NC-SA Compatible License.
|
291 |
-
|
292 |
-
2. You must include the text of, or the URI or hyperlink to, the
|
293 |
-
Adapter's License You apply. You may satisfy this condition
|
294 |
-
in any reasonable manner based on the medium, means, and
|
295 |
-
context in which You Share Adapted Material.
|
296 |
-
|
297 |
-
3. You may not offer or impose any additional or different terms
|
298 |
-
or conditions on, or apply any Effective Technological
|
299 |
-
Measures to, Adapted Material that restrict exercise of the
|
300 |
-
rights granted under the Adapter's License You apply.
|
301 |
-
|
302 |
-
|
303 |
-
Section 4 -- Sui Generis Database Rights.
|
304 |
-
|
305 |
-
Where the Licensed Rights include Sui Generis Database Rights that
|
306 |
-
apply to Your use of the Licensed Material:
|
307 |
-
|
308 |
-
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
309 |
-
to extract, reuse, reproduce, and Share all or a substantial
|
310 |
-
portion of the contents of the database for NonCommercial purposes
|
311 |
-
only;
|
312 |
-
|
313 |
-
b. if You include all or a substantial portion of the database
|
314 |
-
contents in a database in which You have Sui Generis Database
|
315 |
-
Rights, then the database in which You have Sui Generis Database
|
316 |
-
Rights (but not its individual contents) is Adapted Material,
|
317 |
-
including for purposes of Section 3(b); and
|
318 |
-
|
319 |
-
c. You must comply with the conditions in Section 3(a) if You Share
|
320 |
-
all or a substantial portion of the contents of the database.
|
321 |
-
|
322 |
-
For the avoidance of doubt, this Section 4 supplements and does not
|
323 |
-
replace Your obligations under this Public License where the Licensed
|
324 |
-
Rights include other Copyright and Similar Rights.
|
325 |
-
|
326 |
-
|
327 |
-
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
328 |
-
|
329 |
-
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
330 |
-
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
331 |
-
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
332 |
-
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
333 |
-
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
334 |
-
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
335 |
-
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
336 |
-
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
337 |
-
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
338 |
-
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
339 |
-
|
340 |
-
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
341 |
-
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
342 |
-
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
343 |
-
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
344 |
-
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
345 |
-
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
346 |
-
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
347 |
-
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
348 |
-
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
349 |
-
|
350 |
-
c. The disclaimer of warranties and limitation of liability provided
|
351 |
-
above shall be interpreted in a manner that, to the extent
|
352 |
-
possible, most closely approximates an absolute disclaimer and
|
353 |
-
waiver of all liability.
|
354 |
-
|
355 |
-
|
356 |
-
Section 6 -- Term and Termination.
|
357 |
-
|
358 |
-
a. This Public License applies for the term of the Copyright and
|
359 |
-
Similar Rights licensed here. However, if You fail to comply with
|
360 |
-
this Public License, then Your rights under this Public License
|
361 |
-
terminate automatically.
|
362 |
-
|
363 |
-
b. Where Your right to use the Licensed Material has terminated under
|
364 |
-
Section 6(a), it reinstates:
|
365 |
-
|
366 |
-
1. automatically as of the date the violation is cured, provided
|
367 |
-
it is cured within 30 days of Your discovery of the
|
368 |
-
violation; or
|
369 |
-
|
370 |
-
2. upon express reinstatement by the Licensor.
|
371 |
-
|
372 |
-
For the avoidance of doubt, this Section 6(b) does not affect any
|
373 |
-
right the Licensor may have to seek remedies for Your violations
|
374 |
-
of this Public License.
|
375 |
-
|
376 |
-
c. For the avoidance of doubt, the Licensor may also offer the
|
377 |
-
Licensed Material under separate terms or conditions or stop
|
378 |
-
distributing the Licensed Material at any time; however, doing so
|
379 |
-
will not terminate this Public License.
|
380 |
-
|
381 |
-
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
382 |
-
License.
|
383 |
-
|
384 |
-
|
385 |
-
Section 7 -- Other Terms and Conditions.
|
386 |
-
|
387 |
-
a. The Licensor shall not be bound by any additional or different
|
388 |
-
terms or conditions communicated by You unless expressly agreed.
|
389 |
-
|
390 |
-
b. Any arrangements, understandings, or agreements regarding the
|
391 |
-
Licensed Material not stated herein are separate from and
|
392 |
-
independent of the terms and conditions of this Public License.
|
393 |
-
|
394 |
-
|
395 |
-
Section 8 -- Interpretation.
|
396 |
-
|
397 |
-
a. For the avoidance of doubt, this Public License does not, and
|
398 |
-
shall not be interpreted to, reduce, limit, restrict, or impose
|
399 |
-
conditions on any use of the Licensed Material that could lawfully
|
400 |
-
be made without permission under this Public License.
|
401 |
-
|
402 |
-
b. To the extent possible, if any provision of this Public License is
|
403 |
-
deemed unenforceable, it shall be automatically reformed to the
|
404 |
-
minimum extent necessary to make it enforceable. If the provision
|
405 |
-
cannot be reformed, it shall be severed from this Public License
|
406 |
-
without affecting the enforceability of the remaining terms and
|
407 |
-
conditions.
|
408 |
-
|
409 |
-
c. No term or condition of this Public License will be waived and no
|
410 |
-
failure to comply consented to unless expressly agreed to by the
|
411 |
-
Licensor.
|
412 |
-
|
413 |
-
d. Nothing in this Public License constitutes or may be interpreted
|
414 |
-
as a limitation upon, or waiver of, any privileges and immunities
|
415 |
-
that apply to the Licensor or You, including from the legal
|
416 |
-
processes of any jurisdiction or authority.
|
417 |
-
|
418 |
-
=======================================================================
|
419 |
-
|
420 |
-
Creative Commons is not a party to its public
|
421 |
-
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
422 |
-
its public licenses to material it publishes and in those instances
|
423 |
-
will be considered the “Licensor.” The text of the Creative Commons
|
424 |
-
public licenses is dedicated to the public domain under the CC0 Public
|
425 |
-
Domain Dedication. Except for the limited purpose of indicating that
|
426 |
-
material is shared under a Creative Commons public license or as
|
427 |
-
otherwise permitted by the Creative Commons policies published at
|
428 |
-
creativecommons.org/policies, Creative Commons does not authorize the
|
429 |
-
use of the trademark "Creative Commons" or any other trademark or logo
|
430 |
-
of Creative Commons without its prior written consent including,
|
431 |
-
without limitation, in connection with any unauthorized modifications
|
432 |
-
to any of its public licenses or any other arrangements,
|
433 |
-
understandings, or agreements concerning use of licensed material. For
|
434 |
-
the avoidance of doubt, this paragraph does not form part of the
|
435 |
-
public licenses.
|
436 |
-
|
437 |
-
Creative Commons may be contacted at creativecommons.org.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/models/ImageBind/README.md
DELETED
@@ -1,155 +0,0 @@
|
|
1 |
-
# ImageBind: One Embedding Space To Bind Them All
|
2 |
-
|
3 |
-
**[FAIR, Meta AI](https://ai.facebook.com/research/)**
|
4 |
-
|
5 |
-
Rohit Girdhar*,
|
6 |
-
Alaaeldin El-Nouby*,
|
7 |
-
Zhuang Liu,
|
8 |
-
Mannat Singh,
|
9 |
-
Kalyan Vasudev Alwala,
|
10 |
-
Armand Joulin,
|
11 |
-
Ishan Misra*
|
12 |
-
|
13 |
-
To appear at CVPR 2023 (*Highlighted paper*)
|
14 |
-
|
15 |
-
[[`Paper`](https://facebookresearch.github.io/ImageBind/paper)] [[`Blog`](https://ai.facebook.com/blog/imagebind-six-modalities-binding-ai/)] [[`Demo`](https://imagebind.metademolab.com/)] [[`Supplementary Video`](https://dl.fbaipublicfiles.com/imagebind/imagebind_video.mp4)] [[`BibTex`](#citing-imagebind)]
|
16 |
-
|
17 |
-
PyTorch implementation and pretrained models for ImageBind. For details, see the paper: **[ImageBind: One Embedding Space To Bind Them All](https://facebookresearch.github.io/ImageBind/paper)**.
|
18 |
-
|
19 |
-
ImageBind learns a joint embedding across six different modalities - images, text, audio, depth, thermal, and IMU data. It enables novel emergent applications ‘out-of-the-box’ including cross-modal retrieval, composing modalities with arithmetic, cross-modal detection and generation.
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-

|
24 |
-
|
25 |
-
## ImageBind model
|
26 |
-
|
27 |
-
Emergent zero-shot classification performance.
|
28 |
-
|
29 |
-
<table style="margin: auto">
|
30 |
-
<tr>
|
31 |
-
<th>Model</th>
|
32 |
-
<th><span style="color:blue">IN1k</span></th>
|
33 |
-
<th><span style="color:purple">K400</span></th>
|
34 |
-
<th><span style="color:green">NYU-D</span></th>
|
35 |
-
<th><span style="color:LightBlue">ESC</span></th>
|
36 |
-
<th><span style="color:orange">LLVIP</span></th>
|
37 |
-
<th><span style="color:purple">Ego4D</span></th>
|
38 |
-
<th>download</th>
|
39 |
-
</tr>
|
40 |
-
<tr>
|
41 |
-
<td>imagebind_huge</td>
|
42 |
-
<td align="right">77.7</td>
|
43 |
-
<td align="right">50.0</td>
|
44 |
-
<td align="right">54.0</td>
|
45 |
-
<td align="right">66.9</td>
|
46 |
-
<td align="right">63.4</td>
|
47 |
-
<td align="right">25.0</td>
|
48 |
-
<td><a href="https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth">checkpoint</a></td>
|
49 |
-
</tr>
|
50 |
-
|
51 |
-
</table>
|
52 |
-
|
53 |
-
## Usage
|
54 |
-
|
55 |
-
Install pytorch 1.13+ and other 3rd party dependencies.
|
56 |
-
|
57 |
-
```shell
|
58 |
-
conda create --name imagebind python=3.8 -y
|
59 |
-
conda activate imagebind
|
60 |
-
|
61 |
-
pip install -r requirements.txt
|
62 |
-
```
|
63 |
-
|
64 |
-
For windows users, you might need to install `soundfile` for reading/writing audio files. (Thanks @congyue1977)
|
65 |
-
|
66 |
-
```
|
67 |
-
pip install soundfile
|
68 |
-
```
|
69 |
-
|
70 |
-
|
71 |
-
Extract and compare features across modalities (e.g. Image, Text and Audio).
|
72 |
-
|
73 |
-
```python
|
74 |
-
import data
|
75 |
-
import torch
|
76 |
-
from models import imagebind_model
|
77 |
-
from models.imagebind_model import ModalityType
|
78 |
-
|
79 |
-
text_list=["A dog.", "A car", "A bird"]
|
80 |
-
image_paths=[".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"]
|
81 |
-
audio_paths=[".assets/dog_audio.wav", ".assets/car_audio.wav", ".assets/bird_audio.wav"]
|
82 |
-
|
83 |
-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
84 |
-
|
85 |
-
# Instantiate model
|
86 |
-
model = imagebind_model.imagebind_huge(pretrained=True)
|
87 |
-
model.eval()
|
88 |
-
model.to(device)
|
89 |
-
|
90 |
-
# Load data
|
91 |
-
inputs = {
|
92 |
-
ModalityType.TEXT: data.load_and_transform_text(text_list, device),
|
93 |
-
ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
|
94 |
-
ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
|
95 |
-
}
|
96 |
-
|
97 |
-
with torch.no_grad():
|
98 |
-
embeddings = model(inputs)
|
99 |
-
|
100 |
-
print(
|
101 |
-
"Vision x Text: ",
|
102 |
-
torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1),
|
103 |
-
)
|
104 |
-
print(
|
105 |
-
"Audio x Text: ",
|
106 |
-
torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1),
|
107 |
-
)
|
108 |
-
print(
|
109 |
-
"Vision x Audio: ",
|
110 |
-
torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1),
|
111 |
-
)
|
112 |
-
|
113 |
-
# Expected output:
|
114 |
-
#
|
115 |
-
# Vision x Text:
|
116 |
-
# tensor([[9.9761e-01, 2.3694e-03, 1.8612e-05],
|
117 |
-
# [3.3836e-05, 9.9994e-01, 2.4118e-05],
|
118 |
-
# [4.7997e-05, 1.3496e-02, 9.8646e-01]])
|
119 |
-
#
|
120 |
-
# Audio x Text:
|
121 |
-
# tensor([[1., 0., 0.],
|
122 |
-
# [0., 1., 0.],
|
123 |
-
# [0., 0., 1.]])
|
124 |
-
#
|
125 |
-
# Vision x Audio:
|
126 |
-
# tensor([[0.8070, 0.1088, 0.0842],
|
127 |
-
# [0.1036, 0.7884, 0.1079],
|
128 |
-
# [0.0018, 0.0022, 0.9960]])
|
129 |
-
|
130 |
-
```
|
131 |
-
|
132 |
-
## Model card
|
133 |
-
Please see the [model card](model_card.md) for details.
|
134 |
-
|
135 |
-
## License
|
136 |
-
|
137 |
-
ImageBind code and model weights are released under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for additional details.
|
138 |
-
|
139 |
-
## Contributing
|
140 |
-
|
141 |
-
See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
|
142 |
-
|
143 |
-
## Citing ImageBind
|
144 |
-
|
145 |
-
If you find this repository useful, please consider giving a star :star: and citation
|
146 |
-
|
147 |
-
```
|
148 |
-
@inproceedings{girdhar2023imagebind,
|
149 |
-
title={ImageBind: One Embedding Space To Bind Them All},
|
150 |
-
author={Girdhar, Rohit and El-Nouby, Alaaeldin and Liu, Zhuang
|
151 |
-
and Singh, Mannat and Alwala, Kalyan Vasudev and Joulin, Armand and Misra, Ishan},
|
152 |
-
booktitle={CVPR},
|
153 |
-
year={2023}
|
154 |
-
}
|
155 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/models/ImageBind/bpe/bpe_simple_vocab_16e6.txt.gz
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
-
size 1356917
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/models/ImageBind/data.py
DELETED
@@ -1,338 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
-
# All rights reserved.
|
4 |
-
|
5 |
-
# This source code is licensed under the license found in the
|
6 |
-
# LICENSE file in the root directory of this source tree.
|
7 |
-
|
8 |
-
import logging
|
9 |
-
import math
|
10 |
-
|
11 |
-
import torch
|
12 |
-
import torch.nn as nn
|
13 |
-
import torchaudio
|
14 |
-
from PIL import Image
|
15 |
-
from pytorchvideo import transforms as pv_transforms
|
16 |
-
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
|
17 |
-
from pytorchvideo.data.encoded_video import EncodedVideo
|
18 |
-
from torchvision import transforms
|
19 |
-
from torchvision.transforms._transforms_video import NormalizeVideo
|
20 |
-
|
21 |
-
from .models.multimodal_preprocessors import SimpleTokenizer
|
22 |
-
|
23 |
-
DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
|
24 |
-
|
25 |
-
BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz"
|
26 |
-
|
27 |
-
|
28 |
-
def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
|
29 |
-
# Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
|
30 |
-
waveform -= waveform.mean()
|
31 |
-
fbank = torchaudio.compliance.kaldi.fbank(
|
32 |
-
waveform,
|
33 |
-
htk_compat=True,
|
34 |
-
sample_frequency=sample_rate,
|
35 |
-
use_energy=False,
|
36 |
-
window_type="hanning",
|
37 |
-
num_mel_bins=num_mel_bins,
|
38 |
-
dither=0.0,
|
39 |
-
frame_length=25,
|
40 |
-
frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
|
41 |
-
)
|
42 |
-
# Convert to [mel_bins, num_frames] shape
|
43 |
-
fbank = fbank.transpose(0, 1)
|
44 |
-
# Pad to target_length
|
45 |
-
n_frames = fbank.size(1)
|
46 |
-
p = target_length - n_frames
|
47 |
-
# if p is too large (say >20%), flash a warning
|
48 |
-
if abs(p) / n_frames > 0.2:
|
49 |
-
logging.warning(
|
50 |
-
"Large gap between audio n_frames(%d) and "
|
51 |
-
"target_length (%d). Is the audio_target_length "
|
52 |
-
"setting correct?",
|
53 |
-
n_frames,
|
54 |
-
target_length,
|
55 |
-
)
|
56 |
-
# cut and pad
|
57 |
-
if p > 0:
|
58 |
-
fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
|
59 |
-
elif p < 0:
|
60 |
-
fbank = fbank[:, 0:target_length]
|
61 |
-
# Convert to [1, mel_bins, num_frames] shape, essentially like a 1
|
62 |
-
# channel image
|
63 |
-
fbank = fbank.unsqueeze(0)
|
64 |
-
return fbank
|
65 |
-
|
66 |
-
|
67 |
-
def get_clip_timepoints(clip_sampler, duration):
|
68 |
-
# Read out all clips in this video
|
69 |
-
all_clips_timepoints = []
|
70 |
-
is_last_clip = False
|
71 |
-
end = 0.0
|
72 |
-
while not is_last_clip:
|
73 |
-
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
|
74 |
-
all_clips_timepoints.append((start, end))
|
75 |
-
return all_clips_timepoints
|
76 |
-
|
77 |
-
|
78 |
-
def load_and_transform_vision_data(image_paths, device):
|
79 |
-
if image_paths is None:
|
80 |
-
return None
|
81 |
-
|
82 |
-
image_ouputs = []
|
83 |
-
for image_path in image_paths:
|
84 |
-
data_transform = transforms.Compose(
|
85 |
-
[
|
86 |
-
transforms.Resize(
|
87 |
-
224, interpolation=transforms.InterpolationMode.BICUBIC
|
88 |
-
),
|
89 |
-
transforms.CenterCrop(224),
|
90 |
-
transforms.ToTensor(),
|
91 |
-
transforms.Normalize(
|
92 |
-
mean=(0.48145466, 0.4578275, 0.40821073),
|
93 |
-
std=(0.26862954, 0.26130258, 0.27577711),
|
94 |
-
),
|
95 |
-
]
|
96 |
-
)
|
97 |
-
with open(image_path, "rb") as fopen:
|
98 |
-
image = Image.open(fopen).convert("RGB")
|
99 |
-
|
100 |
-
image = data_transform(image).to(device)
|
101 |
-
image_ouputs.append(image)
|
102 |
-
return torch.stack(image_ouputs, dim=0)
|
103 |
-
|
104 |
-
|
105 |
-
def load_and_transform_text(text, device):
|
106 |
-
if text is None:
|
107 |
-
return None
|
108 |
-
tokenizer = SimpleTokenizer(bpe_path=BPE_PATH)
|
109 |
-
tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text]
|
110 |
-
tokens = torch.cat(tokens, dim=0)
|
111 |
-
return tokens
|
112 |
-
|
113 |
-
|
114 |
-
def load_and_transform_audio_data(
|
115 |
-
audio_paths,
|
116 |
-
device,
|
117 |
-
num_mel_bins=128,
|
118 |
-
target_length=204,
|
119 |
-
sample_rate=16000,
|
120 |
-
clip_duration=2,
|
121 |
-
clips_per_video=3,
|
122 |
-
mean=-4.268,
|
123 |
-
std=9.138,
|
124 |
-
):
|
125 |
-
if audio_paths is None:
|
126 |
-
return None
|
127 |
-
|
128 |
-
audio_outputs = []
|
129 |
-
clip_sampler = ConstantClipsPerVideoSampler(
|
130 |
-
clip_duration=clip_duration, clips_per_video=clips_per_video
|
131 |
-
)
|
132 |
-
|
133 |
-
for audio_path in audio_paths:
|
134 |
-
waveform, sr = torchaudio.load(audio_path)
|
135 |
-
if sample_rate != sr:
|
136 |
-
waveform = torchaudio.functional.resample(
|
137 |
-
waveform, orig_freq=sr, new_freq=sample_rate
|
138 |
-
)
|
139 |
-
all_clips_timepoints = get_clip_timepoints(
|
140 |
-
clip_sampler, waveform.size(1) / sample_rate
|
141 |
-
)
|
142 |
-
all_clips = []
|
143 |
-
for clip_timepoints in all_clips_timepoints:
|
144 |
-
waveform_clip = waveform[
|
145 |
-
:,
|
146 |
-
int(clip_timepoints[0] * sample_rate) : int(
|
147 |
-
clip_timepoints[1] * sample_rate
|
148 |
-
),
|
149 |
-
]
|
150 |
-
waveform_melspec = waveform2melspec(
|
151 |
-
waveform_clip, sample_rate, num_mel_bins, target_length
|
152 |
-
)
|
153 |
-
all_clips.append(waveform_melspec)
|
154 |
-
|
155 |
-
normalize = transforms.Normalize(mean=mean, std=std)
|
156 |
-
all_clips = [normalize(ac).to(device) for ac in all_clips]
|
157 |
-
|
158 |
-
all_clips = torch.stack(all_clips, dim=0)
|
159 |
-
audio_outputs.append(all_clips)
|
160 |
-
|
161 |
-
return torch.stack(audio_outputs, dim=0)
|
162 |
-
|
163 |
-
|
164 |
-
def crop_boxes(boxes, x_offset, y_offset):
|
165 |
-
"""
|
166 |
-
Perform crop on the bounding boxes given the offsets.
|
167 |
-
Args:
|
168 |
-
boxes (ndarray or None): bounding boxes to perform crop. The dimension
|
169 |
-
is `num boxes` x 4.
|
170 |
-
x_offset (int): cropping offset in the x axis.
|
171 |
-
y_offset (int): cropping offset in the y axis.
|
172 |
-
Returns:
|
173 |
-
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
174 |
-
`num boxes` x 4.
|
175 |
-
"""
|
176 |
-
cropped_boxes = boxes.copy()
|
177 |
-
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
|
178 |
-
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
|
179 |
-
|
180 |
-
return cropped_boxes
|
181 |
-
|
182 |
-
|
183 |
-
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
|
184 |
-
"""
|
185 |
-
Perform uniform spatial sampling on the images and corresponding boxes.
|
186 |
-
Args:
|
187 |
-
images (tensor): images to perform uniform crop. The dimension is
|
188 |
-
`num frames` x `channel` x `height` x `width`.
|
189 |
-
size (int): size of height and weight to crop the images.
|
190 |
-
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
|
191 |
-
is larger than height. Or 0, 1, or 2 for top, center, and bottom
|
192 |
-
crop if height is larger than width.
|
193 |
-
boxes (ndarray or None): optional. Corresponding boxes to images.
|
194 |
-
Dimension is `num boxes` x 4.
|
195 |
-
scale_size (int): optinal. If not None, resize the images to scale_size before
|
196 |
-
performing any crop.
|
197 |
-
Returns:
|
198 |
-
cropped (tensor): images with dimension of
|
199 |
-
`num frames` x `channel` x `size` x `size`.
|
200 |
-
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
201 |
-
`num boxes` x 4.
|
202 |
-
"""
|
203 |
-
assert spatial_idx in [0, 1, 2]
|
204 |
-
ndim = len(images.shape)
|
205 |
-
if ndim == 3:
|
206 |
-
images = images.unsqueeze(0)
|
207 |
-
height = images.shape[2]
|
208 |
-
width = images.shape[3]
|
209 |
-
|
210 |
-
if scale_size is not None:
|
211 |
-
if width <= height:
|
212 |
-
width, height = scale_size, int(height / width * scale_size)
|
213 |
-
else:
|
214 |
-
width, height = int(width / height * scale_size), scale_size
|
215 |
-
images = torch.nn.functional.interpolate(
|
216 |
-
images,
|
217 |
-
size=(height, width),
|
218 |
-
mode="bilinear",
|
219 |
-
align_corners=False,
|
220 |
-
)
|
221 |
-
|
222 |
-
y_offset = int(math.ceil((height - size) / 2))
|
223 |
-
x_offset = int(math.ceil((width - size) / 2))
|
224 |
-
|
225 |
-
if height > width:
|
226 |
-
if spatial_idx == 0:
|
227 |
-
y_offset = 0
|
228 |
-
elif spatial_idx == 2:
|
229 |
-
y_offset = height - size
|
230 |
-
else:
|
231 |
-
if spatial_idx == 0:
|
232 |
-
x_offset = 0
|
233 |
-
elif spatial_idx == 2:
|
234 |
-
x_offset = width - size
|
235 |
-
cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
|
236 |
-
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
237 |
-
if ndim == 3:
|
238 |
-
cropped = cropped.squeeze(0)
|
239 |
-
return cropped, cropped_boxes
|
240 |
-
|
241 |
-
|
242 |
-
class SpatialCrop(nn.Module):
|
243 |
-
"""
|
244 |
-
Convert the video into 3 smaller clips spatially. Must be used after the
|
245 |
-
temporal crops to get spatial crops, and should be used with
|
246 |
-
-2 in the spatial crop at the slowfast augmentation stage (so full
|
247 |
-
frames are passed in here). Will return a larger list with the
|
248 |
-
3x spatial crops as well.
|
249 |
-
"""
|
250 |
-
|
251 |
-
def __init__(self, crop_size: int = 224, num_crops: int = 3):
|
252 |
-
super().__init__()
|
253 |
-
self.crop_size = crop_size
|
254 |
-
if num_crops == 3:
|
255 |
-
self.crops_to_ext = [0, 1, 2]
|
256 |
-
self.flipped_crops_to_ext = []
|
257 |
-
elif num_crops == 1:
|
258 |
-
self.crops_to_ext = [1]
|
259 |
-
self.flipped_crops_to_ext = []
|
260 |
-
else:
|
261 |
-
raise NotImplementedError("Nothing else supported yet")
|
262 |
-
|
263 |
-
def forward(self, videos):
|
264 |
-
"""
|
265 |
-
Args:
|
266 |
-
videos: A list of C, T, H, W videos.
|
267 |
-
Returns:
|
268 |
-
videos: A list with 3x the number of elements. Each video converted
|
269 |
-
to C, T, H', W' by spatial cropping.
|
270 |
-
"""
|
271 |
-
assert isinstance(videos, list), "Must be a list of videos after temporal crops"
|
272 |
-
assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
|
273 |
-
res = []
|
274 |
-
for video in videos:
|
275 |
-
for spatial_idx in self.crops_to_ext:
|
276 |
-
res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
|
277 |
-
if not self.flipped_crops_to_ext:
|
278 |
-
continue
|
279 |
-
flipped_video = transforms.functional.hflip(video)
|
280 |
-
for spatial_idx in self.flipped_crops_to_ext:
|
281 |
-
res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
|
282 |
-
return res
|
283 |
-
|
284 |
-
|
285 |
-
def load_and_transform_video_data(
|
286 |
-
video_paths,
|
287 |
-
device,
|
288 |
-
clip_duration=2,
|
289 |
-
clips_per_video=5,
|
290 |
-
sample_rate=16000,
|
291 |
-
):
|
292 |
-
if video_paths is None:
|
293 |
-
return None
|
294 |
-
|
295 |
-
video_outputs = []
|
296 |
-
video_transform = transforms.Compose(
|
297 |
-
[
|
298 |
-
pv_transforms.ShortSideScale(224),
|
299 |
-
NormalizeVideo(
|
300 |
-
mean=(0.48145466, 0.4578275, 0.40821073),
|
301 |
-
std=(0.26862954, 0.26130258, 0.27577711),
|
302 |
-
),
|
303 |
-
]
|
304 |
-
)
|
305 |
-
|
306 |
-
clip_sampler = ConstantClipsPerVideoSampler(
|
307 |
-
clip_duration=clip_duration, clips_per_video=clips_per_video
|
308 |
-
)
|
309 |
-
frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
|
310 |
-
|
311 |
-
for video_path in video_paths:
|
312 |
-
video = EncodedVideo.from_path(
|
313 |
-
video_path,
|
314 |
-
decoder="decord",
|
315 |
-
decode_audio=False,
|
316 |
-
**{"sample_rate": sample_rate},
|
317 |
-
)
|
318 |
-
|
319 |
-
all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
|
320 |
-
|
321 |
-
all_video = []
|
322 |
-
for clip_timepoints in all_clips_timepoints:
|
323 |
-
# Read the clip, get frames
|
324 |
-
clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
|
325 |
-
if clip is None:
|
326 |
-
raise ValueError("No clip found")
|
327 |
-
video_clip = frame_sampler(clip["video"])
|
328 |
-
video_clip = video_clip / 255.0 # since this is float, need 0-1
|
329 |
-
|
330 |
-
all_video.append(video_clip)
|
331 |
-
|
332 |
-
all_video = [video_transform(clip) for clip in all_video]
|
333 |
-
all_video = SpatialCrop(224, num_crops=3)(all_video)
|
334 |
-
|
335 |
-
all_video = torch.stack(all_video, dim=0)
|
336 |
-
video_outputs.append(all_video)
|
337 |
-
|
338 |
-
return torch.stack(video_outputs, dim=0).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sonique/Video_LLaMA/video_llama/models/ImageBind/model_card.md
DELETED
@@ -1,94 +0,0 @@
|
|
1 |
-
# Model Card for ImageBind
|
2 |
-
|
3 |
-
Multimodal joint embedding model for image/video, text, audio, depth, IMU, and thermal images.
|
4 |
-
Input any of the six modalities and get the same sized embedding that can be used for cross-modal and multimodal tasks.
|
5 |
-
|
6 |
-
# Model Details
|
7 |
-
|
8 |
-
## Model Description
|
9 |
-
|
10 |
-
<!-- Provide a longer summary of what this model is/does. -->
|
11 |
-
Multimodal joint embedding model for image/video, text, audio, depth, IMU, and thermal images
|
12 |
-
|
13 |
-
- **Developed by:** Meta AI
|
14 |
-
- **Model type:** Multimodal model
|
15 |
-
- **Language(s) (NLP):** en
|
16 |
-
- **License:** CC BY-NC-SA 4.0
|
17 |
-
- **Resources for more information:**
|
18 |
-
- [GitHub Repo](https://github.com/facebookresearch/ImageBind)
|
19 |
-
|
20 |
-
|
21 |
-
# Uses
|
22 |
-
|
23 |
-
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
24 |
-
This model is intended only for research purposes. It provides a joint embedding space for different modalities -- image/video, text, audio, depth, IMU and thermal images.
|
25 |
-
We hope that these joint embeddings can be used for a variety of different cross-modal research, e.g., cross-modal retrieval and combining embeddings from different modalities.
|
26 |
-
|
27 |
-
## Out-of-Scope Use
|
28 |
-
|
29 |
-
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
30 |
-
<!-- If the user enters content, print that. If not, but they enter a task in the list, use that. If neither, say "more info needed." -->
|
31 |
-
|
32 |
-
This model is *NOT* intended to be used in any real world application -- commercial or otherwise.
|
33 |
-
It may produce harmful associations with different inputs.
|
34 |
-
The model needs to be investigated and likely re-trained on specific data for any such application.
|
35 |
-
The model is expected to work better on web-based visual data since it was trained on such data.
|
36 |
-
The text encoder is likely to work only on English language text because of the underlying training datasets.
|
37 |
-
|
38 |
-
# Bias, Risks, and Limitations
|
39 |
-
|
40 |
-
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
41 |
-
Open-domain joint embedding models are prone to producing specific biases, e.g., study from [CLIP](https://github.com/openai/CLIP/blob/main/model-card.md#bias-and-fairness).
|
42 |
-
Since our model uses such models as initialization, it will exhibit such biases too.
|
43 |
-
Moreover, for learning joint embeddings for other modalities such as audio, thermal, depth, and IMU we leverage datasets that are relatively small. These joint embeddings are thus limited to the concepts present in the datasets. For example, the thermal datasets we used are limited to outdoor street scenes, while the depth datasets are limited to indoor scenes.
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
# Training Details
|
48 |
-
|
49 |
-
## Training Data
|
50 |
-
|
51 |
-
<!-- This should link to a Data Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
52 |
-
|
53 |
-
ImageBind uses image-paired data for training -- (image, X) where X is one of text, audio, depth, IMU or thermal data.
|
54 |
-
In particular, we initialize and freeze the image and text encoders using an OpenCLIP ViT-H encoder.
|
55 |
-
We train audio embeddings using Audioset, depth embeddings using the SUN RGB-D dataset, IMU using the Ego4D dataset and thermal embeddings using the LLVIP dataset.
|
56 |
-
We provide the exact training data details in the paper.
|
57 |
-
|
58 |
-
|
59 |
-
## Training Procedure
|
60 |
-
|
61 |
-
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
62 |
-
Please refer to the research paper and github repo for exact details on this.
|
63 |
-
|
64 |
-
# Evaluation
|
65 |
-
|
66 |
-
## Testing Data, Factors & Metrics
|
67 |
-
|
68 |
-
We evaluate the model on a variety of different classification benchmarks for each modality.
|
69 |
-
The evaluation details are presented in the paper.
|
70 |
-
The models performance is measured using standard classification metrics such as accuracy and mAP.
|
71 |
-
|
72 |
-
# Citation
|
73 |
-
|
74 |
-
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
75 |
-
|
76 |
-
**BibTeX:**
|
77 |
-
```
|
78 |
-
@inproceedings{girdhar2023imagebind,
|
79 |
-
title={ImageBind: One Embedding Space To Bind Them All},
|
80 |
-
author={Girdhar, Rohit and El-Nouby, Alaaeldin and Liu, Zhuang
|
81 |
-
and Singh, Mannat and Alwala, Kalyan Vasudev and Joulin, Armand and Misra, Ishan},
|
82 |
-
booktitle={CVPR},
|
83 |
-
year={2023}
|
84 |
-
}
|
85 |
-
```
|
86 |
-
|
87 |
-
|
88 |
-
# Model Card Contact
|
89 |
-
|
90 |
-
Please reach out to the authors at: [email protected] [email protected] [email protected]
|
91 |
-
|
92 |
-
# How to Get Started with the Model
|
93 |
-
|
94 |
-
Our github repo provides a simple example to extract embeddings from images, audio etc.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|