| import cv2 | |
| import glob | |
| import torch | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from networks.amts import Model as AMTS | |
| from networks.amtl import Model as AMTL | |
| from networks.amtg import Model as AMTG | |
| from utils import img2tensor, tensor2img, InputPadder | |
| device = torch.device('cpu' if torch.cuda.is_available() else 'cpu') | |
| model_dict = { | |
| 'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG | |
| } | |
| def img2vid(model_type, img0, img1, frame_ratio, iters): | |
| model = model_dict[model_type]() | |
| model.to(device) | |
| ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth') | |
| ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) | |
| model.load_state_dict(ckpt['state_dict']) | |
| model.eval() | |
| img0_t = img2tensor(img0).to(device) | |
| img1_t = img2tensor(img1).to(device) | |
| padder = InputPadder(img0_t.shape, 16) | |
| img0_t, img1_t = padder.pad(img0_t, img1_t) | |
| inputs = [img0_t, img1_t] | |
| embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) | |
| for i in range(iters): | |
| print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}') | |
| outputs = [img0_t] | |
| for in_0, in_1 in zip(inputs[:-1], inputs[1:]): | |
| with torch.no_grad(): | |
| imgt_pred = model(in_0, in_1, embt, eval=True)['imgt_pred'] | |
| imgt_pred = padder.unpad(imgt_pred) | |
| in_1 = padder.unpad(in_1) | |
| outputs += [imgt_pred, in_1] | |
| inputs = outputs | |
| out_path = 'results' | |
| size = outputs[0].shape[2:][::-1] | |
| writer = cv2.VideoWriter(f'{out_path}/demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), frame_ratio, size) | |
| for i, imgt_pred in enumerate(outputs): | |
| imgt_pred = tensor2img(imgt_pred) | |
| imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR) | |
| writer.write(imgt_pred) | |
| writer.release() | |
| return 'results/demo.mp4' | |
| def demo_img(): | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| gr.Markdown('## Image Demo') | |
| with gr.Row(): | |
| gr.HTML( | |
| """ | |
| <div style="text-align: left; auto;"> | |
| <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem"> | |
| Description: With 2 input images, you can generate a short video from them. | |
| </h3> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| img0 = gr.Image(label='Image0') | |
| img1 = gr.Image(label='Image1') | |
| with gr.Column(): | |
| result = gr.Video(label="Generated Video") | |
| with gr.Accordion('Advanced options', open=False): | |
| ratio = gr.Slider(label='Multiple Ratio', | |
| minimum=4, | |
| maximum=7, | |
| value=6, | |
| step=1) | |
| frame_ratio = gr.Slider(label='Frame Ratio', | |
| minimum=8, | |
| maximum=64, | |
| value=16, | |
| step=1) | |
| model_type = gr.Radio(['AMT-S', 'AMT-L', 'AMT-G'], | |
| label='Model Select', | |
| value='AMT-S') | |
| run_button = gr.Button(label='Run') | |
| inputs = [ | |
| model_type, | |
| img0, | |
| img1, | |
| frame_ratio, | |
| ratio, | |
| ] | |
| gr.Examples(examples=glob.glob("examples/*.png"), | |
| inputs=img0, | |
| label='Example images (drag them to input windows)', | |
| run_on_click=False, | |
| ) | |
| run_button.click(fn=img2vid, | |
| inputs=inputs, | |
| outputs=result,) | |
| return demo |