File size: 2,588 Bytes
4d3f70b 8fd2d97 4d3f70b 0d42fed 8fd2d97 4d3f70b 8a5492d 8fd2d97 5fa4f12 8fd2d97 3e47f96 8fd2d97 1e61b7e 8fd2d97 b3e7edc 8fd2d97 b3e7edc 8fd2d97 5302eb3 8fd2d97 b3e7edc 8fd2d97 b3e7edc 8fd2d97 b3e7edc 8fd2d97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
---
language:
- en
tags:
- pytorch_model_hub_mixin
- animation
- video-frame-interpolation
- uncertainty-estimation
license: mit
pipeline_tag: image-to-image
---
# 🤖 Multi‑Input ResShift Diffusion VFI
<div align="left" style="display: flex; flex-direction: row; gap: 15px">
<a href='https://arxiv.org/pdf/2504.05402'><img src='https://img.shields.io/badge/arXiv-2405.17933-b31b1b.svg'></a>
<a href='https://github.com/VicFonch/Multi-Input-Resshift-Diffusion-VFI'><img src='https://img.shields.io/badge/Repo-Code-blue'></a>
<a href='https://colab.research.google.com/drive/1MGYycbNMW6Mxu5MUqw_RW_xxiVeHK5Aa#scrollTo=EKaYCioiP3tQ'><img src='https://img.shields.io/badge/Colab-Demo-Green'></a>
<a href='https://huggingface.co/spaces/vfontech/Multi-Input-Res-Diffusion-VFI'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face%20Space-Demo-g'></a>
</div>
## ⚙️ Setup
Start by downloading the source code directly from GitHub.
```bash
git clone https://github.com/VicFonch/Multi-Input-Resshift-Diffusion-VFI.git
```
Create a conda environment and install all the requirements
```bash
conda create -n multi-input-resshift python=3.12
conda activate multi-input-resshift
pip install -r requirements.txt
```
**Note**: Make sure your system is compatible with **CUDA 12.4**. If not, install [CuPy](https://docs.cupy.dev/en/stable/install.html) according to your current CUDA version.
## 🚀 Inference Example
```python
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from torchvision.transforms import Compose, ToTensor, Resize, Normalize
from utils.utils import denorm
from model.hub import MultiInputResShiftHub
model = MultiInputResShiftHub.from_pretrained("vfontech/Multiple-Input-Resshift-VFI").cuda()
model.eval()
img0_path = r"_data\example_images\frame1.png"
img2_path = r"_data\example_images\frame3.png"
mean = std = [0.5]*3
transforms = Compose([
Resize((256, 448)),
ToTensor(),
Normalize(mean=mean, std=std),
])
img0 = transforms(Image.open(img0_path).convert("RGB")).unsqueeze(0).cuda()
img2 = transforms(Image.open(img2_path).convert("RGB")).unsqueeze(0).cuda()
tau = 0.5
img1 = model.reverse_process([img0, img2], tau)
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.imshow(denorm(img0, mean=mean, std=std).squeeze().permute(1, 2, 0).cpu().numpy())
plt.subplot(1, 3, 2)
plt.imshow(denorm(img1, mean=mean, std=std).squeeze().permute(1, 2, 0).cpu().numpy())
plt.subplot(1, 3, 3)
plt.imshow(denorm(img2, mean=mean, std=std).squeeze().permute(1, 2, 0).cpu().numpy())
plt.show()
``` |