Spaces:
Running
Running
# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License. | |
import random | |
from functools import partial | |
import clip | |
import decord | |
import gradio as gr | |
import nncore | |
import numpy as np | |
import torch | |
import torchvision.transforms.functional as F | |
from decord import VideoReader | |
from nncore.engine import load_checkpoint | |
from nncore.nn import build_model | |
import pandas as pd | |
TITLE = '๐R2-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding' | |
TITLE_MD = '<h1 align="center">๐R<sup>2</sup>-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding</h1>' | |
DESCRIPTION_MD = 'R<sup>2</sup>-Tuning is a parameter- and memory-efficient transfer learning method for video temporal grounding. Please find more details in our <a href="https://arxiv.org/abs/2404.00801" target="_blank">Tech Report</a> and <a href="https://github.com/yeliudev/R2-Tuning" target="_blank">GitHub Repo</a>.' | |
GUIDE_MD = '### User Guide:\n1. Upload a video or click "random" to sample one.\n2. Input a text query. A good practice is to write a sentence with 5~15 words.\n3. Click "submit" and you\'ll see the moment retrieval and highlight detection results on the right.' | |
CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py' | |
WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/r2_tuning_qvhighlights-ed516355.pth' | |
# yapf:disable | |
EXAMPLES = [ | |
('data/gTAvxnQtjXM_60.0_210.0.mp4', 'A man in a white t shirt wearing a backpack is showing a nearby cathedral.'), | |
('data/pA6Z-qYhSNg_210.0_360.0.mp4', 'Different Facebook posts on transgender bathrooms are shown.'), | |
('data/CkWOpyrAXdw_210.0_360.0.mp4', 'Indian girl cleaning her kitchen before cooking.'), | |
('data/ocLUzCNodj4_360.0_510.0.mp4', 'A woman stands in her bedroom in front of a mirror and talks.'), | |
('data/HkLfNhgP0TM_660.0_810.0.mp4', 'Woman lays down on the couch while talking to the camera.') | |
] | |
# yapf:enable | |
def convert_time(seconds): | |
minutes, seconds = divmod(round(max(seconds, 0)), 60) | |
return f'{minutes:02d}:{seconds:02d}' | |
def load_video(video_path, cfg): | |
decord.bridge.set_bridge('torch') | |
vr = VideoReader(video_path) | |
stride = vr.get_avg_fps() / cfg.data.val.fps | |
fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()] | |
video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255 | |
size = 336 if '336px' in cfg.model.arch else 224 | |
h, w = video.size(-2), video.size(-1) | |
s = min(h, w) | |
x, y = round((h - s) / 2), round((w - s) / 2) | |
video = video[..., x:x + s, y:y + s] | |
video = F.resize(video, size=(size, size)) | |
video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276)) | |
video = video.reshape(video.size(0), -1).unsqueeze(0) | |
return video | |
def init_model(config, checkpoint): | |
cfg = nncore.Config.from_file(config) | |
cfg.model.init = True | |
if checkpoint.startswith('http'): | |
checkpoint = nncore.download(checkpoint, out_dir='checkpoints') | |
model = build_model(cfg.model, dist=False).eval() | |
model = load_checkpoint(model, checkpoint, warning=False) | |
return model, cfg | |
def main(video, query, model, cfg): | |
if len(query) == 0: | |
raise gr.Error('Text query can not be empty.') | |
try: | |
video = load_video(video, cfg) | |
except Exception: | |
raise gr.Error('Failed to load the video.') | |
query = clip.tokenize(query, truncate=True) | |
device = next(model.parameters()).device | |
data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps]) | |
with torch.inference_mode(): | |
pred = model(data) | |
mr = pred['_out']['boundary'][:5].cpu().tolist() | |
mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr] | |
hd = pred['_out']['saliency'].cpu() | |
hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist() | |
hd = pd.DataFrame(dict(x=range(0, len(hd) * 2, 2), y=hd)) | |
return mr, hd | |
model, cfg = init_model(CONFIG, WEIGHT) | |
fn = partial(main, model=model, cfg=cfg) | |
with gr.Blocks(title=TITLE) as demo: | |
gr.Markdown(TITLE_MD) | |
gr.Markdown(DESCRIPTION_MD) | |
gr.Markdown(GUIDE_MD) | |
with gr.Row(): | |
with gr.Column(): | |
video = gr.Video(label='Video') | |
query = gr.Textbox(label='Text Query') | |
with gr.Row(): | |
random_btn = gr.Button(value='๐ฎ Random') | |
gr.ClearButton([video, query], value='๐๏ธ Reset') | |
submit_btn = gr.Button(value='๐ Submit') | |
with gr.Column(): | |
mr = gr.DataFrame( | |
headers=['Start Time', 'End Time', 'Score'], label='Moment Retrieval') | |
hd = gr.LinePlot( | |
x='x', | |
y='y', | |
x_title='Time (seconds)', | |
y_title='Saliency Score', | |
label='Highlight Detection') | |
random_btn.click(lambda: random.sample(EXAMPLES, 1)[0], None, [video, query]) | |
submit_btn.click(fn, [video, query], [mr, hd]) | |
demo.launch() | |