Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| from typing import Literal | |
| import torch | |
| from datasets import Dataset as HFDataset | |
| from datasets import DatasetDict, load_from_disk | |
| from mmengine import print_log | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| import numpy as np | |
| from xtuner.registry import BUILDER | |
| from xtuner.dataset.huggingface import build_origin_dataset | |
| import copy | |
| from .encode_fn import video_lisa_encode_fn | |
| import json | |
| import cv2 | |
| import torchvision.transforms as T | |
| from torchvision.transforms.functional import InterpolationMode | |
| from decord import VideoReader, cpu | |
| def _get_rawvideo_dec(video_path, select_frames=5): | |
| if os.path.exists(video_path): | |
| vreader = VideoReader(video_path, ctx=cpu(0)) | |
| elif os.path.exists(video_path.replace('mkv', 'mp4')): | |
| vreader = VideoReader(video_path.replace('mkv', 'mp4'), ctx=cpu(0)) | |
| else: | |
| print(video_path) | |
| raise FileNotFoundError | |
| fps = vreader.get_avg_fps() | |
| f_start = 0 | |
| f_end = len(vreader) - 1 | |
| num_frames = f_end - f_start + 1 | |
| assert num_frames > 0, f'num_frames: {num_frames}, f_start: {f_start}, f_end: {f_end}, fps: {fps}, video_path: {video_path}' | |
| # T x 3 x H x W | |
| if num_frames <= select_frames: | |
| sample_pos = range(f_start, f_end + 1) | |
| else: | |
| split_point = np.linspace(0, num_frames, num=select_frames+1, dtype=int) | |
| sample_pos = [np.random.randint(split_point[i], split_point[i+1]) for i in range(select_frames)] | |
| patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()] | |
| return patch_images | |
| class VideoChatUniViDataset(Dataset): | |
| IMAGENET_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_STD = (0.229, 0.224, 0.225) | |
| IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>' | |
| IMG_START_TOKEN = '<img>' | |
| IMG_END_TOKEN = '</img>' | |
| FAST_IMG_CONTEXT_TOKEN = '<FAST_IMG_CONTEXT>' | |
| FAST_IMG_START_TOKEN = '<fast_img>' | |
| FAST_IMG_END_TOKEN = '</fast_img>' | |
| def __init__(self, | |
| image_folder, | |
| json_file, | |
| extra_image_processor=None, | |
| tokenizer=None, | |
| sampled_frames=10, | |
| offline_processed_text_folder=None, | |
| template_map_fn=None, | |
| max_length=2048, | |
| lazy=True, | |
| repeats=1, | |
| special_tokens=None, | |
| use_fast=False, | |
| n_fast_images=50, | |
| fast_pool_size=4, | |
| arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl', | |
| preprocessor=None, | |
| ): | |
| assert lazy is True | |
| self.tokenizer = BUILDER.build(tokenizer) | |
| self.sampled_frames = sampled_frames | |
| assert offline_processed_text_folder or (json_file and tokenizer) | |
| self.lazy = lazy | |
| self.max_length = max_length | |
| self.template_map_fn = template_map_fn | |
| if isinstance(self.template_map_fn, dict) and self.lazy: | |
| _type = self.template_map_fn['type'] | |
| del self.template_map_fn['type'] | |
| self.template_map_fn = _type(**self.template_map_fn) | |
| if offline_processed_text_folder and json_file: | |
| print_log( | |
| 'Both `offline_processed_text_folder` and ' | |
| '`data_path` are set, and we load dataset from' | |
| '`offline_processed_text_folder` ' | |
| f'({offline_processed_text_folder})', | |
| logger='current', | |
| level=logging.WARNING) | |
| if offline_processed_text_folder is not None: | |
| raise NotImplementedError | |
| else: | |
| json_datas = self.json_file_preprocess(json_file) | |
| self.json_datas = json_datas | |
| json_data = DatasetDict({'train': HFDataset.from_list(json_datas)}) | |
| if self.lazy: | |
| self.text_data = build_origin_dataset(json_data, 'train') | |
| else: | |
| raise NotImplementedError | |
| self.image_folder = image_folder | |
| if extra_image_processor is not None: | |
| self.extra_image_processor = BUILDER.build(extra_image_processor) | |
| self.arch_type = arch_type | |
| if self.arch_type == 'qwen': | |
| self.IMG_CONTEXT_TOKEN = '<|image_pad|>' | |
| self.IMG_START_TOKEN = '<|vision_start|>' | |
| self.IMG_END_TOKEN = '<|vision_end|>' | |
| elif self.arch_type == 'llava': | |
| self.IMG_CONTEXT_TOKEN = '<image>' | |
| self.IMG_START_TOKEN = '' | |
| self.IMG_END_TOKEN = '' | |
| self.repeats = repeats | |
| self._system = '' | |
| self.downsample_ratio = 0.5 | |
| if self.arch_type == 'llava': | |
| self.downsample_ratio = 1 | |
| self.image_size = 448 | |
| if self.arch_type == 'llava': | |
| self.image_size = 336 | |
| patch_size = 14 | |
| self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2)) | |
| if self.arch_type == 'qwen': | |
| self.patch_token = 1 | |
| if preprocessor is None: | |
| self.transformer = T.Compose([ | |
| T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), | |
| T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC), | |
| T.ToTensor(), | |
| T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD) | |
| ]) | |
| self.preprocessor = None | |
| else: | |
| self.transformer = None | |
| self.preprocessor = BUILDER.build(preprocessor) | |
| self.arch_type = arch_type | |
| if special_tokens is not None: | |
| self.tokenizer.add_tokens(special_tokens, special_tokens=True) | |
| self.use_fast = use_fast | |
| self.n_fast_images = n_fast_images | |
| self.fast_pool_size = fast_pool_size | |
| # for visualization debug | |
| self.save_folder = './work_dirs/video_debug/' | |
| self.cur_number = 0 | |
| print("Video Chat dataset, include {} items.".format(len(self.text_data))) | |
| def __len__(self): | |
| return len(self.text_data) * self.repeats | |
| def modality_length(self): | |
| length_list = [] | |
| for data_dict in self.text_data: | |
| cur_len = 10000 | |
| length_list.append(cur_len) | |
| return length_list | |
| def real_len(self): | |
| return len(self.text_data) | |
| def json_file_preprocess(self, json_file): | |
| # prepare expression annotation files | |
| with open(json_file, 'r') as f: | |
| json_datas = json.load(f) | |
| return json_datas | |
| def dataset_map_fn(self, data_dict, select_k=5): | |
| assert 'video' in data_dict | |
| # video | |
| video_file = data_dict['video'] | |
| video_file = os.path.join(self.image_folder, video_file) | |
| images = _get_rawvideo_dec(video_file, select_frames=select_k) | |
| if self.use_fast: | |
| fast_images = _get_rawvideo_dec(video_file, select_frames=self.n_fast_images) | |
| else: | |
| fast_images = None | |
| conversation = data_dict['conversations'] | |
| # prepare text | |
| if self.use_fast: | |
| text_dict = self.prepare_text( | |
| select_k, conversation, num_image_tokens=self.patch_token, | |
| n_fast_images=len(fast_images), | |
| ) | |
| else: | |
| text_dict = self.prepare_text( | |
| select_k, conversation, num_image_tokens=self.patch_token, | |
| ) | |
| ret = {'images': images, 'conversation': text_dict['conversation'], 'fast_images': fast_images} | |
| return ret | |
| def prepare_text(self, n_frames, conversation, num_image_tokens=256, n_fast_images=0): | |
| if self.use_fast: | |
| fast_frame_token_str = f'{self.FAST_IMG_START_TOKEN}' \ | |
| f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_images * self.fast_pool_size * self.fast_pool_size}' \ | |
| f'{self.FAST_IMG_END_TOKEN}' + '\n' | |
| else: | |
| fast_frame_token_str = '' | |
| frame_token_str = f'{self.IMG_START_TOKEN}' \ | |
| f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \ | |
| f'{self.IMG_END_TOKEN}' | |
| questions = [] | |
| answers = [] | |
| for conv in conversation: | |
| if conv['from'] == 'human': | |
| questions.append(conv['value'].replace('<image>', '')) | |
| else: | |
| answers.append(conv['value']) | |
| assert len(questions) == len(answers) | |
| qa_list = [] | |
| for i, (question, answer) in enumerate(zip(questions, answers)): | |
| if i == 0: | |
| frame_tokens = frame_token_str + '\n' | |
| # frame_tokens = '=' + ' ' | |
| frame_tokens = frame_tokens * n_frames | |
| frame_tokens = frame_tokens.strip() | |
| frame_tokens = fast_frame_token_str + frame_tokens | |
| qa_list.append( | |
| {'from': 'human', 'value': frame_tokens + question} | |
| ) | |
| else: | |
| qa_list.append( | |
| {'from': 'human', 'value': question} | |
| ) | |
| qa_list.append( | |
| {'from': 'gpt', 'value': answer} | |
| ) | |
| input = '' | |
| conversation = [] | |
| for msg in qa_list: | |
| if msg['from'] == 'human': | |
| input += msg['value'] | |
| elif msg['from'] == 'gpt': | |
| conversation.append({'input': input, 'output': msg['value']}) | |
| input = '' | |
| else: | |
| raise NotImplementedError | |
| # add system information | |
| conversation[0].update({'system': self._system}) | |
| return {'conversation': conversation} | |
| def __getitem__(self, index): | |
| index = index % self.real_len() | |
| selected_data_dict = copy.deepcopy(self.text_data[index]) | |
| data_dict = self.dataset_map_fn(selected_data_dict, select_k=self.sampled_frames) | |
| assert 'images' in data_dict.keys() | |
| if self.use_fast: | |
| assert 'fast_images' in data_dict.keys() | |
| pixel_values = [] | |
| num_video_tokens = None | |
| num_frame_tokens = None | |
| if data_dict.get('images', None) is not None: | |
| frames_files = data_dict['images'] | |
| for frame_image in frames_files: | |
| frame_image = frame_image.convert('RGB') | |
| ori_width, ori_height = frame_image.size | |
| if self.preprocessor is not None: | |
| pass | |
| else: | |
| frame_image = self.transformer(frame_image) | |
| pixel_values.append(frame_image) | |
| if self.preprocessor is not None: | |
| if self.arch_type == 'qwen': | |
| _data_dict = self.preprocessor(pixel_values, do_resize=True, size=(self.image_size, self.image_size)) | |
| _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float) | |
| _data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int) | |
| num_frame_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2)) | |
| num_frames = _data_dict['image_grid_thw'].shape[0] | |
| num_video_tokens = num_frame_tokens * num_frames | |
| elif self.arch_type == 'llava': | |
| _data_dict = self.preprocessor(pixel_values, do_resize=True, | |
| size=(self.image_size, self.image_size)) | |
| _data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0) | |
| _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float) | |
| else: | |
| raise NotImplementedError | |
| data_dict.update(_data_dict) | |
| else: | |
| pixel_values = torch.stack(pixel_values, dim=0) # (n_f, 3, h, w) | |
| data_dict['pixel_values'] = pixel_values | |
| else: | |
| data_dict['pixel_values'] = torch.zeros(0, 3, self.image_size, self.image_size) | |
| data_dict['masks'] = None | |
| if num_video_tokens is not None: | |
| assert self.patch_token == 1 | |
| input_str = data_dict['conversation'][0]['input'] | |
| input_str = input_str.replace(self.IMG_CONTEXT_TOKEN, self.IMG_CONTEXT_TOKEN * num_frame_tokens) | |
| assert input_str.count(self.IMG_CONTEXT_TOKEN) == num_video_tokens | |
| data_dict['conversation'][0]['input'] = input_str | |
| result = self.template_map_fn(data_dict) | |
| data_dict.update(result) | |
| result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True) | |
| data_dict.update(result) | |
| # for fast branch | |
| if self.use_fast: | |
| fast_pixel_values = [] | |
| frames_files = data_dict['fast_images'] | |
| for frame_image in frames_files: | |
| frame_image = frame_image.convert('RGB') | |
| ori_width, ori_height = frame_image.size | |
| frame_image = self.transformer(frame_image) | |
| fast_pixel_values.append(frame_image) | |
| fast_pixel_values = torch.stack(fast_pixel_values, dim=0) # (n_f, 3, h, w) | |
| data_dict['fast_pixel_values'] = fast_pixel_values | |
| # # for debug | |
| # self.visualization_debug(data_dict) | |
| # if self.cur_number < 10: | |
| # return self[random.randint(0, len(self))] | |
| data_dict['type'] = 'video' | |
| return data_dict | |
| def visualization_debug(self, data_dict): | |
| save_folder = os.path.join(self.save_folder, 'sample_{}'.format(self.cur_number)) | |
| if not os.path.exists(save_folder): | |
| os.mkdir(save_folder) | |
| self.cur_number += 1 | |
| # images | |
| show_images = [] | |
| pixel_values = data_dict['pixel_values'] | |
| save_folder_image = os.path.join(save_folder, 'image') | |
| if not os.path.exists(save_folder_image): | |
| os.mkdir(save_folder_image) | |
| for i_image, image_pixel_value in enumerate(pixel_values): | |
| # print(image_pixel_value.shape) | |
| image_pixel_value[0] = image_pixel_value[0] * 0.2686 | |
| image_pixel_value[1] = image_pixel_value[1] * 0.2613 | |
| image_pixel_value[2] = image_pixel_value[2] * 0.2757 | |
| image_pixel_value[0] = image_pixel_value[0] + 0.4814 | |
| image_pixel_value[1] = image_pixel_value[1] + 0.4578 | |
| image_pixel_value[2] = image_pixel_value[2] + 0.4082 | |
| image_pixel_value = image_pixel_value * 255 | |
| image_pixel_value = image_pixel_value.permute(1, 2, 0) | |
| image_pixel_value = image_pixel_value.to(torch.uint8).numpy() | |
| # print(os.path.join(save_folder_image, '{}.jpg'.format(i_image))) | |
| # print(image_pixel_value.shape) | |
| show_images.append(image_pixel_value) | |
| cv2.imwrite(os.path.join(save_folder_image, '{}.jpg'.format(i_image)), image_pixel_value) | |
| # text | |
| input_text = self.tokenizer.decode(data_dict['input_ids'], skip_special_tokens=False) | |
| with open(os.path.join(save_folder, 'text.json'), 'w') as f: | |
| json.dump([input_text], f) | |
| return | |