diff --git "a/Video_Captioning.ipynb" "b/Video_Captioning.ipynb" new file mode 100644--- /dev/null +++ "b/Video_Captioning.ipynb" @@ -0,0 +1,383 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "55df6d0d-71cf-4110-81ed-7c0d3ce58e43", + "metadata": {}, + "source": [ + "## Import" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "0abe9574-05f7-4684-b586-033827b89c32", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "74e70729-b658-4ffd-9d8b-ae42a2d1b212", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "from fairseq import utils, tasks\n", + "from fairseq import checkpoint_utils\n", + "from utils.eval_utils import eval_step\n", + "from tasks.mm_tasks.caption import CaptionTask\n", + "from models.unival import UnIVALModel\n", + "from PIL import Image\n", + "\n", + "import random\n", + "from torchvision.transforms import functional as F\n", + "from torchvision.transforms import InterpolationMode\n", + "\n", + "from matplotlib import pyplot as plt\n", + "\n", + "# turn on cuda if GPU is available\n", + "use_cuda = torch.cuda.is_available()\n", + "# use fp16 only when GPU is available\n", + "use_fp16 = False\n", + "import os " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ce03a870-2852-410e-97c4-59461d08f60a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + ".register_task_cls(cls)>" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Register refcoco task\n", + "tasks.register_task('video_caption', CaptionTask)" + ] + }, + { + "cell_type": "markdown", + "id": "58361680-3e90-4fff-962e-2ff67c1e7289", + "metadata": {}, + "source": [ + "### Load model" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "id": "adb79611-7563-4fb6-a576-f31050f8438e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "self.sample_patch_num 784\n", + "self.sample_audio_patch_num None\n", + "self.sample_video_patch_num None\n", + "self.with_cls False\n", + "Loading: all_resnext101\n", + "use bn: \n", + "load pretrained_model /data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth\n", + "_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc.weight', 'fc.bias'])\n", + "unival\n", + "getattr(args, \"stop_on_max_len\", False) False\n" + ] + } + ], + "source": [ + "# Load pretrained ckpt & config\n", + "\n", + "checkpoint_path = '/data/mshukor/logs/ofa/best_models/unival_video_caption_stage_1/checkpoint_best.pt'\n", + "video_model_path = '/data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth'\n", + "\n", + "overrides={\"eval_cider\":False, \"beam\":5, \"max_len_b\":22, \"no_repeat_ngram_size\":3, \"seed\":7, \"unnormalized\": False,\n", + " \"bpe_dir\":\"utils/BPE\", \"video_model_path\": video_model_path,}\n", + "\n", + "models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(\n", + " utils.split_paths(checkpoint_path),\n", + " arg_overrides=overrides\n", + " )\n", + "\n", + "# Move models to GPU\n", + "for model in models:\n", + " model.eval()\n", + " if use_fp16:\n", + " model.half()\n", + " if use_cuda and not cfg.distributed_training.pipeline_model_parallel:\n", + " model.cuda()\n", + " model.prepare_for_inference_(cfg)\n", + "\n", + "# Initialize generator\n", + "generator = task.build_generator(models, cfg.generation)" + ] + }, + { + "cell_type": "markdown", + "id": "e79aad39-1424-47d5-8cd4-6ab77ea46fb4", + "metadata": {}, + "source": [ + "### Preprocess" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "576a3e84-a6aa-446d-adab-fef9499318fc", + "metadata": {}, + "outputs": [], + "source": [ + "# Image transform\n", + "from torchvision import transforms\n", + "mean = [0.5, 0.5, 0.5]\n", + "std = [0.5, 0.5, 0.5]\n", + "\n", + "\n", + "\n", + "type_transform = transforms.Lambda(lambda x: x.float().div(255.0))\n", + "patch_video_resize_transform = transforms.Compose([\n", + " transforms.CenterCrop(cfg.task.patch_frame_size),\n", + " type_transform, \n", + " transforms.Normalize(mean=mean, std=std),\n", + " ])\n", + "\n", + "# video process\n", + "from data.video_utils import VIDEO_READER_FUNCS\n", + "\n", + "video_reader = VIDEO_READER_FUNCS['decord'] \n", + "\n", + "def process_video(video_path, max_num_frames=16, num_frames=16, sample_type='rand',):\n", + " \n", + " # video \n", + " data_path = os.path.join(video_path)\n", + "\n", + " frames, frame_indices, video_duration = video_reader(\n", + " data_path, num_frames, sample_type, max_num_frames=max_num_frames\n", + " )\n", + "\n", + " patch_video = patch_video_resize_transform(frames)\n", + " patch_video = patch_video.permute(1, 0, 2, 3) # -> (C, T, h, w)\n", + "\n", + " return patch_video.unsqueeze(0)\n", + " \n", + "\n", + "# Text preprocess\n", + "bos_item = torch.LongTensor([task.src_dict.bos()])\n", + "eos_item = torch.LongTensor([task.src_dict.eos()])\n", + "pad_idx = task.src_dict.pad()\n", + "def encode_text(text, length=None, append_bos=False, append_eos=False):\n", + " s = task.tgt_dict.encode_line(\n", + " line=task.bpe.encode(text),\n", + " add_if_not_exist=False,\n", + " append_eos=False\n", + " ).long()\n", + " if length is not None:\n", + " s = s[:length]\n", + " if append_bos:\n", + " s = torch.cat([bos_item, s])\n", + " if append_eos:\n", + " s = torch.cat([s, eos_item])\n", + " return s\n", + "\n", + "# Construct input for caption task\n", + "def construct_sample(video_path):\n", + " \n", + " patch_video = process_video(video_path, max_num_frames=16, num_frames=cfg.task.num_frames, sample_type=cfg.task.sample_type,)\n", + " patch_image = torch.zeros((3, cfg.task.patch_image_size, cfg.task.patch_image_size)) \n", + " \n", + " patch_type = torch.tensor([1])\n", + " patch_mask = torch.tensor([True])\n", + " src_text = encode_text(\" what does the video describe?\", append_bos=True, append_eos=True).unsqueeze(0)\n", + " src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])\n", + " sample = {\n", + " \"id\":np.array(['42']),\n", + " \"net_input\": {\n", + " \"src_tokens\": src_text,\n", + " \"src_lengths\": src_length,\n", + " \"patch_videos\": patch_video,\n", + " \"patch_images\": patch_image,\n", + " \"patch_masks\": patch_mask,\n", + " \"patch_types\": patch_type,\n", + " }\n", + " }\n", + " return sample\n", + " \n", + "# Function to turn FP32 to FP16\n", + "def apply_half(t):\n", + " if t.dtype is torch.float32:\n", + " return t.to(dtype=torch.half)\n", + " return t" + ] + }, + { + "cell_type": "markdown", + "id": "f96f776e-9aa0-4271-b881-311851cc033c", + "metadata": {}, + "source": [ + "### Inference" + ] + }, + { + "cell_type": "code", + "execution_count": 157, + "id": "6f8ddf8c-82e2-411c-baa3-850da02f1996", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 3, 16, 384, 384])\n" + ] + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 157, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "save_dir = '/home/mshukor/ofa_adastra'\n", + "\n", + "\n", + "\n", + "\n", + "video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7019.mp4' # a man is sitting in a chair and talking\n", + "# video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7038.mp4' # a person is cooking something in a pan\n", + "# video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7021.mp4' # a group of people are playing baseball\n", + "# video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7068.mp4' # a man and a woman are talking to each other\n", + "# video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7017.mp4' # a person is playing a video game\n", + "# video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7014.mp4' # a girl is singing on the voice\n", + "\n", + "\n", + "\n", + "# video_path = '/data/mshukor/data/video/msrvtt/examples/video1065.mp4'\n", + "\n", + "# limitations\n", + "video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7055.mp4' # a man is driving a car\n", + "\n", + "\n", + "sample = construct_sample(video_path)\n", + "sample = utils.move_to_cuda(sample) if use_cuda else sample\n", + "sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3690f53b-3594-4d8f-81c8-c8ed0931c00b", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 158, + "id": "4651039c-b8c0-4687-871e-b42cb13b2984", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([1], device='cuda:0')\n", + "torch.Size([1, 2048, 1, 12, 12])\n" + ] + } + ], + "source": [ + "from utils.eval_utils import eval_caption\n", + "\n", + "with torch.no_grad():\n", + " result, scores = eval_caption(task, generator, models, sample)" + ] + }, + { + "cell_type": "code", + "execution_count": 159, + "id": "712150d4-f28c-4538-870f-b33f775725d5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a man is driving a car\n" + ] + } + ], + "source": [ + "caption = result[0]['caption']\n", + "print(caption)\n", + "\n", + "from IPython.display import Video\n", + "Video(video_path, embed=True)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "303d531f-dba3-40b9-a1ff-1be92d8c188a", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2db0cc0-5cd2-48dd-b900-56331d53b1df", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ofa", + "language": "python", + "name": "ofa" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}