Upload 23 files
Browse files- .gitattributes +8 -0
- Colab_demo.ipynb +127 -0
- LICENSE +21 -0
- README.md +129 -13
- __pycache__/video_processing.cpython-312.pyc +0 -0
- app.py +294 -0
- demo/I0_0.png +3 -0
- demo/I0_1.png +3 -0
- demo/I0_slomo_clipped.gif +3 -0
- demo/I2_0.png +3 -0
- demo/I2_1.png +3 -0
- demo/I2_slomo_clipped.gif +3 -0
- demo/i0.png +3 -0
- demo/i1.png +3 -0
- inference_img.py +118 -0
- inference_img_SR.py +69 -0
- inference_video.py +290 -0
- inference_video_enhance.py +201 -0
- model/loss.py +128 -0
- model/pytorch_msssim/__init__.py +200 -0
- model/pytorch_msssim/__pycache__/__init__.cpython-312.pyc +0 -0
- model/warplayer.py +22 -0
- requirements.txt +7 -0
- video_processing.py +235 -0
.gitattributes
CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
demo/I0_0.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
demo/I0_1.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
demo/I0_slomo_clipped.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
demo/i0.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
demo/i1.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
demo/I2_0.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
demo/I2_1.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
demo/I2_slomo_clipped.gif filter=lfs diff=lfs merge=lfs -text
|
Colab_demo.ipynb
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"name": "Untitled0.ipynb",
|
7 |
+
"provenance": [],
|
8 |
+
"include_colab_link": true
|
9 |
+
},
|
10 |
+
"kernelspec": {
|
11 |
+
"name": "python3",
|
12 |
+
"display_name": "Python 3"
|
13 |
+
},
|
14 |
+
"accelerator": "GPU"
|
15 |
+
},
|
16 |
+
"cells": [
|
17 |
+
{
|
18 |
+
"cell_type": "markdown",
|
19 |
+
"metadata": {
|
20 |
+
"id": "view-in-github",
|
21 |
+
"colab_type": "text"
|
22 |
+
},
|
23 |
+
"source": [
|
24 |
+
"<a href=\"https://colab.research.google.com/github/hzwer/Practical-RIFE/blob/main/Colab_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"cell_type": "code",
|
29 |
+
"metadata": {
|
30 |
+
"id": "FypCcZkNNt2p"
|
31 |
+
},
|
32 |
+
"source": [
|
33 |
+
"%cd /content\n",
|
34 |
+
"!git clone https://github.com/hzwer/Practical-RIFE"
|
35 |
+
],
|
36 |
+
"execution_count": null,
|
37 |
+
"outputs": []
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"metadata": {
|
42 |
+
"id": "1wysVHxoN54f"
|
43 |
+
},
|
44 |
+
"source": [
|
45 |
+
"!gdown --id 1O5KfS3KzZCY3imeCr2LCsntLhutKuAqj\n",
|
46 |
+
"!7z e Practical-RIFE/RIFE_trained_model_v3.8.zip"
|
47 |
+
],
|
48 |
+
"execution_count": null,
|
49 |
+
"outputs": []
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"cell_type": "code",
|
53 |
+
"metadata": {
|
54 |
+
"id": "AhbHfRBJRAUt"
|
55 |
+
},
|
56 |
+
"source": [
|
57 |
+
"!mkdir /content/Practical-RIFE/train_log\n",
|
58 |
+
"!mv *.py /content/Practical-RIFE/train_log/\n",
|
59 |
+
"!mv *.pkl /content/Practical-RIFE/train_log/\n",
|
60 |
+
"%cd /content/Practical-RIFE/\n",
|
61 |
+
"!gdown --id 1i3xlKb7ax7Y70khcTcuePi6E7crO_dFc\n",
|
62 |
+
"!pip3 install -r requirements.txt"
|
63 |
+
],
|
64 |
+
"execution_count": null,
|
65 |
+
"outputs": []
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "markdown",
|
69 |
+
"metadata": {
|
70 |
+
"id": "rirngW5uRMdg"
|
71 |
+
},
|
72 |
+
"source": [
|
73 |
+
"Please upload your video to content/Practical-RIFE/video.mp4, or use our demo video."
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "code",
|
78 |
+
"metadata": {
|
79 |
+
"id": "dnLn4aHHPzN3"
|
80 |
+
},
|
81 |
+
"source": [
|
82 |
+
"!nvidia-smi\n",
|
83 |
+
"!python3 inference_video.py --exp=1 --video=demo.mp4 --montage --skip"
|
84 |
+
],
|
85 |
+
"execution_count": null,
|
86 |
+
"outputs": []
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "markdown",
|
90 |
+
"metadata": {
|
91 |
+
"id": "77KK6lxHgJhf"
|
92 |
+
},
|
93 |
+
"source": [
|
94 |
+
"Our demo.mp4 is 25FPS. You can adjust the parameters for your own perference.\n",
|
95 |
+
"For example: \n",
|
96 |
+
"--fps=60 --exp=1 --video=mydemo.avi --png"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"metadata": {
|
102 |
+
"id": "0zIBbVE3UfUD",
|
103 |
+
"cellView": "code"
|
104 |
+
},
|
105 |
+
"source": [
|
106 |
+
"from IPython.display import display, Image\n",
|
107 |
+
"import moviepy.editor as mpy\n",
|
108 |
+
"display(mpy.ipython_display('demo_4X_100fps.mp4', height=256, max_duration=100.))"
|
109 |
+
],
|
110 |
+
"execution_count": null,
|
111 |
+
"outputs": []
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"cell_type": "code",
|
115 |
+
"metadata": {
|
116 |
+
"id": "tWkJCNgP3zXA"
|
117 |
+
},
|
118 |
+
"source": [
|
119 |
+
"!python3 inference_img.py --img demo/I0_0.png demo/I0_1.png\n",
|
120 |
+
"ffmpeg -r 10 -f image2 -i output/img%d.png -s 448x256 -vf \"split[s0][s1];[s0]palettegen=stats_mode=single[p];[s1][p]paletteuse=new=1\" output/slomo.gif\n",
|
121 |
+
"# Image interpolation"
|
122 |
+
],
|
123 |
+
"execution_count": null,
|
124 |
+
"outputs": []
|
125 |
+
}
|
126 |
+
]
|
127 |
+
}
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 hzwer
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,13 +1,129 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Practical-RIFE
|
2 |
+
**[V4.0 Anime Demo Video](https://www.bilibili.com/video/BV1J3411t7qT?p=1&share_medium=iphone&share_plat=ios&share_session_id=7AE3DA72-D05C-43A0-9838-E2A80885BD4E&share_source=QQ&share_tag=s_i×tamp=1639643780&unique_k=rjqO0EK)** | **[迭代经验](https://zhuanlan.zhihu.com/p/721430631)** | **[迭代QA](https://github.com/hzwer/Practical-RIFE/issues/124)** | **[Colab](https://colab.research.google.com/drive/1BZmGSq15O4ZU5vPfzkv7jFNYahTm6qwT?usp=sharing)**
|
3 |
+
|
4 |
+
This project is based on [RIFE](https://github.com/hzwer/arXiv2020-RIFE) and [SAFA](https://github.com/megvii-research/WACV2024-SAFA). We aim to enhance their practicality for users by incorporating various features and designing new models. Since improving the PSNR index is not consistent with subjective perception. This project is intended for engineers and developers. For general users, we recommend the following software:
|
5 |
+
|
6 |
+
**[SVFI (中文)](https://github.com/YiWeiHuang-stack/Squirrel-Video-Frame-Interpolation) | [RIFE-App](https://grisk.itch.io/rife-app) | [FlowFrames](https://nmkd.itch.io/flowframes)**
|
7 |
+
|
8 |
+
Thanks to [SVFI team](https://github.com/Justin62628/Squirrel-RIFE) to support model testing on Animation.
|
9 |
+
|
10 |
+
[VapourSynth-RIFE](https://github.com/HolyWu/vs-rife) | [RIFE-ncnn-vulkan](https://github.com/nihui/rife-ncnn-vulkan) | [VapourSynth-RIFE-ncnn-Vulkan](https://github.com/styler00dollar/VapourSynth-RIFE-ncnn-Vulkan) | [vs-mlrt](https://github.com/AmusementClub/vs-mlrt) | [Drop frame fixer and FPS converter](https://github.com/may-son/RIFE-FixDropFrames-and-ConvertFPS)
|
11 |
+
|
12 |
+
## Frame Interpolation
|
13 |
+
2024.08 - We find that 4.24+ is quite suitable for post-processing of [some diffusion model generated videos](https://drive.google.com/drive/folders/1hSzUn10Era3JCaVz0Z5Eg4wT9R6eJ9U9?usp=sharing).
|
14 |
+
|
15 |
+
### Trained Model
|
16 |
+
The content of these links is under the same MIT license as this project. **lite** means using similar training framework, but lower computational cost model.
|
17 |
+
Currently, it is recommended to choose 4.25 by default for most scenes.
|
18 |
+
|
19 |
+
4.26 - 2024.09.21 | [Google Drive](https://drive.google.com/file/d/1gViYvvQrtETBgU1w8axZSsr7YUuw31uy/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/1EZsG3IFO8C1e2uRVb_Npgg?pwd=smw8) || [4.25.lite - 2024.10.20](https://drive.google.com/file/d/1zlKblGuKNatulJNFf5jdB-emp9AqGK05/view?usp=share_link)
|
20 |
+
|
21 |
+
4.25 - 2024.09.19 | [Google Drive](https://drive.google.com/file/d/1ZKjcbmt1hypiFprJPIKW0Tt0lr_2i7bg/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/1rpUX5uawusz2uwEdXtjRbw?pwd=mo6k) | I am trying using more flow blocks, so the scale_list will change accordingly. It seems that the anime scenes have been significantly improved.
|
22 |
+
|
23 |
+
4.22 - 2024.08.08 | [Google Drive](https://drive.google.com/file/d/1qh2DSA9a1eZUTtZG9U9RQKO7N7OaUJ0_/view?usp=share_link) [百度网盘](https://pan.baidu.com/s/1EA5BIHqOu35Rj4meg00G4g?pwd=hwym) || [4.22.lite](https://drive.google.com/file/d/1Smy6gY7BkS_RzCjPCbMEy-TsX8Ma5B0R/view?usp=sharing) || 4.21 - 2024.08.04 | [Google Drive](https://drive.google.com/file/d/1l5u6G8vEkPAT7cYYWwzB6OG8vwBYrxiS/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/1TMjRFOwdLgsShKdGbTKW_g?pwd=4q6d)
|
24 |
+
|
25 |
+
4.20 - 2024.07.24 | [Google Drive](https://drive.google.com/file/d/11n3YR7-qCRZm9RDdwtqOTsgCJUHPuexA/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/1v0b7ZTSj_VvLOfW-hQ_NZQ?pwd=ykkv)
|
26 |
+
|| 4.18 - 2024.07.03 | [Google Drive](https://drive.google.com/file/d/1octn-UVuEjXa_HlsIUbNeLTTvYCKbC_s/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/1fqtxJyXSgUx-gE3rieuKxg?pwd=udr1)
|
27 |
+
|
28 |
+
4.17 - 2024.05.24 | [Google Drive](https://drive.google.com/file/d/1962p_lEWo_kLTEynarNaRYRNVdaiQG2k/view?usp=share_link) [百度网盘](https://pan.baidu.com/s/1bMzTYoJKZXsoxuSBmzj6VQ?pwd=as37) : Add gram loss from [FILM](https://github.com/google-research/frame-interpolation/blob/69f8708f08e62c2edf46a27616a4bfcf083e2076/losses/vgg19_loss.py) || [4.17.lite](https://drive.google.com/file/d/1e9Qb4rm20UAsO7h9VILDwrpvTSHWWW8b/view?usp=share_link)
|
29 |
+
|
30 |
+
4.15 - 2024.03.11 | [Google Drive](https://drive.google.com/file/d/1xlem7cfKoMaiLzjoeum8KIQTYO-9iqG5/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/1IGNIX7JXGUwI_tfoafYHqA?pwd=bg0b) || [4.15.lite](https://drive.google.com/file/d/1BoOF-qSEnTPDjpKG1sBTa6k7Sv5_-k7z/view?usp=sharing) || 4.14 - 2024.01.08 | [Google Drive](https://drive.google.com/file/d/1BjuEY7CHZv1wzmwXSQP9ZTj0mLWu_4xy/view?usp=share_link) [百度网盘](https://pan.baidu.com/s/1d-W64lRsJTqNsgWoXYiaWQ?pwd=xawa) || [4.14.lite](https://drive.google.com/file/d/1eULia_onOtRXHMAW9VeDL8N2_7z8J1ba/view?usp=share_link)
|
31 |
+
|
32 |
+
v4.9.2 - 2023.11.01 | [Google Drive](https://drive.google.com/file/d/1UssCvbL8N-ty0xIKM5G5ZTEgp9o4w3hp/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/18cbx3EP4HWgSa1vkcXvvyw?pwd=swr9) || v4.3 - 2022.8.17 | [Google Drive](https://drive.google.com/file/d/1xrNofTGMHdt9sQv7-EOG0EChl8hZW_cU/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/12AUAeZLZf5E1_Zx6WkS3xw?pwd=q83a)
|
33 |
+
|
34 |
+
v3.8 - 2021.6.17 | [Google Drive](https://drive.google.com/file/d/1O5KfS3KzZCY3imeCr2LCsntLhutKuAqj/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/1X-jpWBZWe-IQBoNAsxo2mA?pwd=kxr3) || v3.1 - 2021.5.17 | [Google Drive](https://drive.google.com/file/d/1xn4R3TQyFhtMXN2pa3lRB8cd4E1zckQe/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/1W4p_Ni04HLI_jTy45sVodA?pwd=64bz)
|
35 |
+
|
36 |
+
[More Older Version](https://github.com/megvii-research/ECCV2022-RIFE/issues/41)
|
37 |
+
|
38 |
+
### Installation
|
39 |
+
python <= 3.11
|
40 |
+
```
|
41 |
+
git clone [email protected]:hzwer/Practical-RIFE.git
|
42 |
+
cd Practical-RIFE
|
43 |
+
pip3 install -r requirements.txt
|
44 |
+
```
|
45 |
+
Download a model from the model list and put *.py and flownet.pkl on train_log/
|
46 |
+
### Run
|
47 |
+
|
48 |
+
You can use our [demo video](https://drive.google.com/file/d/1i3xlKb7ax7Y70khcTcuePi6E7crO_dFc/view?usp=sharing) or your video.
|
49 |
+
```
|
50 |
+
python3 inference_video.py --multi=2 --video=video.mp4
|
51 |
+
```
|
52 |
+
(generate video_2X_xxfps.mp4)
|
53 |
+
```
|
54 |
+
python3 inference_video.py --multi=4 --video=video.mp4
|
55 |
+
```
|
56 |
+
(for 4X interpolation)
|
57 |
+
```
|
58 |
+
python3 inference_video.py --multi=2 --video=video.mp4 --scale=0.5
|
59 |
+
```
|
60 |
+
(If your video has high resolution, such as 4K, we recommend set --scale=0.5 (default 1.0))
|
61 |
+
```
|
62 |
+
python3 inference_video.py --multi=4 --img=input/
|
63 |
+
```
|
64 |
+
(to read video from pngs, like input/0.png ... input/612.png, ensure that the png names are numbers)
|
65 |
+
|
66 |
+
Parameter descriptions:
|
67 |
+
|
68 |
+
--img / --video: The input file address
|
69 |
+
|
70 |
+
--output: Output video name 'xxx.mp4'
|
71 |
+
|
72 |
+
--model: Directory with trained model files
|
73 |
+
|
74 |
+
--UHD: It is equivalent to setting scale=0.5
|
75 |
+
|
76 |
+
--montage: Splice the generated video with the original video, like [this demo](https://www.youtube.com/watch?v=kUQ7KK6MhHw)
|
77 |
+
|
78 |
+
--fps: Set output FPS manually
|
79 |
+
|
80 |
+
--ext: Set output video format, default: mp4
|
81 |
+
|
82 |
+
--multi: Interpolation frame rate multiplier
|
83 |
+
|
84 |
+
--exp: Set --multi to 2^(--exp)
|
85 |
+
|
86 |
+
--skip: It's no longer useful refer to [issue 207](https://github.com/hzwer/ECCV2022-RIFE/issues/207)
|
87 |
+
|
88 |
+
|
89 |
+
### Model training
|
90 |
+
The whole repo can be downloaded from [v4.0](https://drive.google.com/file/d/1zoSz7b8c6kUsnd4gYZ_6TrKxa7ghHJWW/view?usp=sharing), [v4.12](https://drive.google.com/file/d/1IHB35zhO4rr-JSMnpRvHhU9U65Z4giWv/view?usp=sharing), [v4.15](https://drive.google.com/file/d/19sUMZ-6H7g_hYDjTcqxYu9kE7TqnfS3k/view?usp=sharing), [v4.18](https://drive.google.com/file/d/1g8D2foww7DhGLIxtaDLr9fU3y-ByOw4B/view?usp=share_link), [v4.25](https://drive.google.com/file/d/1_l4OgBp3GrrHOcQB87xXCI7OtTzyeXZL/view?usp=share_link). However, we currently do not have the time to organize them well, they are for reference only.
|
91 |
+
|
92 |
+
## Video Enhancement
|
93 |
+
|
94 |
+
<img width="710" alt="image" src="https://github.com/hzwer/Practical-RIFE/assets/10103856/5bae134c-0747-4084-bbab-37b1595352f1">
|
95 |
+
|
96 |
+
We are developing a practical model of [SAFA](https://github.com/megvii-research/WACV2024-SAFA). Welcome to check its [demo](https://www.youtube.com/watch?v=QII2KQSBBwk) ([BiliBili](https://www.bilibili.com/video/BV1Up4y1d7kF/)) and provide advice.
|
97 |
+
|
98 |
+
v0.5 - 2023.12.26 | [Google Drive](https://drive.google.com/file/d/1OLO9hLV97ZQ4uRV2-aQqgnwhbKMMt6TX/view?usp=sharing)
|
99 |
+
|
100 |
+
```
|
101 |
+
python3 inference_video_enhance.py --video=demo.mp4
|
102 |
+
```
|
103 |
+
|
104 |
+
## Citation
|
105 |
+
|
106 |
+
```
|
107 |
+
@inproceedings{huang2022rife,
|
108 |
+
title={Real-Time Intermediate Flow Estimation for Video Frame Interpolation},
|
109 |
+
author={Huang, Zhewei and Zhang, Tianyuan and Heng, Wen and Shi, Boxin and Zhou, Shuchang},
|
110 |
+
booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
|
111 |
+
year={2022}
|
112 |
+
}
|
113 |
+
```
|
114 |
+
```
|
115 |
+
@inproceedings{huang2024safa,
|
116 |
+
title={Scale-Adaptive Feature Aggregation for Efficient Space-Time Video Super-Resolution},
|
117 |
+
author={Huang, Zhewei and Huang, Ailin and Hu, Xiaotao and Hu, Chen and Xu, Jun and Zhou, Shuchang},
|
118 |
+
booktitle={Winter Conference on Applications of Computer Vision (WACV)},
|
119 |
+
year={2024}
|
120 |
+
}
|
121 |
+
```
|
122 |
+
|
123 |
+
## Reference
|
124 |
+
|
125 |
+
Optical Flow:
|
126 |
+
[ARFlow](https://github.com/lliuz/ARFlow) [pytorch-liteflownet](https://github.com/sniklaus/pytorch-liteflownet) [RAFT](https://github.com/princeton-vl/RAFT) [pytorch-PWCNet](https://github.com/sniklaus/pytorch-pwc)
|
127 |
+
|
128 |
+
Video Interpolation:
|
129 |
+
[DVF](https://github.com/lxx1991/pytorch-voxel-flow) [TOflow](https://github.com/Coldog2333/pytoflow) [SepConv](https://github.com/sniklaus/sepconv-slomo) [DAIN](https://github.com/baowenbo/DAIN) [CAIN](https://github.com/myungsub/CAIN) [MEMC-Net](https://github.com/baowenbo/MEMC-Net) [SoftSplat](https://github.com/sniklaus/softmax-splatting) [BMBC](https://github.com/JunHeum/BMBC) [EDSC](https://github.com/Xianhang/EDSC-pytorch) [EQVI](https://github.com/lyh-18/EQVI) [RIFE](https://github.com/hzwer/arXiv2020-RIFE)
|
__pycache__/video_processing.cpython-312.pyc
ADDED
Binary file (12.1 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image, ImageFilter
|
4 |
+
import cv2
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torchvision import transforms
|
9 |
+
import warnings
|
10 |
+
from video_processing import process_video
|
11 |
+
|
12 |
+
warnings.filterwarnings("ignore")
|
13 |
+
|
14 |
+
# ZeroGPU decorator (if available)
|
15 |
+
try:
|
16 |
+
import spaces
|
17 |
+
HAS_ZEROGPU = True
|
18 |
+
except ImportError:
|
19 |
+
HAS_ZEROGPU = False
|
20 |
+
# Create a dummy decorator if spaces is not available
|
21 |
+
def spaces_gpu(func):
|
22 |
+
return func
|
23 |
+
spaces = type('spaces', (), {'GPU': spaces_gpu})()
|
24 |
+
|
25 |
+
# VAAPI acceleration check
|
26 |
+
def check_vaapi_support():
|
27 |
+
"""Check if VAAPI is available for hardware acceleration"""
|
28 |
+
try:
|
29 |
+
# Check if VAAPI devices are available
|
30 |
+
vaapi_devices = [f for f in os.listdir('/dev/dri') if f.startswith('render')]
|
31 |
+
return len(vaapi_devices) > 0
|
32 |
+
except:
|
33 |
+
return False
|
34 |
+
|
35 |
+
HAS_VAAPI = check_vaapi_support()
|
36 |
+
|
37 |
+
class TorchUpscaler:
|
38 |
+
"""PyTorch-based upscaler that can use GPU acceleration"""
|
39 |
+
|
40 |
+
def __init__(self, device='auto'):
|
41 |
+
if device == 'auto':
|
42 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
43 |
+
else:
|
44 |
+
self.device = torch.device(device)
|
45 |
+
|
46 |
+
print(f"Using device: {self.device}")
|
47 |
+
|
48 |
+
def bicubic_torch(self, image_tensor, scale_factor):
|
49 |
+
"""GPU-accelerated bicubic upscaling using PyTorch"""
|
50 |
+
return F.interpolate(
|
51 |
+
image_tensor,
|
52 |
+
scale_factor=scale_factor,
|
53 |
+
mode='bicubic',
|
54 |
+
align_corners=False,
|
55 |
+
antialias=True
|
56 |
+
)
|
57 |
+
|
58 |
+
def lanczos_torch(self, image_tensor, scale_factor):
|
59 |
+
"""GPU-accelerated Lanczos-style upscaling"""
|
60 |
+
return F.interpolate(
|
61 |
+
image_tensor,
|
62 |
+
scale_factor=scale_factor,
|
63 |
+
mode='bicubic',
|
64 |
+
align_corners=False,
|
65 |
+
antialias=True
|
66 |
+
)
|
67 |
+
|
68 |
+
def esrgan_style_upscale(self, image_tensor, scale_factor):
|
69 |
+
"""Simple ESRGAN-style upscaling using convolutions"""
|
70 |
+
b, c, h, w = image_tensor.shape
|
71 |
+
upscaled = F.interpolate(image_tensor, scale_factor=scale_factor, mode='bicubic', align_corners=False)
|
72 |
+
kernel = torch.tensor([[[[-1, -1, -1],
|
73 |
+
[-1, 9, -1],
|
74 |
+
[-1, -1, -1]]]], dtype=torch.float32, device=self.device)
|
75 |
+
kernel = kernel.repeat(c, 1, 1, 1)
|
76 |
+
sharpened = F.conv2d(upscaled, kernel, padding=1, groups=c)
|
77 |
+
result = 0.8 * upscaled + 0.2 * sharpened
|
78 |
+
return torch.clamp(result, 0, 1)
|
79 |
+
|
80 |
+
class VAAPIUpscaler:
|
81 |
+
"""VAAPI hardware-accelerated upscaler"""
|
82 |
+
|
83 |
+
def __init__(self):
|
84 |
+
self.vaapi_available = HAS_VAAPI
|
85 |
+
if self.vaapi_available:
|
86 |
+
print("VAAPI hardware acceleration available")
|
87 |
+
else:
|
88 |
+
print("VAAPI hardware acceleration not available")
|
89 |
+
|
90 |
+
def upscale_vaapi(self, image_array, scale_factor, method):
|
91 |
+
"""Use VAAPI for hardware-accelerated upscaling"""
|
92 |
+
if not self.vaapi_available:
|
93 |
+
return None
|
94 |
+
try:
|
95 |
+
h, w = image_array.shape[:2]
|
96 |
+
new_h, new_w = int(h * scale_factor), int(w * scale_factor)
|
97 |
+
if method == "VAAPI_BICUBIC":
|
98 |
+
return cv2.resize(image_array, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
|
99 |
+
elif method == "VAAPI_LANCZOS":
|
100 |
+
return cv2.resize(image_array, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
|
101 |
+
except Exception as e:
|
102 |
+
print(f"VAAPI upscaling failed: {e}")
|
103 |
+
return None
|
104 |
+
|
105 |
+
torch_upscaler = TorchUpscaler()
|
106 |
+
vaapi_upscaler = VAAPIUpscaler()
|
107 |
+
|
108 |
+
@spaces.GPU if HAS_ZEROGPU else lambda x: x
|
109 |
+
def upscale_image_accelerated(image, scale_factor, method, enhance_quality, use_gpu_acceleration):
|
110 |
+
if image is None:
|
111 |
+
return None
|
112 |
+
|
113 |
+
original_width, original_height = image.size
|
114 |
+
new_width = int(original_width * scale_factor)
|
115 |
+
new_height = int(original_height * scale_factor)
|
116 |
+
|
117 |
+
try:
|
118 |
+
if use_gpu_acceleration and torch.cuda.is_available():
|
119 |
+
print("Using GPU acceleration")
|
120 |
+
transform = transforms.Compose([transforms.ToTensor()])
|
121 |
+
image_tensor = transform(image).unsqueeze(0).to(torch_upscaler.device)
|
122 |
+
|
123 |
+
if method == "GPU_Bicubic":
|
124 |
+
upscaled_tensor = torch_upscaler.bicubic_torch(image_tensor, scale_factor)
|
125 |
+
elif method == "GPU_Lanczos":
|
126 |
+
upscaled_tensor = torch_upscaler.lanczos_torch(image_tensor, scale_factor)
|
127 |
+
elif method == "GPU_ESRGAN_Style":
|
128 |
+
upscaled_tensor = torch_upscaler.esrgan_style_upscale(image_tensor, scale_factor)
|
129 |
+
else:
|
130 |
+
upscaled_tensor = torch_upscaler.bicubic_torch(image_tensor, scale_factor)
|
131 |
+
|
132 |
+
upscaled = transforms.ToPILImage()(upscaled_tensor.squeeze(0).cpu())
|
133 |
+
|
134 |
+
elif method.startswith("VAAPI_") and HAS_VAAPI:
|
135 |
+
print("Using VAAPI acceleration")
|
136 |
+
img_array = np.array(image)
|
137 |
+
upscaled_array = vaapi_upscaler.upscale_vaapi(img_array, scale_factor, method)
|
138 |
+
upscaled = Image.fromarray(upscaled_array) if upscaled_array is not None else image.resize((new_width, new_height), Image.BICUBIC)
|
139 |
+
|
140 |
+
else:
|
141 |
+
print("Using CPU methods")
|
142 |
+
if method == "Bicubic":
|
143 |
+
upscaled = image.resize((new_width, new_height), Image.BICUBIC)
|
144 |
+
elif method == "Lanczos":
|
145 |
+
upscaled = image.resize((new_width, new_height), Image.LANCZOS)
|
146 |
+
else:
|
147 |
+
upscaled = image.resize((new_width, new_height), Image.BICUBIC)
|
148 |
+
|
149 |
+
if enhance_quality:
|
150 |
+
upscaled = upscaled.filter(ImageFilter.UnsharpMask(radius=1, percent=120, threshold=3))
|
151 |
+
|
152 |
+
return upscaled
|
153 |
+
|
154 |
+
except Exception as e:
|
155 |
+
print(f"Error during upscaling: {e}")
|
156 |
+
return image
|
157 |
+
|
158 |
+
def get_available_methods():
|
159 |
+
methods = ["Bicubic", "Lanczos"]
|
160 |
+
if torch.cuda.is_available():
|
161 |
+
methods.extend(["GPU_Bicubic", "GPU_Lanczos", "GPU_ESRGAN_Style"])
|
162 |
+
if HAS_VAAPI:
|
163 |
+
methods.extend(["VAAPI_BICUBIC", "VAAPI_LANCZOS"])
|
164 |
+
return methods
|
165 |
+
|
166 |
+
def get_system_info():
|
167 |
+
info = []
|
168 |
+
if torch.cuda.is_available():
|
169 |
+
gpu_name = torch.cuda.get_device_name(0)
|
170 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
171 |
+
info.append(f"🚀 CUDA GPU: {gpu_name} ({gpu_memory:.1f} GB)")
|
172 |
+
else:
|
173 |
+
info.append("❌ CUDA not available")
|
174 |
+
if HAS_ZEROGPU:
|
175 |
+
info.append("✅ ZeroGPU support enabled")
|
176 |
+
if HAS_VAAPI:
|
177 |
+
info.append("✅ VAAPI hardware acceleration available")
|
178 |
+
return "\n".join(info)
|
179 |
+
|
180 |
+
def process_and_info_accelerated(image, scale_factor, method, enhance_quality, use_gpu_acceleration):
|
181 |
+
if image is None:
|
182 |
+
return None, "Please upload an image first"
|
183 |
+
|
184 |
+
original_info = f"Original: {image.size[0]} × {image.size[1]} pixels"
|
185 |
+
result = upscale_image_accelerated(image, scale_factor, method, enhance_quality, use_gpu_acceleration)
|
186 |
+
if result is None:
|
187 |
+
return None, "Error processing image"
|
188 |
+
|
189 |
+
result_info = f"Upscaled: {result.size[0]} × {result.size[1]} pixels"
|
190 |
+
accel_info = "GPU/Hardware" if use_gpu_acceleration else "CPU"
|
191 |
+
|
192 |
+
combined_info = f"""
|
193 |
+
## Processing Details
|
194 |
+
{original_info}
|
195 |
+
{result_info}
|
196 |
+
**Scale Factor:** {scale_factor}x
|
197 |
+
**Method:** {method}
|
198 |
+
**Acceleration:** {accel_info}
|
199 |
+
**Quality Enhancement:** {'✅' if enhance_quality else '❌'}
|
200 |
+
|
201 |
+
## System Status
|
202 |
+
{get_system_info()}
|
203 |
+
"""
|
204 |
+
return result, combined_info
|
205 |
+
|
206 |
+
def create_accelerated_upscaler_ui():
|
207 |
+
available_methods = get_available_methods()
|
208 |
+
|
209 |
+
gr.Markdown("## 🚀 Accelerated Image Upscaler")
|
210 |
+
with gr.Row():
|
211 |
+
with gr.Column(scale=1):
|
212 |
+
input_image = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard"])
|
213 |
+
scale_factor = gr.Slider(minimum=1.5, maximum=4.0, step=0.5, value=2.0, label="Scale Factor")
|
214 |
+
method = gr.Dropdown(choices=available_methods, value=available_methods[0], label="Upscaling Method")
|
215 |
+
use_gpu_acceleration = gr.Checkbox(label="Use GPU Acceleration", value=torch.cuda.is_available())
|
216 |
+
enhance_quality = gr.Checkbox(label="Apply Quality Enhancement", value=True)
|
217 |
+
process_btn = gr.Button("🚀 Upscale Image", variant="primary")
|
218 |
+
|
219 |
+
with gr.Column(scale=2):
|
220 |
+
output_image = gr.Image(label="Upscaled Image", type="pil")
|
221 |
+
image_info = gr.Markdown(value=f"## System Status\n{get_system_info()}", label="Processing Information")
|
222 |
+
|
223 |
+
process_btn.click(
|
224 |
+
fn=process_and_info_accelerated,
|
225 |
+
inputs=[input_image, scale_factor, method, enhance_quality, use_gpu_acceleration],
|
226 |
+
outputs=[output_image, image_info]
|
227 |
+
)
|
228 |
+
|
229 |
+
def create_video_interface_ui():
|
230 |
+
gr.Markdown("## 🚀 Video Upscaler and Frame Interpolator")
|
231 |
+
with gr.Row():
|
232 |
+
with gr.Column(scale=1):
|
233 |
+
input_video = gr.Video(label="Upload Video", sources=["upload"])
|
234 |
+
scale_factor = gr.Slider(minimum=1.5, maximum=4.0, step=0.5, value=2.0, label="Scale Factor")
|
235 |
+
multi = gr.Slider(minimum=2, maximum=8, step=1, value=2, label="Frame Multiplier")
|
236 |
+
use_gpu_acceleration = gr.Checkbox(label="Use GPU Acceleration", value=torch.cuda.is_available())
|
237 |
+
process_btn = gr.Button("🚀 Process Video", variant="primary")
|
238 |
+
|
239 |
+
with gr.Column(scale=2):
|
240 |
+
output_video = gr.Video(label="Processed Video")
|
241 |
+
processing_info = gr.Markdown(value=f"## System Status\n{get_system_info()}", label="Processing Information")
|
242 |
+
|
243 |
+
process_btn.click(
|
244 |
+
fn=process_video_wrapper,
|
245 |
+
inputs=[input_video, scale_factor, multi, use_gpu_acceleration],
|
246 |
+
outputs=[output_video, processing_info]
|
247 |
+
)
|
248 |
+
|
249 |
+
def process_video_wrapper(video_path, scale_factor, multi, use_gpu):
|
250 |
+
if video_path is None:
|
251 |
+
return None, "Please upload a video first"
|
252 |
+
|
253 |
+
output_path = "temp_output.mp4"
|
254 |
+
modelDir = 'rife/train_log'
|
255 |
+
|
256 |
+
processed_video_path = process_video(
|
257 |
+
video=video_path,
|
258 |
+
output=output_path,
|
259 |
+
modelDir=modelDir,
|
260 |
+
fp16=use_gpu,
|
261 |
+
UHD=False,
|
262 |
+
scale=scale_factor,
|
263 |
+
skip=False,
|
264 |
+
fps=None,
|
265 |
+
png=False,
|
266 |
+
ext='mp4',
|
267 |
+
exp=1,
|
268 |
+
multi=multi
|
269 |
+
)
|
270 |
+
|
271 |
+
info = f"""
|
272 |
+
## Processing Details
|
273 |
+
**Scale Factor:** {scale_factor}x
|
274 |
+
**Frame Multiplier:** {multi}x
|
275 |
+
**Acceleration:** {'GPU' if use_gpu else 'CPU'}
|
276 |
+
|
277 |
+
## System Status
|
278 |
+
{get_system_info()}
|
279 |
+
"""
|
280 |
+
return processed_video_path, info
|
281 |
+
|
282 |
+
with gr.Blocks(title="Accelerated Media Processor", theme=gr.themes.Soft()) as demo:
|
283 |
+
with gr.Tab("Image Upscaler"):
|
284 |
+
create_accelerated_upscaler_ui()
|
285 |
+
with gr.Tab("Video Processing"):
|
286 |
+
create_video_interface_ui()
|
287 |
+
|
288 |
+
if __name__ == "__main__":
|
289 |
+
demo.launch(
|
290 |
+
server_name="0.0.0.0",
|
291 |
+
server_port=7860,
|
292 |
+
share=False,
|
293 |
+
debug=True
|
294 |
+
)
|
demo/I0_0.png
ADDED
![]() |
Git LFS Details
|
demo/I0_1.png
ADDED
![]() |
Git LFS Details
|
demo/I0_slomo_clipped.gif
ADDED
![]() |
Git LFS Details
|
demo/I2_0.png
ADDED
![]() |
Git LFS Details
|
demo/I2_1.png
ADDED
![]() |
Git LFS Details
|
demo/I2_slomo_clipped.gif
ADDED
![]() |
Git LFS Details
|
demo/i0.png
ADDED
![]() |
Git LFS Details
|
demo/i1.png
ADDED
![]() |
Git LFS Details
|
inference_img.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
from torch.nn import functional as F
|
6 |
+
import warnings
|
7 |
+
warnings.filterwarnings("ignore")
|
8 |
+
|
9 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
+
torch.set_grad_enabled(False)
|
11 |
+
if torch.cuda.is_available():
|
12 |
+
torch.backends.cudnn.enabled = True
|
13 |
+
torch.backends.cudnn.benchmark = True
|
14 |
+
|
15 |
+
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
|
16 |
+
parser.add_argument('--img', dest='img', nargs=2, required=True)
|
17 |
+
parser.add_argument('--exp', default=4, type=int)
|
18 |
+
parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range')
|
19 |
+
parser.add_argument('--rthreshold', default=0.02, type=float, help='returns image when actual ratio falls in given range threshold')
|
20 |
+
parser.add_argument('--rmaxcycles', default=8, type=int, help='limit max number of bisectional cycles')
|
21 |
+
parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files')
|
22 |
+
|
23 |
+
args = parser.parse_args()
|
24 |
+
|
25 |
+
try:
|
26 |
+
try:
|
27 |
+
from model.RIFE_HDv2 import Model
|
28 |
+
model = Model()
|
29 |
+
model.load_model(args.modelDir, -1)
|
30 |
+
print("Loaded v2.x HD model.")
|
31 |
+
except:
|
32 |
+
from train_log.RIFE_HDv3 import Model
|
33 |
+
model = Model()
|
34 |
+
model.load_model(args.modelDir, -1)
|
35 |
+
print("Loaded v3.x HD model.")
|
36 |
+
except:
|
37 |
+
from model.RIFE_HD import Model
|
38 |
+
model = Model()
|
39 |
+
model.load_model(args.modelDir, -1)
|
40 |
+
print("Loaded v1.x HD model")
|
41 |
+
if not hasattr(model, 'version'):
|
42 |
+
model.version = 0
|
43 |
+
model.eval()
|
44 |
+
model.device()
|
45 |
+
|
46 |
+
if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
|
47 |
+
img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
|
48 |
+
img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
|
49 |
+
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0)
|
50 |
+
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0)
|
51 |
+
|
52 |
+
else:
|
53 |
+
img0 = cv2.imread(args.img[0], cv2.IMREAD_UNCHANGED)
|
54 |
+
img1 = cv2.imread(args.img[1], cv2.IMREAD_UNCHANGED)
|
55 |
+
img0 = cv2.resize(img0, (448, 256))
|
56 |
+
img1 = cv2.resize(img1, (448, 256))
|
57 |
+
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
|
58 |
+
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
|
59 |
+
|
60 |
+
n, c, h, w = img0.shape
|
61 |
+
ph = ((h - 1) // 64 + 1) * 64
|
62 |
+
pw = ((w - 1) // 64 + 1) * 64
|
63 |
+
padding = (0, pw - w, 0, ph - h)
|
64 |
+
img0 = F.pad(img0, padding)
|
65 |
+
img1 = F.pad(img1, padding)
|
66 |
+
|
67 |
+
|
68 |
+
if args.ratio:
|
69 |
+
if model.version >= 3.9:
|
70 |
+
img_list = [img0, model.inference(img0, img1, args.ratio), img1]
|
71 |
+
else:
|
72 |
+
img0_ratio = 0.0
|
73 |
+
img1_ratio = 1.0
|
74 |
+
if args.ratio <= img0_ratio + args.rthreshold / 2:
|
75 |
+
middle = img0
|
76 |
+
elif args.ratio >= img1_ratio - args.rthreshold / 2:
|
77 |
+
middle = img1
|
78 |
+
else:
|
79 |
+
tmp_img0 = img0
|
80 |
+
tmp_img1 = img1
|
81 |
+
for inference_cycle in range(args.rmaxcycles):
|
82 |
+
middle = model.inference(tmp_img0, tmp_img1)
|
83 |
+
middle_ratio = ( img0_ratio + img1_ratio ) / 2
|
84 |
+
if args.ratio - (args.rthreshold / 2) <= middle_ratio <= args.ratio + (args.rthreshold / 2):
|
85 |
+
break
|
86 |
+
if args.ratio > middle_ratio:
|
87 |
+
tmp_img0 = middle
|
88 |
+
img0_ratio = middle_ratio
|
89 |
+
else:
|
90 |
+
tmp_img1 = middle
|
91 |
+
img1_ratio = middle_ratio
|
92 |
+
img_list.append(middle)
|
93 |
+
img_list.append(img1)
|
94 |
+
else:
|
95 |
+
if model.version >= 3.9:
|
96 |
+
img_list = [img0]
|
97 |
+
n = 2 ** args.exp
|
98 |
+
for i in range(n-1):
|
99 |
+
img_list.append(model.inference(img0, img1, (i+1) * 1. / n))
|
100 |
+
img_list.append(img1)
|
101 |
+
else:
|
102 |
+
img_list = [img0, img1]
|
103 |
+
for i in range(args.exp):
|
104 |
+
tmp = []
|
105 |
+
for j in range(len(img_list) - 1):
|
106 |
+
mid = model.inference(img_list[j], img_list[j + 1])
|
107 |
+
tmp.append(img_list[j])
|
108 |
+
tmp.append(mid)
|
109 |
+
tmp.append(img1)
|
110 |
+
img_list = tmp
|
111 |
+
|
112 |
+
if not os.path.exists('output'):
|
113 |
+
os.mkdir('output')
|
114 |
+
for i in range(len(img_list)):
|
115 |
+
if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
|
116 |
+
cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
|
117 |
+
else:
|
118 |
+
cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
|
inference_img_SR.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
from torch.nn import functional as F
|
6 |
+
import warnings
|
7 |
+
warnings.filterwarnings("ignore")
|
8 |
+
|
9 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
+
torch.set_grad_enabled(False)
|
11 |
+
if torch.cuda.is_available():
|
12 |
+
torch.backends.cudnn.enabled = True
|
13 |
+
torch.backends.cudnn.benchmark = True
|
14 |
+
|
15 |
+
parser = argparse.ArgumentParser(description='STVSR for a pair of images')
|
16 |
+
parser.add_argument('--img', dest='img', nargs=2, required=True)
|
17 |
+
parser.add_argument('--exp', default=2, type=int)
|
18 |
+
parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range')
|
19 |
+
parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files')
|
20 |
+
|
21 |
+
args = parser.parse_args()
|
22 |
+
|
23 |
+
from train_log.model import Model
|
24 |
+
model = Model()
|
25 |
+
model.device()
|
26 |
+
model.load_model('train_log')
|
27 |
+
model.eval()
|
28 |
+
|
29 |
+
if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
|
30 |
+
img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
|
31 |
+
img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
|
32 |
+
img0 = cv2.resize(img0, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
|
33 |
+
img1 = cv2.resize(img1, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
|
34 |
+
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0)
|
35 |
+
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0)
|
36 |
+
else:
|
37 |
+
img0 = cv2.imread(args.img[0], cv2.IMREAD_UNCHANGED)
|
38 |
+
img1 = cv2.imread(args.img[1], cv2.IMREAD_UNCHANGED)
|
39 |
+
img0 = cv2.resize(img0, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
|
40 |
+
img1 = cv2.resize(img1, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
|
41 |
+
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
|
42 |
+
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
|
43 |
+
|
44 |
+
n, c, h, w = img0.shape
|
45 |
+
ph = ((h - 1) // 32 + 1) * 32
|
46 |
+
pw = ((w - 1) // 32 + 1) * 32
|
47 |
+
padding = (0, pw - w, 0, ph - h)
|
48 |
+
img0 = F.pad(img0, padding)
|
49 |
+
img1 = F.pad(img1, padding)
|
50 |
+
|
51 |
+
if args.ratio:
|
52 |
+
print('ratio={}'.format(args.ratio))
|
53 |
+
img_list = model.inference(img0, img1, timestep=args.ratio)
|
54 |
+
else:
|
55 |
+
n = 2 ** args.exp - 1
|
56 |
+
time_list = [0]
|
57 |
+
for i in range(n):
|
58 |
+
time_list.append((i+1) * 1. / (n+1))
|
59 |
+
time_list.append(1)
|
60 |
+
print(time_list)
|
61 |
+
img_list = model.inference(img0, img1, timestep=time_list)
|
62 |
+
|
63 |
+
if not os.path.exists('output'):
|
64 |
+
os.mkdir('output')
|
65 |
+
for i in range(len(img_list)):
|
66 |
+
if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
|
67 |
+
cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
|
68 |
+
else:
|
69 |
+
cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
|
inference_video.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
from torch.nn import functional as F
|
8 |
+
import warnings
|
9 |
+
import _thread
|
10 |
+
import skvideo.io
|
11 |
+
from queue import Queue, Empty
|
12 |
+
from model.pytorch_msssim import ssim_matlab
|
13 |
+
|
14 |
+
warnings.filterwarnings("ignore")
|
15 |
+
|
16 |
+
def transferAudio(sourceVideo, targetVideo):
|
17 |
+
import shutil
|
18 |
+
import moviepy.editor
|
19 |
+
tempAudioFileName = "./temp/audio.mkv"
|
20 |
+
|
21 |
+
# split audio from original video file and store in "temp" directory
|
22 |
+
if True:
|
23 |
+
|
24 |
+
# clear old "temp" directory if it exits
|
25 |
+
if os.path.isdir("temp"):
|
26 |
+
# remove temp directory
|
27 |
+
shutil.rmtree("temp")
|
28 |
+
# create new "temp" directory
|
29 |
+
os.makedirs("temp")
|
30 |
+
# extract audio from video
|
31 |
+
os.system('ffmpeg -y -i "{}" -c:a copy -vn {}'.format(sourceVideo, tempAudioFileName))
|
32 |
+
|
33 |
+
targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1]
|
34 |
+
os.rename(targetVideo, targetNoAudio)
|
35 |
+
# combine audio file and new video file
|
36 |
+
os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
|
37 |
+
|
38 |
+
if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac
|
39 |
+
tempAudioFileName = "./temp/audio.m4a"
|
40 |
+
os.system('ffmpeg -y -i "{}" -c:a aac -b:a 160k -vn {}'.format(sourceVideo, tempAudioFileName))
|
41 |
+
os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
|
42 |
+
if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format
|
43 |
+
os.rename(targetNoAudio, targetVideo)
|
44 |
+
print("Audio transfer failed. Interpolated video will have no audio")
|
45 |
+
else:
|
46 |
+
print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.")
|
47 |
+
|
48 |
+
# remove audio-less video
|
49 |
+
os.remove(targetNoAudio)
|
50 |
+
else:
|
51 |
+
os.remove(targetNoAudio)
|
52 |
+
|
53 |
+
# remove temp directory
|
54 |
+
shutil.rmtree("temp")
|
55 |
+
|
56 |
+
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
|
57 |
+
parser.add_argument('--video', dest='video', type=str, default=None)
|
58 |
+
parser.add_argument('--output', dest='output', type=str, default=None)
|
59 |
+
parser.add_argument('--img', dest='img', type=str, default=None)
|
60 |
+
parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video')
|
61 |
+
parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files')
|
62 |
+
parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores')
|
63 |
+
parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
|
64 |
+
parser.add_argument('--scale', dest='scale', type=float, default=1.0, help='Try scale=0.5 for 4k video')
|
65 |
+
parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing')
|
66 |
+
parser.add_argument('--fps', dest='fps', type=int, default=None)
|
67 |
+
parser.add_argument('--png', dest='png', action='store_true', help='whether to vid_out png format vid_outs')
|
68 |
+
parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='vid_out video extension')
|
69 |
+
parser.add_argument('--exp', dest='exp', type=int, default=1)
|
70 |
+
parser.add_argument('--multi', dest='multi', type=int, default=2)
|
71 |
+
|
72 |
+
args = parser.parse_args()
|
73 |
+
if args.exp != 1:
|
74 |
+
args.multi = (2 ** args.exp)
|
75 |
+
assert (not args.video is None or not args.img is None)
|
76 |
+
if args.skip:
|
77 |
+
print("skip flag is abandoned, please refer to issue #207.")
|
78 |
+
if args.UHD and args.scale==1.0:
|
79 |
+
args.scale = 0.5
|
80 |
+
assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0]
|
81 |
+
if not args.img is None:
|
82 |
+
args.png = True
|
83 |
+
|
84 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
85 |
+
torch.set_grad_enabled(False)
|
86 |
+
if torch.cuda.is_available():
|
87 |
+
torch.backends.cudnn.enabled = True
|
88 |
+
torch.backends.cudnn.benchmark = True
|
89 |
+
if(args.fp16):
|
90 |
+
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
91 |
+
|
92 |
+
from train_log.RIFE_HDv3 import Model
|
93 |
+
model = Model()
|
94 |
+
if not hasattr(model, 'version'):
|
95 |
+
model.version = 0
|
96 |
+
model.load_model(args.modelDir, -1)
|
97 |
+
print("Loaded 3.x/4.x HD model.")
|
98 |
+
model.eval()
|
99 |
+
model.device()
|
100 |
+
|
101 |
+
if not args.video is None:
|
102 |
+
videoCapture = cv2.VideoCapture(args.video)
|
103 |
+
fps = videoCapture.get(cv2.CAP_PROP_FPS)
|
104 |
+
tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
|
105 |
+
videoCapture.release()
|
106 |
+
if args.fps is None:
|
107 |
+
fpsNotAssigned = True
|
108 |
+
args.fps = fps * args.multi
|
109 |
+
else:
|
110 |
+
fpsNotAssigned = False
|
111 |
+
videogen = skvideo.io.vreader(args.video)
|
112 |
+
lastframe = next(videogen)
|
113 |
+
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
|
114 |
+
video_path_wo_ext, ext = os.path.splitext(args.video)
|
115 |
+
print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, fps, args.fps))
|
116 |
+
if args.png == False and fpsNotAssigned == True:
|
117 |
+
print("The audio will be merged after interpolation process")
|
118 |
+
else:
|
119 |
+
print("Will not merge audio because using png or fps flag!")
|
120 |
+
else:
|
121 |
+
videogen = []
|
122 |
+
for f in os.listdir(args.img):
|
123 |
+
if 'png' in f:
|
124 |
+
videogen.append(f)
|
125 |
+
tot_frame = len(videogen)
|
126 |
+
videogen.sort(key= lambda x:int(x[:-4]))
|
127 |
+
lastframe = cv2.imread(os.path.join(args.img, videogen[0]), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
|
128 |
+
videogen = videogen[1:]
|
129 |
+
h, w, _ = lastframe.shape
|
130 |
+
vid_out_name = None
|
131 |
+
vid_out = None
|
132 |
+
if args.png:
|
133 |
+
if not os.path.exists('vid_out'):
|
134 |
+
os.mkdir('vid_out')
|
135 |
+
else:
|
136 |
+
if args.output is not None:
|
137 |
+
vid_out_name = args.output
|
138 |
+
else:
|
139 |
+
vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, args.multi, int(np.round(args.fps)), args.ext)
|
140 |
+
vid_out = cv2.VideoWriter(vid_out_name, fourcc, args.fps, (w, h))
|
141 |
+
|
142 |
+
def clear_write_buffer(user_args, write_buffer):
|
143 |
+
cnt = 0
|
144 |
+
while True:
|
145 |
+
item = write_buffer.get()
|
146 |
+
if item is None:
|
147 |
+
break
|
148 |
+
if user_args.png:
|
149 |
+
cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1])
|
150 |
+
cnt += 1
|
151 |
+
else:
|
152 |
+
vid_out.write(item[:, :, ::-1])
|
153 |
+
|
154 |
+
def build_read_buffer(user_args, read_buffer, videogen):
|
155 |
+
try:
|
156 |
+
for frame in videogen:
|
157 |
+
if not user_args.img is None:
|
158 |
+
frame = cv2.imread(os.path.join(user_args.img, frame), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
|
159 |
+
if user_args.montage:
|
160 |
+
frame = frame[:, left: left + w]
|
161 |
+
read_buffer.put(frame)
|
162 |
+
except:
|
163 |
+
pass
|
164 |
+
read_buffer.put(None)
|
165 |
+
|
166 |
+
def make_inference(I0, I1, n):
|
167 |
+
global model
|
168 |
+
if model.version >= 3.9:
|
169 |
+
res = []
|
170 |
+
for i in range(n):
|
171 |
+
res.append(model.inference(I0, I1, (i+1) * 1. / (n+1), args.scale))
|
172 |
+
return res
|
173 |
+
else:
|
174 |
+
middle = model.inference(I0, I1, args.scale)
|
175 |
+
if n == 1:
|
176 |
+
return [middle]
|
177 |
+
first_half = make_inference(I0, middle, n=n//2)
|
178 |
+
second_half = make_inference(middle, I1, n=n//2)
|
179 |
+
if n%2:
|
180 |
+
return [*first_half, middle, *second_half]
|
181 |
+
else:
|
182 |
+
return [*first_half, *second_half]
|
183 |
+
|
184 |
+
def pad_image(img):
|
185 |
+
if(args.fp16):
|
186 |
+
return F.pad(img, padding).half()
|
187 |
+
else:
|
188 |
+
return F.pad(img, padding)
|
189 |
+
|
190 |
+
if args.montage:
|
191 |
+
left = w // 4
|
192 |
+
w = w // 2
|
193 |
+
tmp = max(128, int(128 / args.scale))
|
194 |
+
ph = ((h - 1) // tmp + 1) * tmp
|
195 |
+
pw = ((w - 1) // tmp + 1) * tmp
|
196 |
+
padding = (0, pw - w, 0, ph - h)
|
197 |
+
pbar = tqdm(total=tot_frame)
|
198 |
+
if args.montage:
|
199 |
+
lastframe = lastframe[:, left: left + w]
|
200 |
+
write_buffer = Queue(maxsize=500)
|
201 |
+
read_buffer = Queue(maxsize=500)
|
202 |
+
_thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen))
|
203 |
+
_thread.start_new_thread(clear_write_buffer, (args, write_buffer))
|
204 |
+
|
205 |
+
I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
|
206 |
+
I1 = pad_image(I1)
|
207 |
+
temp = None # save lastframe when processing static frame
|
208 |
+
|
209 |
+
while True:
|
210 |
+
if temp is not None:
|
211 |
+
frame = temp
|
212 |
+
temp = None
|
213 |
+
else:
|
214 |
+
frame = read_buffer.get()
|
215 |
+
if frame is None:
|
216 |
+
break
|
217 |
+
I0 = I1
|
218 |
+
I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
|
219 |
+
I1 = pad_image(I1)
|
220 |
+
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
|
221 |
+
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
|
222 |
+
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
223 |
+
|
224 |
+
break_flag = False
|
225 |
+
if ssim > 0.996:
|
226 |
+
frame = read_buffer.get() # read a new frame
|
227 |
+
if frame is None:
|
228 |
+
break_flag = True
|
229 |
+
frame = lastframe
|
230 |
+
else:
|
231 |
+
temp = frame
|
232 |
+
I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
|
233 |
+
I1 = pad_image(I1)
|
234 |
+
I1 = model.inference(I0, I1, scale=args.scale)
|
235 |
+
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
|
236 |
+
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
237 |
+
frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]
|
238 |
+
|
239 |
+
if ssim < 0.2:
|
240 |
+
output = []
|
241 |
+
for i in range(args.multi - 1):
|
242 |
+
output.append(I0)
|
243 |
+
'''
|
244 |
+
output = []
|
245 |
+
step = 1 / args.multi
|
246 |
+
alpha = 0
|
247 |
+
for i in range(args.multi - 1):
|
248 |
+
alpha += step
|
249 |
+
beta = 1-alpha
|
250 |
+
output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)
|
251 |
+
'''
|
252 |
+
else:
|
253 |
+
output = make_inference(I0, I1, args.multi - 1)
|
254 |
+
|
255 |
+
if args.montage:
|
256 |
+
write_buffer.put(np.concatenate((lastframe, lastframe), 1))
|
257 |
+
for mid in output:
|
258 |
+
mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
|
259 |
+
write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1))
|
260 |
+
else:
|
261 |
+
write_buffer.put(lastframe)
|
262 |
+
for mid in output:
|
263 |
+
mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
|
264 |
+
write_buffer.put(mid[:h, :w])
|
265 |
+
pbar.update(1)
|
266 |
+
lastframe = frame
|
267 |
+
if break_flag:
|
268 |
+
break
|
269 |
+
|
270 |
+
if args.montage:
|
271 |
+
write_buffer.put(np.concatenate((lastframe, lastframe), 1))
|
272 |
+
else:
|
273 |
+
write_buffer.put(lastframe)
|
274 |
+
write_buffer.put(None)
|
275 |
+
|
276 |
+
import time
|
277 |
+
while(not write_buffer.empty()):
|
278 |
+
time.sleep(0.1)
|
279 |
+
pbar.close()
|
280 |
+
if not vid_out is None:
|
281 |
+
vid_out.release()
|
282 |
+
|
283 |
+
# move audio to new video file if appropriate
|
284 |
+
if args.png == False and fpsNotAssigned == True and not args.video is None:
|
285 |
+
try:
|
286 |
+
transferAudio(args.video, vid_out_name)
|
287 |
+
except:
|
288 |
+
print("Audio transfer failed. Interpolated video will have no audio")
|
289 |
+
targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1]
|
290 |
+
os.rename(targetNoAudio, vid_out_name)
|
inference_video_enhance.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
from torch.nn import functional as F
|
8 |
+
import warnings
|
9 |
+
import _thread
|
10 |
+
import skvideo.io
|
11 |
+
from queue import Queue, Empty
|
12 |
+
from model.pytorch_msssim import ssim_matlab
|
13 |
+
|
14 |
+
warnings.filterwarnings("ignore")
|
15 |
+
|
16 |
+
def transferAudio(sourceVideo, targetVideo):
|
17 |
+
import shutil
|
18 |
+
import moviepy.editor
|
19 |
+
tempAudioFileName = "./temp/audio.mkv"
|
20 |
+
|
21 |
+
# split audio from original video file and store in "temp" directory
|
22 |
+
if True:
|
23 |
+
|
24 |
+
# clear old "temp" directory if it exits
|
25 |
+
if os.path.isdir("temp"):
|
26 |
+
# remove temp directory
|
27 |
+
shutil.rmtree("temp")
|
28 |
+
# create new "temp" directory
|
29 |
+
os.makedirs("temp")
|
30 |
+
# extract audio from video
|
31 |
+
os.system('ffmpeg -y -i "{}" -c:a copy -vn {}'.format(sourceVideo, tempAudioFileName))
|
32 |
+
|
33 |
+
targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1]
|
34 |
+
os.rename(targetVideo, targetNoAudio)
|
35 |
+
# combine audio file and new video file
|
36 |
+
os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
|
37 |
+
|
38 |
+
if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac
|
39 |
+
tempAudioFileName = "./temp/audio.m4a"
|
40 |
+
os.system('ffmpeg -y -i "{}" -c:a aac -b:a 160k -vn {}'.format(sourceVideo, tempAudioFileName))
|
41 |
+
os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
|
42 |
+
if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format
|
43 |
+
os.rename(targetNoAudio, targetVideo)
|
44 |
+
print("Audio transfer failed. Interpolated video will have no audio")
|
45 |
+
else:
|
46 |
+
print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.")
|
47 |
+
|
48 |
+
# remove audio-less video
|
49 |
+
os.remove(targetNoAudio)
|
50 |
+
else:
|
51 |
+
os.remove(targetNoAudio)
|
52 |
+
|
53 |
+
# remove temp directory
|
54 |
+
shutil.rmtree("temp")
|
55 |
+
|
56 |
+
parser = argparse.ArgumentParser(description='Video SR')
|
57 |
+
parser.add_argument('--video', dest='video', type=str, default=None)
|
58 |
+
parser.add_argument('--output', dest='output', type=str, default=None)
|
59 |
+
parser.add_argument('--img', dest='img', type=str, default=None)
|
60 |
+
parser.add_argument('--model', dest='modelDir', type=str, default='train_log_SAFA', help='directory with trained model files')
|
61 |
+
parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores')
|
62 |
+
parser.add_argument('--png', dest='png', action='store_true', help='whether to vid_out png format vid_outs')
|
63 |
+
parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='vid_out video extension')
|
64 |
+
|
65 |
+
args = parser.parse_args()
|
66 |
+
assert (not args.video is None or not args.img is None)
|
67 |
+
if not args.img is None:
|
68 |
+
args.png = True
|
69 |
+
|
70 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
71 |
+
torch.set_grad_enabled(False)
|
72 |
+
if torch.cuda.is_available():
|
73 |
+
torch.backends.cudnn.enabled = True
|
74 |
+
torch.backends.cudnn.benchmark = True
|
75 |
+
if(args.fp16):
|
76 |
+
print('set fp16')
|
77 |
+
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
78 |
+
|
79 |
+
try:
|
80 |
+
from train_log_SAFA.model import Model
|
81 |
+
except:
|
82 |
+
print("Please download our model from model list")
|
83 |
+
model = Model()
|
84 |
+
model.device()
|
85 |
+
model.load_model(args.modelDir)
|
86 |
+
print("Loaded SAFA model.")
|
87 |
+
model.eval()
|
88 |
+
|
89 |
+
if not args.video is None:
|
90 |
+
videoCapture = cv2.VideoCapture(args.video)
|
91 |
+
fps = videoCapture.get(cv2.CAP_PROP_FPS)
|
92 |
+
tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
|
93 |
+
videoCapture.release()
|
94 |
+
fpsNotAssigned = True
|
95 |
+
videogen = skvideo.io.vreader(args.video)
|
96 |
+
lastframe = next(videogen)
|
97 |
+
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
|
98 |
+
video_path_wo_ext, ext = os.path.splitext(args.video)
|
99 |
+
if args.png == False and fpsNotAssigned == True:
|
100 |
+
print("The audio will be merged after interpolation process")
|
101 |
+
else:
|
102 |
+
print("Will not merge audio because using png or fps flag!")
|
103 |
+
else:
|
104 |
+
videogen = []
|
105 |
+
for f in os.listdir(args.img):
|
106 |
+
if 'png' in f:
|
107 |
+
videogen.append(f)
|
108 |
+
tot_frame = len(videogen)
|
109 |
+
videogen.sort(key= lambda x:int(x[:-4]))
|
110 |
+
lastframe = cv2.imread(os.path.join(args.img, videogen[0]), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
|
111 |
+
videogen = videogen[1:]
|
112 |
+
|
113 |
+
h, w, _ = lastframe.shape
|
114 |
+
|
115 |
+
vid_out_name = None
|
116 |
+
vid_out = None
|
117 |
+
if args.png:
|
118 |
+
if not os.path.exists('vid_out'):
|
119 |
+
os.mkdir('vid_out')
|
120 |
+
else:
|
121 |
+
if args.output is not None:
|
122 |
+
vid_out_name = args.output
|
123 |
+
else:
|
124 |
+
vid_out_name = '{}_2X{}'.format(video_path_wo_ext, ext)
|
125 |
+
vid_out = cv2.VideoWriter(vid_out_name, fourcc, fps, (w, h))
|
126 |
+
|
127 |
+
def clear_write_buffer(user_args, write_buffer):
|
128 |
+
cnt = 0
|
129 |
+
while True:
|
130 |
+
item = write_buffer.get()
|
131 |
+
if item is None:
|
132 |
+
break
|
133 |
+
if user_args.png:
|
134 |
+
cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1])
|
135 |
+
cnt += 1
|
136 |
+
else:
|
137 |
+
vid_out.write(item[:, :, ::-1])
|
138 |
+
|
139 |
+
def build_read_buffer(user_args, read_buffer, videogen):
|
140 |
+
for frame in videogen:
|
141 |
+
if not user_args.img is None:
|
142 |
+
frame = cv2.imread(os.path.join(user_args.img, frame), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
|
143 |
+
# if user_args.montage:
|
144 |
+
# frame = frame[:, left: left + w]
|
145 |
+
read_buffer.put(frame)
|
146 |
+
read_buffer.put(None)
|
147 |
+
|
148 |
+
def pad_image(img):
|
149 |
+
if(args.fp16):
|
150 |
+
return F.pad(img, padding, mode='reflect').half()
|
151 |
+
else:
|
152 |
+
return F.pad(img, padding, mode='reflect')
|
153 |
+
|
154 |
+
tmp = 64
|
155 |
+
ph = ((h - 1) // tmp + 1) * tmp
|
156 |
+
pw = ((w - 1) // tmp + 1) * tmp
|
157 |
+
padding = (0, pw - w, 0, ph - h)
|
158 |
+
pbar = tqdm(total=tot_frame)
|
159 |
+
write_buffer = Queue(maxsize=500)
|
160 |
+
read_buffer = Queue(maxsize=500)
|
161 |
+
_thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen))
|
162 |
+
_thread.start_new_thread(clear_write_buffer, (args, write_buffer))
|
163 |
+
|
164 |
+
while True:
|
165 |
+
frame = read_buffer.get()
|
166 |
+
if frame is None:
|
167 |
+
break
|
168 |
+
# lastframe_2x = cv2.resize(lastframe, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
|
169 |
+
# frame_2x = cv2.resize(frame, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
|
170 |
+
I0 = pad_image(torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)
|
171 |
+
I1 = pad_image(torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)
|
172 |
+
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
|
173 |
+
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
|
174 |
+
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
175 |
+
if ssim < 0.2:
|
176 |
+
out = [model.inference(I0, I0, [0])[0], model.inference(I1, I1, [0])[0]]
|
177 |
+
else:
|
178 |
+
out = model.inference(I0, I1, [0, 1])
|
179 |
+
assert(len(out) == 2)
|
180 |
+
write_buffer.put((out[0][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
|
181 |
+
write_buffer.put((out[1][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
|
182 |
+
lastframe = read_buffer.get()
|
183 |
+
if lastframe is None:
|
184 |
+
break
|
185 |
+
pbar.update(2)
|
186 |
+
|
187 |
+
import time
|
188 |
+
while(not write_buffer.empty()):
|
189 |
+
time.sleep(0.1)
|
190 |
+
pbar.close()
|
191 |
+
if not vid_out is None:
|
192 |
+
vid_out.release()
|
193 |
+
|
194 |
+
# move audio to new video file if appropriate
|
195 |
+
if args.png == False and fpsNotAssigned == True and not args.video is None:
|
196 |
+
try:
|
197 |
+
transferAudio(args.video, vid_out_name)
|
198 |
+
except:
|
199 |
+
print("Audio transfer failed. Interpolated video will have no audio")
|
200 |
+
targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1]
|
201 |
+
os.rename(targetNoAudio, vid_out_name)
|
model/loss.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torchvision.models as models
|
6 |
+
|
7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
+
|
9 |
+
|
10 |
+
class EPE(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super(EPE, self).__init__()
|
13 |
+
|
14 |
+
def forward(self, flow, gt, loss_mask):
|
15 |
+
loss_map = (flow - gt.detach()) ** 2
|
16 |
+
loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5
|
17 |
+
return (loss_map * loss_mask)
|
18 |
+
|
19 |
+
|
20 |
+
class Ternary(nn.Module):
|
21 |
+
def __init__(self):
|
22 |
+
super(Ternary, self).__init__()
|
23 |
+
patch_size = 7
|
24 |
+
out_channels = patch_size * patch_size
|
25 |
+
self.w = np.eye(out_channels).reshape(
|
26 |
+
(patch_size, patch_size, 1, out_channels))
|
27 |
+
self.w = np.transpose(self.w, (3, 2, 0, 1))
|
28 |
+
self.w = torch.tensor(self.w).float().to(device)
|
29 |
+
|
30 |
+
def transform(self, img):
|
31 |
+
patches = F.conv2d(img, self.w, padding=3, bias=None)
|
32 |
+
transf = patches - img
|
33 |
+
transf_norm = transf / torch.sqrt(0.81 + transf**2)
|
34 |
+
return transf_norm
|
35 |
+
|
36 |
+
def rgb2gray(self, rgb):
|
37 |
+
r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :]
|
38 |
+
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
39 |
+
return gray
|
40 |
+
|
41 |
+
def hamming(self, t1, t2):
|
42 |
+
dist = (t1 - t2) ** 2
|
43 |
+
dist_norm = torch.mean(dist / (0.1 + dist), 1, True)
|
44 |
+
return dist_norm
|
45 |
+
|
46 |
+
def valid_mask(self, t, padding):
|
47 |
+
n, _, h, w = t.size()
|
48 |
+
inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t)
|
49 |
+
mask = F.pad(inner, [padding] * 4)
|
50 |
+
return mask
|
51 |
+
|
52 |
+
def forward(self, img0, img1):
|
53 |
+
img0 = self.transform(self.rgb2gray(img0))
|
54 |
+
img1 = self.transform(self.rgb2gray(img1))
|
55 |
+
return self.hamming(img0, img1) * self.valid_mask(img0, 1)
|
56 |
+
|
57 |
+
|
58 |
+
class SOBEL(nn.Module):
|
59 |
+
def __init__(self):
|
60 |
+
super(SOBEL, self).__init__()
|
61 |
+
self.kernelX = torch.tensor([
|
62 |
+
[1, 0, -1],
|
63 |
+
[2, 0, -2],
|
64 |
+
[1, 0, -1],
|
65 |
+
]).float()
|
66 |
+
self.kernelY = self.kernelX.clone().T
|
67 |
+
self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device)
|
68 |
+
self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device)
|
69 |
+
|
70 |
+
def forward(self, pred, gt):
|
71 |
+
N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3]
|
72 |
+
img_stack = torch.cat(
|
73 |
+
[pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0)
|
74 |
+
sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1)
|
75 |
+
sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1)
|
76 |
+
pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:]
|
77 |
+
pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:]
|
78 |
+
|
79 |
+
L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y)
|
80 |
+
loss = (L1X+L1Y)
|
81 |
+
return loss
|
82 |
+
|
83 |
+
class MeanShift(nn.Conv2d):
|
84 |
+
def __init__(self, data_mean, data_std, data_range=1, norm=True):
|
85 |
+
c = len(data_mean)
|
86 |
+
super(MeanShift, self).__init__(c, c, kernel_size=1)
|
87 |
+
std = torch.Tensor(data_std)
|
88 |
+
self.weight.data = torch.eye(c).view(c, c, 1, 1)
|
89 |
+
if norm:
|
90 |
+
self.weight.data.div_(std.view(c, 1, 1, 1))
|
91 |
+
self.bias.data = -1 * data_range * torch.Tensor(data_mean)
|
92 |
+
self.bias.data.div_(std)
|
93 |
+
else:
|
94 |
+
self.weight.data.mul_(std.view(c, 1, 1, 1))
|
95 |
+
self.bias.data = data_range * torch.Tensor(data_mean)
|
96 |
+
self.requires_grad = False
|
97 |
+
|
98 |
+
class VGGPerceptualLoss(torch.nn.Module):
|
99 |
+
def __init__(self, rank=0):
|
100 |
+
super(VGGPerceptualLoss, self).__init__()
|
101 |
+
blocks = []
|
102 |
+
pretrained = True
|
103 |
+
self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features
|
104 |
+
self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
|
105 |
+
for param in self.parameters():
|
106 |
+
param.requires_grad = False
|
107 |
+
|
108 |
+
def forward(self, X, Y, indices=None):
|
109 |
+
X = self.normalize(X)
|
110 |
+
Y = self.normalize(Y)
|
111 |
+
indices = [2, 7, 12, 21, 30]
|
112 |
+
weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5]
|
113 |
+
k = 0
|
114 |
+
loss = 0
|
115 |
+
for i in range(indices[-1]):
|
116 |
+
X = self.vgg_pretrained_features[i](X)
|
117 |
+
Y = self.vgg_pretrained_features[i](Y)
|
118 |
+
if (i+1) in indices:
|
119 |
+
loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1
|
120 |
+
k += 1
|
121 |
+
return loss
|
122 |
+
|
123 |
+
if __name__ == '__main__':
|
124 |
+
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
|
125 |
+
img1 = torch.tensor(np.random.normal(
|
126 |
+
0, 1, (3, 3, 256, 256))).float().to(device)
|
127 |
+
ternary_loss = Ternary()
|
128 |
+
print(ternary_loss(img0, img1).shape)
|
model/pytorch_msssim/__init__.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from math import exp
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
7 |
+
|
8 |
+
def gaussian(window_size, sigma):
|
9 |
+
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
10 |
+
return gauss/gauss.sum()
|
11 |
+
|
12 |
+
|
13 |
+
def create_window(window_size, channel=1):
|
14 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
15 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)
|
16 |
+
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
17 |
+
return window
|
18 |
+
|
19 |
+
def create_window_3d(window_size, channel=1):
|
20 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
21 |
+
_2D_window = _1D_window.mm(_1D_window.t())
|
22 |
+
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
|
23 |
+
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
|
24 |
+
return window
|
25 |
+
|
26 |
+
|
27 |
+
def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
|
28 |
+
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
|
29 |
+
if val_range is None:
|
30 |
+
if torch.max(img1) > 128:
|
31 |
+
max_val = 255
|
32 |
+
else:
|
33 |
+
max_val = 1
|
34 |
+
|
35 |
+
if torch.min(img1) < -0.5:
|
36 |
+
min_val = -1
|
37 |
+
else:
|
38 |
+
min_val = 0
|
39 |
+
L = max_val - min_val
|
40 |
+
else:
|
41 |
+
L = val_range
|
42 |
+
|
43 |
+
padd = 0
|
44 |
+
(_, channel, height, width) = img1.size()
|
45 |
+
if window is None:
|
46 |
+
real_size = min(window_size, height, width)
|
47 |
+
window = create_window(real_size, channel=channel).to(img1.device)
|
48 |
+
|
49 |
+
# mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
|
50 |
+
# mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
|
51 |
+
mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
|
52 |
+
mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
|
53 |
+
|
54 |
+
mu1_sq = mu1.pow(2)
|
55 |
+
mu2_sq = mu2.pow(2)
|
56 |
+
mu1_mu2 = mu1 * mu2
|
57 |
+
|
58 |
+
sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq
|
59 |
+
sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq
|
60 |
+
sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2
|
61 |
+
|
62 |
+
C1 = (0.01 * L) ** 2
|
63 |
+
C2 = (0.03 * L) ** 2
|
64 |
+
|
65 |
+
v1 = 2.0 * sigma12 + C2
|
66 |
+
v2 = sigma1_sq + sigma2_sq + C2
|
67 |
+
cs = torch.mean(v1 / v2) # contrast sensitivity
|
68 |
+
|
69 |
+
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
|
70 |
+
|
71 |
+
if size_average:
|
72 |
+
ret = ssim_map.mean()
|
73 |
+
else:
|
74 |
+
ret = ssim_map.mean(1).mean(1).mean(1)
|
75 |
+
|
76 |
+
if full:
|
77 |
+
return ret, cs
|
78 |
+
return ret
|
79 |
+
|
80 |
+
|
81 |
+
def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
|
82 |
+
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
|
83 |
+
if val_range is None:
|
84 |
+
if torch.max(img1) > 128:
|
85 |
+
max_val = 255
|
86 |
+
else:
|
87 |
+
max_val = 1
|
88 |
+
|
89 |
+
if torch.min(img1) < -0.5:
|
90 |
+
min_val = -1
|
91 |
+
else:
|
92 |
+
min_val = 0
|
93 |
+
L = max_val - min_val
|
94 |
+
else:
|
95 |
+
L = val_range
|
96 |
+
|
97 |
+
padd = 0
|
98 |
+
(_, _, height, width) = img1.size()
|
99 |
+
if window is None:
|
100 |
+
real_size = min(window_size, height, width)
|
101 |
+
window = create_window_3d(real_size, channel=1).to(img1.device)
|
102 |
+
# Channel is set to 1 since we consider color images as volumetric images
|
103 |
+
|
104 |
+
img1 = img1.unsqueeze(1)
|
105 |
+
img2 = img2.unsqueeze(1)
|
106 |
+
|
107 |
+
mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
|
108 |
+
mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
|
109 |
+
|
110 |
+
mu1_sq = mu1.pow(2)
|
111 |
+
mu2_sq = mu2.pow(2)
|
112 |
+
mu1_mu2 = mu1 * mu2
|
113 |
+
|
114 |
+
sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq
|
115 |
+
sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq
|
116 |
+
sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2
|
117 |
+
|
118 |
+
C1 = (0.01 * L) ** 2
|
119 |
+
C2 = (0.03 * L) ** 2
|
120 |
+
|
121 |
+
v1 = 2.0 * sigma12 + C2
|
122 |
+
v2 = sigma1_sq + sigma2_sq + C2
|
123 |
+
cs = torch.mean(v1 / v2) # contrast sensitivity
|
124 |
+
|
125 |
+
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
|
126 |
+
|
127 |
+
if size_average:
|
128 |
+
ret = ssim_map.mean()
|
129 |
+
else:
|
130 |
+
ret = ssim_map.mean(1).mean(1).mean(1)
|
131 |
+
|
132 |
+
if full:
|
133 |
+
return ret, cs
|
134 |
+
return ret
|
135 |
+
|
136 |
+
|
137 |
+
def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
|
138 |
+
device = img1.device
|
139 |
+
weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
|
140 |
+
levels = weights.size()[0]
|
141 |
+
mssim = []
|
142 |
+
mcs = []
|
143 |
+
for _ in range(levels):
|
144 |
+
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
|
145 |
+
mssim.append(sim)
|
146 |
+
mcs.append(cs)
|
147 |
+
|
148 |
+
img1 = F.avg_pool2d(img1, (2, 2))
|
149 |
+
img2 = F.avg_pool2d(img2, (2, 2))
|
150 |
+
|
151 |
+
mssim = torch.stack(mssim)
|
152 |
+
mcs = torch.stack(mcs)
|
153 |
+
|
154 |
+
# Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
|
155 |
+
if normalize:
|
156 |
+
mssim = (mssim + 1) / 2
|
157 |
+
mcs = (mcs + 1) / 2
|
158 |
+
|
159 |
+
pow1 = mcs ** weights
|
160 |
+
pow2 = mssim ** weights
|
161 |
+
# From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
|
162 |
+
output = torch.prod(pow1[:-1] * pow2[-1])
|
163 |
+
return output
|
164 |
+
|
165 |
+
|
166 |
+
# Classes to re-use window
|
167 |
+
class SSIM(torch.nn.Module):
|
168 |
+
def __init__(self, window_size=11, size_average=True, val_range=None):
|
169 |
+
super(SSIM, self).__init__()
|
170 |
+
self.window_size = window_size
|
171 |
+
self.size_average = size_average
|
172 |
+
self.val_range = val_range
|
173 |
+
|
174 |
+
# Assume 3 channel for SSIM
|
175 |
+
self.channel = 3
|
176 |
+
self.window = create_window(window_size, channel=self.channel)
|
177 |
+
|
178 |
+
def forward(self, img1, img2):
|
179 |
+
(_, channel, _, _) = img1.size()
|
180 |
+
|
181 |
+
if channel == self.channel and self.window.dtype == img1.dtype:
|
182 |
+
window = self.window
|
183 |
+
else:
|
184 |
+
window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
|
185 |
+
self.window = window
|
186 |
+
self.channel = channel
|
187 |
+
|
188 |
+
_ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
|
189 |
+
dssim = (1 - _ssim) / 2
|
190 |
+
return dssim
|
191 |
+
|
192 |
+
class MSSSIM(torch.nn.Module):
|
193 |
+
def __init__(self, window_size=11, size_average=True, channel=3):
|
194 |
+
super(MSSSIM, self).__init__()
|
195 |
+
self.window_size = window_size
|
196 |
+
self.size_average = size_average
|
197 |
+
self.channel = channel
|
198 |
+
|
199 |
+
def forward(self, img1, img2):
|
200 |
+
return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
|
model/pytorch_msssim/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (10.3 kB). View file
|
|
model/warplayer.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
5 |
+
backwarp_tenGrid = {}
|
6 |
+
|
7 |
+
|
8 |
+
def warp(tenInput, tenFlow):
|
9 |
+
k = (str(tenFlow.device), str(tenFlow.size()))
|
10 |
+
if k not in backwarp_tenGrid:
|
11 |
+
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
|
12 |
+
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
13 |
+
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
|
14 |
+
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
15 |
+
backwarp_tenGrid[k] = torch.cat(
|
16 |
+
[tenHorizontal, tenVertical], 1).to(device)
|
17 |
+
|
18 |
+
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
19 |
+
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
20 |
+
|
21 |
+
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
22 |
+
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy>=1.16, <=1.23.5
|
2 |
+
tqdm>=4.35.0
|
3 |
+
sk-video>=1.1.10
|
4 |
+
torch>=1.3.0
|
5 |
+
opencv-python>=4.1.2
|
6 |
+
moviepy>=1.0.3
|
7 |
+
torchvision>=0.7.0
|
video_processing.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
from torch.nn import functional as F
|
7 |
+
import warnings
|
8 |
+
import _thread
|
9 |
+
import skvideo.io
|
10 |
+
from queue import Queue, Empty
|
11 |
+
from model.pytorch_msssim import ssim_matlab
|
12 |
+
|
13 |
+
warnings.filterwarnings("ignore")
|
14 |
+
|
15 |
+
def transferAudio(sourceVideo, targetVideo):
|
16 |
+
import shutil
|
17 |
+
import moviepy.editor
|
18 |
+
tempAudioFileName = "./temp/audio.mkv"
|
19 |
+
|
20 |
+
# split audio from original video file and store in "temp" directory
|
21 |
+
if True:
|
22 |
+
|
23 |
+
# clear old "temp" directory if it exits
|
24 |
+
if os.path.isdir("temp"):
|
25 |
+
# remove temp directory
|
26 |
+
shutil.rmtree("temp")
|
27 |
+
# create new "temp" directory
|
28 |
+
os.makedirs("temp")
|
29 |
+
# extract audio from video
|
30 |
+
os.system('ffmpeg -y -i "{}" -c:a copy -vn {}'.format(sourceVideo, tempAudioFileName))
|
31 |
+
|
32 |
+
targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1]
|
33 |
+
os.rename(targetVideo, targetNoAudio)
|
34 |
+
# combine audio file and new video file
|
35 |
+
os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
|
36 |
+
|
37 |
+
if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac
|
38 |
+
tempAudioFileName = "./temp/audio.m4a"
|
39 |
+
os.system('ffmpeg -y -i "{}" -c:a aac -b:a 160k -vn {}'.format(sourceVideo, tempAudioFileName))
|
40 |
+
os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
|
41 |
+
if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format
|
42 |
+
os.rename(targetNoAudio, targetVideo)
|
43 |
+
print("Audio transfer failed. Interpolated video will have no audio")
|
44 |
+
else:
|
45 |
+
print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.")
|
46 |
+
|
47 |
+
# remove audio-less video
|
48 |
+
os.remove(targetNoAudio)
|
49 |
+
else:
|
50 |
+
os.remove(targetNoAudio)
|
51 |
+
|
52 |
+
# remove temp directory
|
53 |
+
shutil.rmtree("temp")
|
54 |
+
|
55 |
+
def process_video(video, output, modelDir, fp16, UHD, scale, skip, fps, png, ext, exp, multi):
|
56 |
+
if exp != 1:
|
57 |
+
multi = (2 ** exp)
|
58 |
+
assert (not video is None)
|
59 |
+
if skip:
|
60 |
+
print("skip flag is abandoned, please refer to issue #207.")
|
61 |
+
if UHD and scale==1.0:
|
62 |
+
scale = 0.5
|
63 |
+
assert scale in [0.25, 0.5, 1.0, 2.0, 4.0]
|
64 |
+
|
65 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
66 |
+
torch.set_grad_enabled(False)
|
67 |
+
if torch.cuda.is_available():
|
68 |
+
torch.backends.cudnn.enabled = True
|
69 |
+
torch.backends.cudnn.benchmark = True
|
70 |
+
if(fp16):
|
71 |
+
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
72 |
+
|
73 |
+
from rife.train_log.RIFE_HDv3 import Model
|
74 |
+
model = Model()
|
75 |
+
if not hasattr(model, 'version'):
|
76 |
+
model.version = 0
|
77 |
+
model.load_model(modelDir, -1)
|
78 |
+
print("Loaded 3.x/4.x HD model.")
|
79 |
+
model.eval()
|
80 |
+
model.device()
|
81 |
+
|
82 |
+
videoCapture = cv2.VideoCapture(video)
|
83 |
+
fps_in = videoCapture.get(cv2.CAP_PROP_FPS)
|
84 |
+
tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
|
85 |
+
videoCapture.release()
|
86 |
+
if fps is None:
|
87 |
+
fpsNotAssigned = True
|
88 |
+
fps_out = fps_in * multi
|
89 |
+
else:
|
90 |
+
fpsNotAssigned = False
|
91 |
+
fps_out = fps
|
92 |
+
videogen = skvideo.io.vreader(video)
|
93 |
+
lastframe = next(videogen)
|
94 |
+
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
|
95 |
+
video_path_wo_ext, video_ext = os.path.splitext(video)
|
96 |
+
print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, ext, tot_frame, fps_in, fps_out))
|
97 |
+
if png == False and fpsNotAssigned == True:
|
98 |
+
print("The audio will be merged after interpolation process")
|
99 |
+
else:
|
100 |
+
print("Will not merge audio because using png or fps flag!")
|
101 |
+
|
102 |
+
h, w, _ = lastframe.shape
|
103 |
+
vid_out_name = None
|
104 |
+
vid_out = None
|
105 |
+
if png:
|
106 |
+
if not os.path.exists('vid_out'):
|
107 |
+
os.mkdir('vid_out')
|
108 |
+
else:
|
109 |
+
if output is not None:
|
110 |
+
vid_out_name = output
|
111 |
+
else:
|
112 |
+
vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, multi, int(np.round(fps_out)), ext)
|
113 |
+
vid_out = cv2.VideoWriter(vid_out_name, fourcc, fps_out, (w, h))
|
114 |
+
|
115 |
+
def clear_write_buffer(user_args, write_buffer):
|
116 |
+
cnt = 0
|
117 |
+
while True:
|
118 |
+
item = write_buffer.get()
|
119 |
+
if item is None:
|
120 |
+
break
|
121 |
+
if user_args.png:
|
122 |
+
cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1])
|
123 |
+
cnt += 1
|
124 |
+
else:
|
125 |
+
vid_out.write(item[:, :, ::-1])
|
126 |
+
|
127 |
+
def build_read_buffer(user_args, read_buffer, videogen):
|
128 |
+
try:
|
129 |
+
for frame in videogen:
|
130 |
+
read_buffer.put(frame)
|
131 |
+
except:
|
132 |
+
pass
|
133 |
+
read_buffer.put(None)
|
134 |
+
|
135 |
+
def make_inference(I0, I1, n):
|
136 |
+
if model.version >= 3.9:
|
137 |
+
res = []
|
138 |
+
for i in range(n):
|
139 |
+
res.append(model.inference(I0, I1, (i+1) * 1. / (n+1), scale))
|
140 |
+
return res
|
141 |
+
else:
|
142 |
+
middle = model.inference(I0, I1, scale)
|
143 |
+
if n == 1:
|
144 |
+
return [middle]
|
145 |
+
first_half = make_inference(I0, middle, n=n//2)
|
146 |
+
second_half = make_inference(middle, I1, n=n//2)
|
147 |
+
if n%2:
|
148 |
+
return [*first_half, middle, *second_half]
|
149 |
+
else:
|
150 |
+
return [*first_half, *second_half]
|
151 |
+
|
152 |
+
def pad_image(img):
|
153 |
+
if(fp16):
|
154 |
+
return F.pad(img, padding).half()
|
155 |
+
else:
|
156 |
+
return F.pad(img, padding)
|
157 |
+
|
158 |
+
tmp = max(128, int(128 / scale))
|
159 |
+
ph = ((h - 1) // tmp + 1) * tmp
|
160 |
+
pw = ((w - 1) // tmp + 1) * tmp
|
161 |
+
padding = (0, pw - w, 0, ph - h)
|
162 |
+
pbar = tqdm(total=tot_frame)
|
163 |
+
write_buffer = Queue(maxsize=500)
|
164 |
+
read_buffer = Queue(maxsize=500)
|
165 |
+
_thread.start_new_thread(build_read_buffer, ((), read_buffer, videogen))
|
166 |
+
_thread.start_new_thread(clear_write_buffer, ((), write_buffer))
|
167 |
+
|
168 |
+
I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
|
169 |
+
I1 = pad_image(I1)
|
170 |
+
temp = None # save lastframe when processing static frame
|
171 |
+
|
172 |
+
while True:
|
173 |
+
if temp is not None:
|
174 |
+
frame = temp
|
175 |
+
temp = None
|
176 |
+
else:
|
177 |
+
frame = read_buffer.get()
|
178 |
+
if frame is None:
|
179 |
+
break
|
180 |
+
I0 = I1
|
181 |
+
I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
|
182 |
+
I1 = pad_image(I1)
|
183 |
+
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
|
184 |
+
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
|
185 |
+
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
186 |
+
|
187 |
+
break_flag = False
|
188 |
+
if ssim > 0.996:
|
189 |
+
frame = read_buffer.get() # read a new frame
|
190 |
+
if frame is None:
|
191 |
+
break_flag = True
|
192 |
+
frame = lastframe
|
193 |
+
else:
|
194 |
+
temp = frame
|
195 |
+
I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
|
196 |
+
I1 = pad_image(I1)
|
197 |
+
I1 = model.inference(I0, I1, scale=scale)
|
198 |
+
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
|
199 |
+
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
200 |
+
frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]
|
201 |
+
|
202 |
+
if ssim < 0.2:
|
203 |
+
output = []
|
204 |
+
for i in range(multi - 1):
|
205 |
+
output.append(I0)
|
206 |
+
else:
|
207 |
+
output = make_inference(I0, I1, multi - 1)
|
208 |
+
|
209 |
+
write_buffer.put(lastframe)
|
210 |
+
for mid in output:
|
211 |
+
mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
|
212 |
+
write_buffer.put(mid[:h, :w])
|
213 |
+
pbar.update(1)
|
214 |
+
lastframe = frame
|
215 |
+
if break_flag:
|
216 |
+
break
|
217 |
+
|
218 |
+
write_buffer.put(lastframe)
|
219 |
+
write_buffer.put(None)
|
220 |
+
|
221 |
+
import time
|
222 |
+
while(not write_buffer.empty()):
|
223 |
+
time.sleep(0.1)
|
224 |
+
pbar.close()
|
225 |
+
if not vid_out is None:
|
226 |
+
vid_out.release()
|
227 |
+
|
228 |
+
if png == False and fpsNotAssigned == True and not video is None:
|
229 |
+
try:
|
230 |
+
transferAudio(video, vid_out_name)
|
231 |
+
except:
|
232 |
+
print("Audio transfer failed. Interpolated video will have no audio")
|
233 |
+
targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1]
|
234 |
+
os.rename(targetNoAudio, vid_out_name)
|
235 |
+
return vid_out_name
|