inoculatemedia commited on
Commit
e950119
·
verified ·
1 Parent(s): b7ae1f6

Upload 23 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo/I0_0.png filter=lfs diff=lfs merge=lfs -text
37
+ demo/I0_1.png filter=lfs diff=lfs merge=lfs -text
38
+ demo/I0_slomo_clipped.gif filter=lfs diff=lfs merge=lfs -text
39
+ demo/i0.png filter=lfs diff=lfs merge=lfs -text
40
+ demo/i1.png filter=lfs diff=lfs merge=lfs -text
41
+ demo/I2_0.png filter=lfs diff=lfs merge=lfs -text
42
+ demo/I2_1.png filter=lfs diff=lfs merge=lfs -text
43
+ demo/I2_slomo_clipped.gif filter=lfs diff=lfs merge=lfs -text
Colab_demo.ipynb ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "name": "Untitled0.ipynb",
7
+ "provenance": [],
8
+ "include_colab_link": true
9
+ },
10
+ "kernelspec": {
11
+ "name": "python3",
12
+ "display_name": "Python 3"
13
+ },
14
+ "accelerator": "GPU"
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "markdown",
19
+ "metadata": {
20
+ "id": "view-in-github",
21
+ "colab_type": "text"
22
+ },
23
+ "source": [
24
+ "<a href=\"https://colab.research.google.com/github/hzwer/Practical-RIFE/blob/main/Colab_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "metadata": {
30
+ "id": "FypCcZkNNt2p"
31
+ },
32
+ "source": [
33
+ "%cd /content\n",
34
+ "!git clone https://github.com/hzwer/Practical-RIFE"
35
+ ],
36
+ "execution_count": null,
37
+ "outputs": []
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "metadata": {
42
+ "id": "1wysVHxoN54f"
43
+ },
44
+ "source": [
45
+ "!gdown --id 1O5KfS3KzZCY3imeCr2LCsntLhutKuAqj\n",
46
+ "!7z e Practical-RIFE/RIFE_trained_model_v3.8.zip"
47
+ ],
48
+ "execution_count": null,
49
+ "outputs": []
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "metadata": {
54
+ "id": "AhbHfRBJRAUt"
55
+ },
56
+ "source": [
57
+ "!mkdir /content/Practical-RIFE/train_log\n",
58
+ "!mv *.py /content/Practical-RIFE/train_log/\n",
59
+ "!mv *.pkl /content/Practical-RIFE/train_log/\n",
60
+ "%cd /content/Practical-RIFE/\n",
61
+ "!gdown --id 1i3xlKb7ax7Y70khcTcuePi6E7crO_dFc\n",
62
+ "!pip3 install -r requirements.txt"
63
+ ],
64
+ "execution_count": null,
65
+ "outputs": []
66
+ },
67
+ {
68
+ "cell_type": "markdown",
69
+ "metadata": {
70
+ "id": "rirngW5uRMdg"
71
+ },
72
+ "source": [
73
+ "Please upload your video to content/Practical-RIFE/video.mp4, or use our demo video."
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "metadata": {
79
+ "id": "dnLn4aHHPzN3"
80
+ },
81
+ "source": [
82
+ "!nvidia-smi\n",
83
+ "!python3 inference_video.py --exp=1 --video=demo.mp4 --montage --skip"
84
+ ],
85
+ "execution_count": null,
86
+ "outputs": []
87
+ },
88
+ {
89
+ "cell_type": "markdown",
90
+ "metadata": {
91
+ "id": "77KK6lxHgJhf"
92
+ },
93
+ "source": [
94
+ "Our demo.mp4 is 25FPS. You can adjust the parameters for your own perference.\n",
95
+ "For example: \n",
96
+ "--fps=60 --exp=1 --video=mydemo.avi --png"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "metadata": {
102
+ "id": "0zIBbVE3UfUD",
103
+ "cellView": "code"
104
+ },
105
+ "source": [
106
+ "from IPython.display import display, Image\n",
107
+ "import moviepy.editor as mpy\n",
108
+ "display(mpy.ipython_display('demo_4X_100fps.mp4', height=256, max_duration=100.))"
109
+ ],
110
+ "execution_count": null,
111
+ "outputs": []
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "metadata": {
116
+ "id": "tWkJCNgP3zXA"
117
+ },
118
+ "source": [
119
+ "!python3 inference_img.py --img demo/I0_0.png demo/I0_1.png\n",
120
+ "ffmpeg -r 10 -f image2 -i output/img%d.png -s 448x256 -vf \"split[s0][s1];[s0]palettegen=stats_mode=single[p];[s1][p]paletteuse=new=1\" output/slomo.gif\n",
121
+ "# Image interpolation"
122
+ ],
123
+ "execution_count": null,
124
+ "outputs": []
125
+ }
126
+ ]
127
+ }
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 hzwer
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,129 @@
1
- ---
2
- title: Zerogpu Upscaler Interpolation
3
- emoji: 📚
4
- colorFrom: gray
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.38.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Practical-RIFE
2
+ **[V4.0 Anime Demo Video](https://www.bilibili.com/video/BV1J3411t7qT?p=1&share_medium=iphone&share_plat=ios&share_session_id=7AE3DA72-D05C-43A0-9838-E2A80885BD4E&share_source=QQ&share_tag=s_i&timestamp=1639643780&unique_k=rjqO0EK)** | **[迭代经验](https://zhuanlan.zhihu.com/p/721430631)** | **[迭代QA](https://github.com/hzwer/Practical-RIFE/issues/124)** | **[Colab](https://colab.research.google.com/drive/1BZmGSq15O4ZU5vPfzkv7jFNYahTm6qwT?usp=sharing)**
3
+
4
+ This project is based on [RIFE](https://github.com/hzwer/arXiv2020-RIFE) and [SAFA](https://github.com/megvii-research/WACV2024-SAFA). We aim to enhance their practicality for users by incorporating various features and designing new models. Since improving the PSNR index is not consistent with subjective perception. This project is intended for engineers and developers. For general users, we recommend the following software:
5
+
6
+ **[SVFI (中文)](https://github.com/YiWeiHuang-stack/Squirrel-Video-Frame-Interpolation) | [RIFE-App](https://grisk.itch.io/rife-app) | [FlowFrames](https://nmkd.itch.io/flowframes)**
7
+
8
+ Thanks to [SVFI team](https://github.com/Justin62628/Squirrel-RIFE) to support model testing on Animation.
9
+
10
+ [VapourSynth-RIFE](https://github.com/HolyWu/vs-rife) | [RIFE-ncnn-vulkan](https://github.com/nihui/rife-ncnn-vulkan) | [VapourSynth-RIFE-ncnn-Vulkan](https://github.com/styler00dollar/VapourSynth-RIFE-ncnn-Vulkan) | [vs-mlrt](https://github.com/AmusementClub/vs-mlrt) | [Drop frame fixer and FPS converter](https://github.com/may-son/RIFE-FixDropFrames-and-ConvertFPS)
11
+
12
+ ## Frame Interpolation
13
+ 2024.08 - We find that 4.24+ is quite suitable for post-processing of [some diffusion model generated videos](https://drive.google.com/drive/folders/1hSzUn10Era3JCaVz0Z5Eg4wT9R6eJ9U9?usp=sharing).
14
+
15
+ ### Trained Model
16
+ The content of these links is under the same MIT license as this project. **lite** means using similar training framework, but lower computational cost model.
17
+ Currently, it is recommended to choose 4.25 by default for most scenes.
18
+
19
+ 4.26 - 2024.09.21 | [Google Drive](https://drive.google.com/file/d/1gViYvvQrtETBgU1w8axZSsr7YUuw31uy/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/1EZsG3IFO8C1e2uRVb_Npgg?pwd=smw8) || [4.25.lite - 2024.10.20](https://drive.google.com/file/d/1zlKblGuKNatulJNFf5jdB-emp9AqGK05/view?usp=share_link)
20
+
21
+ 4.25 - 2024.09.19 | [Google Drive](https://drive.google.com/file/d/1ZKjcbmt1hypiFprJPIKW0Tt0lr_2i7bg/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/1rpUX5uawusz2uwEdXtjRbw?pwd=mo6k) | I am trying using more flow blocks, so the scale_list will change accordingly. It seems that the anime scenes have been significantly improved.
22
+
23
+ 4.22 - 2024.08.08 | [Google Drive](https://drive.google.com/file/d/1qh2DSA9a1eZUTtZG9U9RQKO7N7OaUJ0_/view?usp=share_link) [百度网盘](https://pan.baidu.com/s/1EA5BIHqOu35Rj4meg00G4g?pwd=hwym) || [4.22.lite](https://drive.google.com/file/d/1Smy6gY7BkS_RzCjPCbMEy-TsX8Ma5B0R/view?usp=sharing) || 4.21 - 2024.08.04 | [Google Drive](https://drive.google.com/file/d/1l5u6G8vEkPAT7cYYWwzB6OG8vwBYrxiS/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/1TMjRFOwdLgsShKdGbTKW_g?pwd=4q6d)
24
+
25
+ 4.20 - 2024.07.24 | [Google Drive](https://drive.google.com/file/d/11n3YR7-qCRZm9RDdwtqOTsgCJUHPuexA/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/1v0b7ZTSj_VvLOfW-hQ_NZQ?pwd=ykkv)
26
+ || 4.18 - 2024.07.03 | [Google Drive](https://drive.google.com/file/d/1octn-UVuEjXa_HlsIUbNeLTTvYCKbC_s/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/1fqtxJyXSgUx-gE3rieuKxg?pwd=udr1)
27
+
28
+ 4.17 - 2024.05.24 | [Google Drive](https://drive.google.com/file/d/1962p_lEWo_kLTEynarNaRYRNVdaiQG2k/view?usp=share_link) [百度网盘](https://pan.baidu.com/s/1bMzTYoJKZXsoxuSBmzj6VQ?pwd=as37) : Add gram loss from [FILM](https://github.com/google-research/frame-interpolation/blob/69f8708f08e62c2edf46a27616a4bfcf083e2076/losses/vgg19_loss.py) || [4.17.lite](https://drive.google.com/file/d/1e9Qb4rm20UAsO7h9VILDwrpvTSHWWW8b/view?usp=share_link)
29
+
30
+ 4.15 - 2024.03.11 | [Google Drive](https://drive.google.com/file/d/1xlem7cfKoMaiLzjoeum8KIQTYO-9iqG5/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/1IGNIX7JXGUwI_tfoafYHqA?pwd=bg0b) || [4.15.lite](https://drive.google.com/file/d/1BoOF-qSEnTPDjpKG1sBTa6k7Sv5_-k7z/view?usp=sharing) || 4.14 - 2024.01.08 | [Google Drive](https://drive.google.com/file/d/1BjuEY7CHZv1wzmwXSQP9ZTj0mLWu_4xy/view?usp=share_link) [百度网盘](https://pan.baidu.com/s/1d-W64lRsJTqNsgWoXYiaWQ?pwd=xawa) || [4.14.lite](https://drive.google.com/file/d/1eULia_onOtRXHMAW9VeDL8N2_7z8J1ba/view?usp=share_link)
31
+
32
+ v4.9.2 - 2023.11.01 | [Google Drive](https://drive.google.com/file/d/1UssCvbL8N-ty0xIKM5G5ZTEgp9o4w3hp/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/18cbx3EP4HWgSa1vkcXvvyw?pwd=swr9) || v4.3 - 2022.8.17 | [Google Drive](https://drive.google.com/file/d/1xrNofTGMHdt9sQv7-EOG0EChl8hZW_cU/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/12AUAeZLZf5E1_Zx6WkS3xw?pwd=q83a)
33
+
34
+ v3.8 - 2021.6.17 | [Google Drive](https://drive.google.com/file/d/1O5KfS3KzZCY3imeCr2LCsntLhutKuAqj/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/1X-jpWBZWe-IQBoNAsxo2mA?pwd=kxr3) || v3.1 - 2021.5.17 | [Google Drive](https://drive.google.com/file/d/1xn4R3TQyFhtMXN2pa3lRB8cd4E1zckQe/view?usp=sharing) [百度网盘](https://pan.baidu.com/s/1W4p_Ni04HLI_jTy45sVodA?pwd=64bz)
35
+
36
+ [More Older Version](https://github.com/megvii-research/ECCV2022-RIFE/issues/41)
37
+
38
+ ### Installation
39
+ python <= 3.11
40
+ ```
41
+ git clone [email protected]:hzwer/Practical-RIFE.git
42
+ cd Practical-RIFE
43
+ pip3 install -r requirements.txt
44
+ ```
45
+ Download a model from the model list and put *.py and flownet.pkl on train_log/
46
+ ### Run
47
+
48
+ You can use our [demo video](https://drive.google.com/file/d/1i3xlKb7ax7Y70khcTcuePi6E7crO_dFc/view?usp=sharing) or your video.
49
+ ```
50
+ python3 inference_video.py --multi=2 --video=video.mp4
51
+ ```
52
+ (generate video_2X_xxfps.mp4)
53
+ ```
54
+ python3 inference_video.py --multi=4 --video=video.mp4
55
+ ```
56
+ (for 4X interpolation)
57
+ ```
58
+ python3 inference_video.py --multi=2 --video=video.mp4 --scale=0.5
59
+ ```
60
+ (If your video has high resolution, such as 4K, we recommend set --scale=0.5 (default 1.0))
61
+ ```
62
+ python3 inference_video.py --multi=4 --img=input/
63
+ ```
64
+ (to read video from pngs, like input/0.png ... input/612.png, ensure that the png names are numbers)
65
+
66
+ Parameter descriptions:
67
+
68
+ --img / --video: The input file address
69
+
70
+ --output: Output video name 'xxx.mp4'
71
+
72
+ --model: Directory with trained model files
73
+
74
+ --UHD: It is equivalent to setting scale=0.5
75
+
76
+ --montage: Splice the generated video with the original video, like [this demo](https://www.youtube.com/watch?v=kUQ7KK6MhHw)
77
+
78
+ --fps: Set output FPS manually
79
+
80
+ --ext: Set output video format, default: mp4
81
+
82
+ --multi: Interpolation frame rate multiplier
83
+
84
+ --exp: Set --multi to 2^(--exp)
85
+
86
+ --skip: It's no longer useful refer to [issue 207](https://github.com/hzwer/ECCV2022-RIFE/issues/207)
87
+
88
+
89
+ ### Model training
90
+ The whole repo can be downloaded from [v4.0](https://drive.google.com/file/d/1zoSz7b8c6kUsnd4gYZ_6TrKxa7ghHJWW/view?usp=sharing), [v4.12](https://drive.google.com/file/d/1IHB35zhO4rr-JSMnpRvHhU9U65Z4giWv/view?usp=sharing), [v4.15](https://drive.google.com/file/d/19sUMZ-6H7g_hYDjTcqxYu9kE7TqnfS3k/view?usp=sharing), [v4.18](https://drive.google.com/file/d/1g8D2foww7DhGLIxtaDLr9fU3y-ByOw4B/view?usp=share_link), [v4.25](https://drive.google.com/file/d/1_l4OgBp3GrrHOcQB87xXCI7OtTzyeXZL/view?usp=share_link). However, we currently do not have the time to organize them well, they are for reference only.
91
+
92
+ ## Video Enhancement
93
+
94
+ <img width="710" alt="image" src="https://github.com/hzwer/Practical-RIFE/assets/10103856/5bae134c-0747-4084-bbab-37b1595352f1">
95
+
96
+ We are developing a practical model of [SAFA](https://github.com/megvii-research/WACV2024-SAFA). Welcome to check its [demo](https://www.youtube.com/watch?v=QII2KQSBBwk) ([BiliBili](https://www.bilibili.com/video/BV1Up4y1d7kF/)) and provide advice.
97
+
98
+ v0.5 - 2023.12.26 | [Google Drive](https://drive.google.com/file/d/1OLO9hLV97ZQ4uRV2-aQqgnwhbKMMt6TX/view?usp=sharing)
99
+
100
+ ```
101
+ python3 inference_video_enhance.py --video=demo.mp4
102
+ ```
103
+
104
+ ## Citation
105
+
106
+ ```
107
+ @inproceedings{huang2022rife,
108
+ title={Real-Time Intermediate Flow Estimation for Video Frame Interpolation},
109
+ author={Huang, Zhewei and Zhang, Tianyuan and Heng, Wen and Shi, Boxin and Zhou, Shuchang},
110
+ booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
111
+ year={2022}
112
+ }
113
+ ```
114
+ ```
115
+ @inproceedings{huang2024safa,
116
+ title={Scale-Adaptive Feature Aggregation for Efficient Space-Time Video Super-Resolution},
117
+ author={Huang, Zhewei and Huang, Ailin and Hu, Xiaotao and Hu, Chen and Xu, Jun and Zhou, Shuchang},
118
+ booktitle={Winter Conference on Applications of Computer Vision (WACV)},
119
+ year={2024}
120
+ }
121
+ ```
122
+
123
+ ## Reference
124
+
125
+ Optical Flow:
126
+ [ARFlow](https://github.com/lliuz/ARFlow) [pytorch-liteflownet](https://github.com/sniklaus/pytorch-liteflownet) [RAFT](https://github.com/princeton-vl/RAFT) [pytorch-PWCNet](https://github.com/sniklaus/pytorch-pwc)
127
+
128
+ Video Interpolation:
129
+ [DVF](https://github.com/lxx1991/pytorch-voxel-flow) [TOflow](https://github.com/Coldog2333/pytoflow) [SepConv](https://github.com/sniklaus/sepconv-slomo) [DAIN](https://github.com/baowenbo/DAIN) [CAIN](https://github.com/myungsub/CAIN) [MEMC-Net](https://github.com/baowenbo/MEMC-Net) [SoftSplat](https://github.com/sniklaus/softmax-splatting) [BMBC](https://github.com/JunHeum/BMBC) [EDSC](https://github.com/Xianhang/EDSC-pytorch) [EQVI](https://github.com/lyh-18/EQVI) [RIFE](https://github.com/hzwer/arXiv2020-RIFE)
__pycache__/video_processing.cpython-312.pyc ADDED
Binary file (12.1 kB). View file
 
app.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image, ImageFilter
4
+ import cv2
5
+ import os
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torchvision import transforms
9
+ import warnings
10
+ from video_processing import process_video
11
+
12
+ warnings.filterwarnings("ignore")
13
+
14
+ # ZeroGPU decorator (if available)
15
+ try:
16
+ import spaces
17
+ HAS_ZEROGPU = True
18
+ except ImportError:
19
+ HAS_ZEROGPU = False
20
+ # Create a dummy decorator if spaces is not available
21
+ def spaces_gpu(func):
22
+ return func
23
+ spaces = type('spaces', (), {'GPU': spaces_gpu})()
24
+
25
+ # VAAPI acceleration check
26
+ def check_vaapi_support():
27
+ """Check if VAAPI is available for hardware acceleration"""
28
+ try:
29
+ # Check if VAAPI devices are available
30
+ vaapi_devices = [f for f in os.listdir('/dev/dri') if f.startswith('render')]
31
+ return len(vaapi_devices) > 0
32
+ except:
33
+ return False
34
+
35
+ HAS_VAAPI = check_vaapi_support()
36
+
37
+ class TorchUpscaler:
38
+ """PyTorch-based upscaler that can use GPU acceleration"""
39
+
40
+ def __init__(self, device='auto'):
41
+ if device == 'auto':
42
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43
+ else:
44
+ self.device = torch.device(device)
45
+
46
+ print(f"Using device: {self.device}")
47
+
48
+ def bicubic_torch(self, image_tensor, scale_factor):
49
+ """GPU-accelerated bicubic upscaling using PyTorch"""
50
+ return F.interpolate(
51
+ image_tensor,
52
+ scale_factor=scale_factor,
53
+ mode='bicubic',
54
+ align_corners=False,
55
+ antialias=True
56
+ )
57
+
58
+ def lanczos_torch(self, image_tensor, scale_factor):
59
+ """GPU-accelerated Lanczos-style upscaling"""
60
+ return F.interpolate(
61
+ image_tensor,
62
+ scale_factor=scale_factor,
63
+ mode='bicubic',
64
+ align_corners=False,
65
+ antialias=True
66
+ )
67
+
68
+ def esrgan_style_upscale(self, image_tensor, scale_factor):
69
+ """Simple ESRGAN-style upscaling using convolutions"""
70
+ b, c, h, w = image_tensor.shape
71
+ upscaled = F.interpolate(image_tensor, scale_factor=scale_factor, mode='bicubic', align_corners=False)
72
+ kernel = torch.tensor([[[[-1, -1, -1],
73
+ [-1, 9, -1],
74
+ [-1, -1, -1]]]], dtype=torch.float32, device=self.device)
75
+ kernel = kernel.repeat(c, 1, 1, 1)
76
+ sharpened = F.conv2d(upscaled, kernel, padding=1, groups=c)
77
+ result = 0.8 * upscaled + 0.2 * sharpened
78
+ return torch.clamp(result, 0, 1)
79
+
80
+ class VAAPIUpscaler:
81
+ """VAAPI hardware-accelerated upscaler"""
82
+
83
+ def __init__(self):
84
+ self.vaapi_available = HAS_VAAPI
85
+ if self.vaapi_available:
86
+ print("VAAPI hardware acceleration available")
87
+ else:
88
+ print("VAAPI hardware acceleration not available")
89
+
90
+ def upscale_vaapi(self, image_array, scale_factor, method):
91
+ """Use VAAPI for hardware-accelerated upscaling"""
92
+ if not self.vaapi_available:
93
+ return None
94
+ try:
95
+ h, w = image_array.shape[:2]
96
+ new_h, new_w = int(h * scale_factor), int(w * scale_factor)
97
+ if method == "VAAPI_BICUBIC":
98
+ return cv2.resize(image_array, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
99
+ elif method == "VAAPI_LANCZOS":
100
+ return cv2.resize(image_array, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
101
+ except Exception as e:
102
+ print(f"VAAPI upscaling failed: {e}")
103
+ return None
104
+
105
+ torch_upscaler = TorchUpscaler()
106
+ vaapi_upscaler = VAAPIUpscaler()
107
+
108
+ @spaces.GPU if HAS_ZEROGPU else lambda x: x
109
+ def upscale_image_accelerated(image, scale_factor, method, enhance_quality, use_gpu_acceleration):
110
+ if image is None:
111
+ return None
112
+
113
+ original_width, original_height = image.size
114
+ new_width = int(original_width * scale_factor)
115
+ new_height = int(original_height * scale_factor)
116
+
117
+ try:
118
+ if use_gpu_acceleration and torch.cuda.is_available():
119
+ print("Using GPU acceleration")
120
+ transform = transforms.Compose([transforms.ToTensor()])
121
+ image_tensor = transform(image).unsqueeze(0).to(torch_upscaler.device)
122
+
123
+ if method == "GPU_Bicubic":
124
+ upscaled_tensor = torch_upscaler.bicubic_torch(image_tensor, scale_factor)
125
+ elif method == "GPU_Lanczos":
126
+ upscaled_tensor = torch_upscaler.lanczos_torch(image_tensor, scale_factor)
127
+ elif method == "GPU_ESRGAN_Style":
128
+ upscaled_tensor = torch_upscaler.esrgan_style_upscale(image_tensor, scale_factor)
129
+ else:
130
+ upscaled_tensor = torch_upscaler.bicubic_torch(image_tensor, scale_factor)
131
+
132
+ upscaled = transforms.ToPILImage()(upscaled_tensor.squeeze(0).cpu())
133
+
134
+ elif method.startswith("VAAPI_") and HAS_VAAPI:
135
+ print("Using VAAPI acceleration")
136
+ img_array = np.array(image)
137
+ upscaled_array = vaapi_upscaler.upscale_vaapi(img_array, scale_factor, method)
138
+ upscaled = Image.fromarray(upscaled_array) if upscaled_array is not None else image.resize((new_width, new_height), Image.BICUBIC)
139
+
140
+ else:
141
+ print("Using CPU methods")
142
+ if method == "Bicubic":
143
+ upscaled = image.resize((new_width, new_height), Image.BICUBIC)
144
+ elif method == "Lanczos":
145
+ upscaled = image.resize((new_width, new_height), Image.LANCZOS)
146
+ else:
147
+ upscaled = image.resize((new_width, new_height), Image.BICUBIC)
148
+
149
+ if enhance_quality:
150
+ upscaled = upscaled.filter(ImageFilter.UnsharpMask(radius=1, percent=120, threshold=3))
151
+
152
+ return upscaled
153
+
154
+ except Exception as e:
155
+ print(f"Error during upscaling: {e}")
156
+ return image
157
+
158
+ def get_available_methods():
159
+ methods = ["Bicubic", "Lanczos"]
160
+ if torch.cuda.is_available():
161
+ methods.extend(["GPU_Bicubic", "GPU_Lanczos", "GPU_ESRGAN_Style"])
162
+ if HAS_VAAPI:
163
+ methods.extend(["VAAPI_BICUBIC", "VAAPI_LANCZOS"])
164
+ return methods
165
+
166
+ def get_system_info():
167
+ info = []
168
+ if torch.cuda.is_available():
169
+ gpu_name = torch.cuda.get_device_name(0)
170
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
171
+ info.append(f"🚀 CUDA GPU: {gpu_name} ({gpu_memory:.1f} GB)")
172
+ else:
173
+ info.append("❌ CUDA not available")
174
+ if HAS_ZEROGPU:
175
+ info.append("✅ ZeroGPU support enabled")
176
+ if HAS_VAAPI:
177
+ info.append("✅ VAAPI hardware acceleration available")
178
+ return "\n".join(info)
179
+
180
+ def process_and_info_accelerated(image, scale_factor, method, enhance_quality, use_gpu_acceleration):
181
+ if image is None:
182
+ return None, "Please upload an image first"
183
+
184
+ original_info = f"Original: {image.size[0]} × {image.size[1]} pixels"
185
+ result = upscale_image_accelerated(image, scale_factor, method, enhance_quality, use_gpu_acceleration)
186
+ if result is None:
187
+ return None, "Error processing image"
188
+
189
+ result_info = f"Upscaled: {result.size[0]} × {result.size[1]} pixels"
190
+ accel_info = "GPU/Hardware" if use_gpu_acceleration else "CPU"
191
+
192
+ combined_info = f"""
193
+ ## Processing Details
194
+ {original_info}
195
+ {result_info}
196
+ **Scale Factor:** {scale_factor}x
197
+ **Method:** {method}
198
+ **Acceleration:** {accel_info}
199
+ **Quality Enhancement:** {'✅' if enhance_quality else '❌'}
200
+
201
+ ## System Status
202
+ {get_system_info()}
203
+ """
204
+ return result, combined_info
205
+
206
+ def create_accelerated_upscaler_ui():
207
+ available_methods = get_available_methods()
208
+
209
+ gr.Markdown("## 🚀 Accelerated Image Upscaler")
210
+ with gr.Row():
211
+ with gr.Column(scale=1):
212
+ input_image = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard"])
213
+ scale_factor = gr.Slider(minimum=1.5, maximum=4.0, step=0.5, value=2.0, label="Scale Factor")
214
+ method = gr.Dropdown(choices=available_methods, value=available_methods[0], label="Upscaling Method")
215
+ use_gpu_acceleration = gr.Checkbox(label="Use GPU Acceleration", value=torch.cuda.is_available())
216
+ enhance_quality = gr.Checkbox(label="Apply Quality Enhancement", value=True)
217
+ process_btn = gr.Button("🚀 Upscale Image", variant="primary")
218
+
219
+ with gr.Column(scale=2):
220
+ output_image = gr.Image(label="Upscaled Image", type="pil")
221
+ image_info = gr.Markdown(value=f"## System Status\n{get_system_info()}", label="Processing Information")
222
+
223
+ process_btn.click(
224
+ fn=process_and_info_accelerated,
225
+ inputs=[input_image, scale_factor, method, enhance_quality, use_gpu_acceleration],
226
+ outputs=[output_image, image_info]
227
+ )
228
+
229
+ def create_video_interface_ui():
230
+ gr.Markdown("## 🚀 Video Upscaler and Frame Interpolator")
231
+ with gr.Row():
232
+ with gr.Column(scale=1):
233
+ input_video = gr.Video(label="Upload Video", sources=["upload"])
234
+ scale_factor = gr.Slider(minimum=1.5, maximum=4.0, step=0.5, value=2.0, label="Scale Factor")
235
+ multi = gr.Slider(minimum=2, maximum=8, step=1, value=2, label="Frame Multiplier")
236
+ use_gpu_acceleration = gr.Checkbox(label="Use GPU Acceleration", value=torch.cuda.is_available())
237
+ process_btn = gr.Button("🚀 Process Video", variant="primary")
238
+
239
+ with gr.Column(scale=2):
240
+ output_video = gr.Video(label="Processed Video")
241
+ processing_info = gr.Markdown(value=f"## System Status\n{get_system_info()}", label="Processing Information")
242
+
243
+ process_btn.click(
244
+ fn=process_video_wrapper,
245
+ inputs=[input_video, scale_factor, multi, use_gpu_acceleration],
246
+ outputs=[output_video, processing_info]
247
+ )
248
+
249
+ def process_video_wrapper(video_path, scale_factor, multi, use_gpu):
250
+ if video_path is None:
251
+ return None, "Please upload a video first"
252
+
253
+ output_path = "temp_output.mp4"
254
+ modelDir = 'rife/train_log'
255
+
256
+ processed_video_path = process_video(
257
+ video=video_path,
258
+ output=output_path,
259
+ modelDir=modelDir,
260
+ fp16=use_gpu,
261
+ UHD=False,
262
+ scale=scale_factor,
263
+ skip=False,
264
+ fps=None,
265
+ png=False,
266
+ ext='mp4',
267
+ exp=1,
268
+ multi=multi
269
+ )
270
+
271
+ info = f"""
272
+ ## Processing Details
273
+ **Scale Factor:** {scale_factor}x
274
+ **Frame Multiplier:** {multi}x
275
+ **Acceleration:** {'GPU' if use_gpu else 'CPU'}
276
+
277
+ ## System Status
278
+ {get_system_info()}
279
+ """
280
+ return processed_video_path, info
281
+
282
+ with gr.Blocks(title="Accelerated Media Processor", theme=gr.themes.Soft()) as demo:
283
+ with gr.Tab("Image Upscaler"):
284
+ create_accelerated_upscaler_ui()
285
+ with gr.Tab("Video Processing"):
286
+ create_video_interface_ui()
287
+
288
+ if __name__ == "__main__":
289
+ demo.launch(
290
+ server_name="0.0.0.0",
291
+ server_port=7860,
292
+ share=False,
293
+ debug=True
294
+ )
demo/I0_0.png ADDED

Git LFS Details

  • SHA256: dffc62ff6436b3d1c02aec36a8b3c65603bd9d83e0001e942c4b99d1e509a6c3
  • Pointer size: 131 Bytes
  • Size of remote file: 180 kB
demo/I0_1.png ADDED

Git LFS Details

  • SHA256: fcf4eb07e7c63de8f1508fb93854ce464dd4bc9fd6e85be0e5922b9f700c7c6a
  • Pointer size: 131 Bytes
  • Size of remote file: 183 kB
demo/I0_slomo_clipped.gif ADDED

Git LFS Details

  • SHA256: 5ad98421883a509d66916a8cf87fcd1a4268f23ab85cf09910e5709a466f9aa9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
demo/I2_0.png ADDED

Git LFS Details

  • SHA256: d8598c23c159508b348ecd7cbab5bf2a00530f6bd3316bf3ac5c09ddec4c14c1
  • Pointer size: 131 Bytes
  • Size of remote file: 163 kB
demo/I2_1.png ADDED

Git LFS Details

  • SHA256: cc0c17b74bff42d21793ede192a7ec028df569a9bc466c5b17ab0cbaf76d53fc
  • Pointer size: 131 Bytes
  • Size of remote file: 163 kB
demo/I2_slomo_clipped.gif ADDED

Git LFS Details

  • SHA256: a15a9a33ce9d87173ea7a4e9c4722b14f81a52633292a1e4e7babeac62fbc623
  • Pointer size: 131 Bytes
  • Size of remote file: 967 kB
demo/i0.png ADDED

Git LFS Details

  • SHA256: a01f79bf3c485c6d59284f36c9ec7933598313ae847b0394acbc0c4573491687
  • Pointer size: 131 Bytes
  • Size of remote file: 132 kB
demo/i1.png ADDED

Git LFS Details

  • SHA256: 7ba4a0f8eae5a62ce567ef2611349e56907e29b968b9da5d2e57a8f4525d51ea
  • Pointer size: 131 Bytes
  • Size of remote file: 132 kB
inference_img.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import argparse
5
+ from torch.nn import functional as F
6
+ import warnings
7
+ warnings.filterwarnings("ignore")
8
+
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ torch.set_grad_enabled(False)
11
+ if torch.cuda.is_available():
12
+ torch.backends.cudnn.enabled = True
13
+ torch.backends.cudnn.benchmark = True
14
+
15
+ parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
16
+ parser.add_argument('--img', dest='img', nargs=2, required=True)
17
+ parser.add_argument('--exp', default=4, type=int)
18
+ parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range')
19
+ parser.add_argument('--rthreshold', default=0.02, type=float, help='returns image when actual ratio falls in given range threshold')
20
+ parser.add_argument('--rmaxcycles', default=8, type=int, help='limit max number of bisectional cycles')
21
+ parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files')
22
+
23
+ args = parser.parse_args()
24
+
25
+ try:
26
+ try:
27
+ from model.RIFE_HDv2 import Model
28
+ model = Model()
29
+ model.load_model(args.modelDir, -1)
30
+ print("Loaded v2.x HD model.")
31
+ except:
32
+ from train_log.RIFE_HDv3 import Model
33
+ model = Model()
34
+ model.load_model(args.modelDir, -1)
35
+ print("Loaded v3.x HD model.")
36
+ except:
37
+ from model.RIFE_HD import Model
38
+ model = Model()
39
+ model.load_model(args.modelDir, -1)
40
+ print("Loaded v1.x HD model")
41
+ if not hasattr(model, 'version'):
42
+ model.version = 0
43
+ model.eval()
44
+ model.device()
45
+
46
+ if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
47
+ img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
48
+ img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
49
+ img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0)
50
+ img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0)
51
+
52
+ else:
53
+ img0 = cv2.imread(args.img[0], cv2.IMREAD_UNCHANGED)
54
+ img1 = cv2.imread(args.img[1], cv2.IMREAD_UNCHANGED)
55
+ img0 = cv2.resize(img0, (448, 256))
56
+ img1 = cv2.resize(img1, (448, 256))
57
+ img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
58
+ img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
59
+
60
+ n, c, h, w = img0.shape
61
+ ph = ((h - 1) // 64 + 1) * 64
62
+ pw = ((w - 1) // 64 + 1) * 64
63
+ padding = (0, pw - w, 0, ph - h)
64
+ img0 = F.pad(img0, padding)
65
+ img1 = F.pad(img1, padding)
66
+
67
+
68
+ if args.ratio:
69
+ if model.version >= 3.9:
70
+ img_list = [img0, model.inference(img0, img1, args.ratio), img1]
71
+ else:
72
+ img0_ratio = 0.0
73
+ img1_ratio = 1.0
74
+ if args.ratio <= img0_ratio + args.rthreshold / 2:
75
+ middle = img0
76
+ elif args.ratio >= img1_ratio - args.rthreshold / 2:
77
+ middle = img1
78
+ else:
79
+ tmp_img0 = img0
80
+ tmp_img1 = img1
81
+ for inference_cycle in range(args.rmaxcycles):
82
+ middle = model.inference(tmp_img0, tmp_img1)
83
+ middle_ratio = ( img0_ratio + img1_ratio ) / 2
84
+ if args.ratio - (args.rthreshold / 2) <= middle_ratio <= args.ratio + (args.rthreshold / 2):
85
+ break
86
+ if args.ratio > middle_ratio:
87
+ tmp_img0 = middle
88
+ img0_ratio = middle_ratio
89
+ else:
90
+ tmp_img1 = middle
91
+ img1_ratio = middle_ratio
92
+ img_list.append(middle)
93
+ img_list.append(img1)
94
+ else:
95
+ if model.version >= 3.9:
96
+ img_list = [img0]
97
+ n = 2 ** args.exp
98
+ for i in range(n-1):
99
+ img_list.append(model.inference(img0, img1, (i+1) * 1. / n))
100
+ img_list.append(img1)
101
+ else:
102
+ img_list = [img0, img1]
103
+ for i in range(args.exp):
104
+ tmp = []
105
+ for j in range(len(img_list) - 1):
106
+ mid = model.inference(img_list[j], img_list[j + 1])
107
+ tmp.append(img_list[j])
108
+ tmp.append(mid)
109
+ tmp.append(img1)
110
+ img_list = tmp
111
+
112
+ if not os.path.exists('output'):
113
+ os.mkdir('output')
114
+ for i in range(len(img_list)):
115
+ if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
116
+ cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
117
+ else:
118
+ cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
inference_img_SR.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import argparse
5
+ from torch.nn import functional as F
6
+ import warnings
7
+ warnings.filterwarnings("ignore")
8
+
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ torch.set_grad_enabled(False)
11
+ if torch.cuda.is_available():
12
+ torch.backends.cudnn.enabled = True
13
+ torch.backends.cudnn.benchmark = True
14
+
15
+ parser = argparse.ArgumentParser(description='STVSR for a pair of images')
16
+ parser.add_argument('--img', dest='img', nargs=2, required=True)
17
+ parser.add_argument('--exp', default=2, type=int)
18
+ parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range')
19
+ parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files')
20
+
21
+ args = parser.parse_args()
22
+
23
+ from train_log.model import Model
24
+ model = Model()
25
+ model.device()
26
+ model.load_model('train_log')
27
+ model.eval()
28
+
29
+ if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
30
+ img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
31
+ img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
32
+ img0 = cv2.resize(img0, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
33
+ img1 = cv2.resize(img1, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
34
+ img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0)
35
+ img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0)
36
+ else:
37
+ img0 = cv2.imread(args.img[0], cv2.IMREAD_UNCHANGED)
38
+ img1 = cv2.imread(args.img[1], cv2.IMREAD_UNCHANGED)
39
+ img0 = cv2.resize(img0, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
40
+ img1 = cv2.resize(img1, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
41
+ img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
42
+ img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
43
+
44
+ n, c, h, w = img0.shape
45
+ ph = ((h - 1) // 32 + 1) * 32
46
+ pw = ((w - 1) // 32 + 1) * 32
47
+ padding = (0, pw - w, 0, ph - h)
48
+ img0 = F.pad(img0, padding)
49
+ img1 = F.pad(img1, padding)
50
+
51
+ if args.ratio:
52
+ print('ratio={}'.format(args.ratio))
53
+ img_list = model.inference(img0, img1, timestep=args.ratio)
54
+ else:
55
+ n = 2 ** args.exp - 1
56
+ time_list = [0]
57
+ for i in range(n):
58
+ time_list.append((i+1) * 1. / (n+1))
59
+ time_list.append(1)
60
+ print(time_list)
61
+ img_list = model.inference(img0, img1, timestep=time_list)
62
+
63
+ if not os.path.exists('output'):
64
+ os.mkdir('output')
65
+ for i in range(len(img_list)):
66
+ if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
67
+ cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
68
+ else:
69
+ cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
inference_video.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import argparse
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from torch.nn import functional as F
8
+ import warnings
9
+ import _thread
10
+ import skvideo.io
11
+ from queue import Queue, Empty
12
+ from model.pytorch_msssim import ssim_matlab
13
+
14
+ warnings.filterwarnings("ignore")
15
+
16
+ def transferAudio(sourceVideo, targetVideo):
17
+ import shutil
18
+ import moviepy.editor
19
+ tempAudioFileName = "./temp/audio.mkv"
20
+
21
+ # split audio from original video file and store in "temp" directory
22
+ if True:
23
+
24
+ # clear old "temp" directory if it exits
25
+ if os.path.isdir("temp"):
26
+ # remove temp directory
27
+ shutil.rmtree("temp")
28
+ # create new "temp" directory
29
+ os.makedirs("temp")
30
+ # extract audio from video
31
+ os.system('ffmpeg -y -i "{}" -c:a copy -vn {}'.format(sourceVideo, tempAudioFileName))
32
+
33
+ targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1]
34
+ os.rename(targetVideo, targetNoAudio)
35
+ # combine audio file and new video file
36
+ os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
37
+
38
+ if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac
39
+ tempAudioFileName = "./temp/audio.m4a"
40
+ os.system('ffmpeg -y -i "{}" -c:a aac -b:a 160k -vn {}'.format(sourceVideo, tempAudioFileName))
41
+ os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
42
+ if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format
43
+ os.rename(targetNoAudio, targetVideo)
44
+ print("Audio transfer failed. Interpolated video will have no audio")
45
+ else:
46
+ print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.")
47
+
48
+ # remove audio-less video
49
+ os.remove(targetNoAudio)
50
+ else:
51
+ os.remove(targetNoAudio)
52
+
53
+ # remove temp directory
54
+ shutil.rmtree("temp")
55
+
56
+ parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
57
+ parser.add_argument('--video', dest='video', type=str, default=None)
58
+ parser.add_argument('--output', dest='output', type=str, default=None)
59
+ parser.add_argument('--img', dest='img', type=str, default=None)
60
+ parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video')
61
+ parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files')
62
+ parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores')
63
+ parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
64
+ parser.add_argument('--scale', dest='scale', type=float, default=1.0, help='Try scale=0.5 for 4k video')
65
+ parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing')
66
+ parser.add_argument('--fps', dest='fps', type=int, default=None)
67
+ parser.add_argument('--png', dest='png', action='store_true', help='whether to vid_out png format vid_outs')
68
+ parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='vid_out video extension')
69
+ parser.add_argument('--exp', dest='exp', type=int, default=1)
70
+ parser.add_argument('--multi', dest='multi', type=int, default=2)
71
+
72
+ args = parser.parse_args()
73
+ if args.exp != 1:
74
+ args.multi = (2 ** args.exp)
75
+ assert (not args.video is None or not args.img is None)
76
+ if args.skip:
77
+ print("skip flag is abandoned, please refer to issue #207.")
78
+ if args.UHD and args.scale==1.0:
79
+ args.scale = 0.5
80
+ assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0]
81
+ if not args.img is None:
82
+ args.png = True
83
+
84
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
+ torch.set_grad_enabled(False)
86
+ if torch.cuda.is_available():
87
+ torch.backends.cudnn.enabled = True
88
+ torch.backends.cudnn.benchmark = True
89
+ if(args.fp16):
90
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
91
+
92
+ from train_log.RIFE_HDv3 import Model
93
+ model = Model()
94
+ if not hasattr(model, 'version'):
95
+ model.version = 0
96
+ model.load_model(args.modelDir, -1)
97
+ print("Loaded 3.x/4.x HD model.")
98
+ model.eval()
99
+ model.device()
100
+
101
+ if not args.video is None:
102
+ videoCapture = cv2.VideoCapture(args.video)
103
+ fps = videoCapture.get(cv2.CAP_PROP_FPS)
104
+ tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
105
+ videoCapture.release()
106
+ if args.fps is None:
107
+ fpsNotAssigned = True
108
+ args.fps = fps * args.multi
109
+ else:
110
+ fpsNotAssigned = False
111
+ videogen = skvideo.io.vreader(args.video)
112
+ lastframe = next(videogen)
113
+ fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
114
+ video_path_wo_ext, ext = os.path.splitext(args.video)
115
+ print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, fps, args.fps))
116
+ if args.png == False and fpsNotAssigned == True:
117
+ print("The audio will be merged after interpolation process")
118
+ else:
119
+ print("Will not merge audio because using png or fps flag!")
120
+ else:
121
+ videogen = []
122
+ for f in os.listdir(args.img):
123
+ if 'png' in f:
124
+ videogen.append(f)
125
+ tot_frame = len(videogen)
126
+ videogen.sort(key= lambda x:int(x[:-4]))
127
+ lastframe = cv2.imread(os.path.join(args.img, videogen[0]), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
128
+ videogen = videogen[1:]
129
+ h, w, _ = lastframe.shape
130
+ vid_out_name = None
131
+ vid_out = None
132
+ if args.png:
133
+ if not os.path.exists('vid_out'):
134
+ os.mkdir('vid_out')
135
+ else:
136
+ if args.output is not None:
137
+ vid_out_name = args.output
138
+ else:
139
+ vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, args.multi, int(np.round(args.fps)), args.ext)
140
+ vid_out = cv2.VideoWriter(vid_out_name, fourcc, args.fps, (w, h))
141
+
142
+ def clear_write_buffer(user_args, write_buffer):
143
+ cnt = 0
144
+ while True:
145
+ item = write_buffer.get()
146
+ if item is None:
147
+ break
148
+ if user_args.png:
149
+ cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1])
150
+ cnt += 1
151
+ else:
152
+ vid_out.write(item[:, :, ::-1])
153
+
154
+ def build_read_buffer(user_args, read_buffer, videogen):
155
+ try:
156
+ for frame in videogen:
157
+ if not user_args.img is None:
158
+ frame = cv2.imread(os.path.join(user_args.img, frame), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
159
+ if user_args.montage:
160
+ frame = frame[:, left: left + w]
161
+ read_buffer.put(frame)
162
+ except:
163
+ pass
164
+ read_buffer.put(None)
165
+
166
+ def make_inference(I0, I1, n):
167
+ global model
168
+ if model.version >= 3.9:
169
+ res = []
170
+ for i in range(n):
171
+ res.append(model.inference(I0, I1, (i+1) * 1. / (n+1), args.scale))
172
+ return res
173
+ else:
174
+ middle = model.inference(I0, I1, args.scale)
175
+ if n == 1:
176
+ return [middle]
177
+ first_half = make_inference(I0, middle, n=n//2)
178
+ second_half = make_inference(middle, I1, n=n//2)
179
+ if n%2:
180
+ return [*first_half, middle, *second_half]
181
+ else:
182
+ return [*first_half, *second_half]
183
+
184
+ def pad_image(img):
185
+ if(args.fp16):
186
+ return F.pad(img, padding).half()
187
+ else:
188
+ return F.pad(img, padding)
189
+
190
+ if args.montage:
191
+ left = w // 4
192
+ w = w // 2
193
+ tmp = max(128, int(128 / args.scale))
194
+ ph = ((h - 1) // tmp + 1) * tmp
195
+ pw = ((w - 1) // tmp + 1) * tmp
196
+ padding = (0, pw - w, 0, ph - h)
197
+ pbar = tqdm(total=tot_frame)
198
+ if args.montage:
199
+ lastframe = lastframe[:, left: left + w]
200
+ write_buffer = Queue(maxsize=500)
201
+ read_buffer = Queue(maxsize=500)
202
+ _thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen))
203
+ _thread.start_new_thread(clear_write_buffer, (args, write_buffer))
204
+
205
+ I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
206
+ I1 = pad_image(I1)
207
+ temp = None # save lastframe when processing static frame
208
+
209
+ while True:
210
+ if temp is not None:
211
+ frame = temp
212
+ temp = None
213
+ else:
214
+ frame = read_buffer.get()
215
+ if frame is None:
216
+ break
217
+ I0 = I1
218
+ I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
219
+ I1 = pad_image(I1)
220
+ I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
221
+ I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
222
+ ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
223
+
224
+ break_flag = False
225
+ if ssim > 0.996:
226
+ frame = read_buffer.get() # read a new frame
227
+ if frame is None:
228
+ break_flag = True
229
+ frame = lastframe
230
+ else:
231
+ temp = frame
232
+ I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
233
+ I1 = pad_image(I1)
234
+ I1 = model.inference(I0, I1, scale=args.scale)
235
+ I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
236
+ ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
237
+ frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]
238
+
239
+ if ssim < 0.2:
240
+ output = []
241
+ for i in range(args.multi - 1):
242
+ output.append(I0)
243
+ '''
244
+ output = []
245
+ step = 1 / args.multi
246
+ alpha = 0
247
+ for i in range(args.multi - 1):
248
+ alpha += step
249
+ beta = 1-alpha
250
+ output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)
251
+ '''
252
+ else:
253
+ output = make_inference(I0, I1, args.multi - 1)
254
+
255
+ if args.montage:
256
+ write_buffer.put(np.concatenate((lastframe, lastframe), 1))
257
+ for mid in output:
258
+ mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
259
+ write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1))
260
+ else:
261
+ write_buffer.put(lastframe)
262
+ for mid in output:
263
+ mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
264
+ write_buffer.put(mid[:h, :w])
265
+ pbar.update(1)
266
+ lastframe = frame
267
+ if break_flag:
268
+ break
269
+
270
+ if args.montage:
271
+ write_buffer.put(np.concatenate((lastframe, lastframe), 1))
272
+ else:
273
+ write_buffer.put(lastframe)
274
+ write_buffer.put(None)
275
+
276
+ import time
277
+ while(not write_buffer.empty()):
278
+ time.sleep(0.1)
279
+ pbar.close()
280
+ if not vid_out is None:
281
+ vid_out.release()
282
+
283
+ # move audio to new video file if appropriate
284
+ if args.png == False and fpsNotAssigned == True and not args.video is None:
285
+ try:
286
+ transferAudio(args.video, vid_out_name)
287
+ except:
288
+ print("Audio transfer failed. Interpolated video will have no audio")
289
+ targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1]
290
+ os.rename(targetNoAudio, vid_out_name)
inference_video_enhance.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import argparse
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from torch.nn import functional as F
8
+ import warnings
9
+ import _thread
10
+ import skvideo.io
11
+ from queue import Queue, Empty
12
+ from model.pytorch_msssim import ssim_matlab
13
+
14
+ warnings.filterwarnings("ignore")
15
+
16
+ def transferAudio(sourceVideo, targetVideo):
17
+ import shutil
18
+ import moviepy.editor
19
+ tempAudioFileName = "./temp/audio.mkv"
20
+
21
+ # split audio from original video file and store in "temp" directory
22
+ if True:
23
+
24
+ # clear old "temp" directory if it exits
25
+ if os.path.isdir("temp"):
26
+ # remove temp directory
27
+ shutil.rmtree("temp")
28
+ # create new "temp" directory
29
+ os.makedirs("temp")
30
+ # extract audio from video
31
+ os.system('ffmpeg -y -i "{}" -c:a copy -vn {}'.format(sourceVideo, tempAudioFileName))
32
+
33
+ targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1]
34
+ os.rename(targetVideo, targetNoAudio)
35
+ # combine audio file and new video file
36
+ os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
37
+
38
+ if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac
39
+ tempAudioFileName = "./temp/audio.m4a"
40
+ os.system('ffmpeg -y -i "{}" -c:a aac -b:a 160k -vn {}'.format(sourceVideo, tempAudioFileName))
41
+ os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
42
+ if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format
43
+ os.rename(targetNoAudio, targetVideo)
44
+ print("Audio transfer failed. Interpolated video will have no audio")
45
+ else:
46
+ print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.")
47
+
48
+ # remove audio-less video
49
+ os.remove(targetNoAudio)
50
+ else:
51
+ os.remove(targetNoAudio)
52
+
53
+ # remove temp directory
54
+ shutil.rmtree("temp")
55
+
56
+ parser = argparse.ArgumentParser(description='Video SR')
57
+ parser.add_argument('--video', dest='video', type=str, default=None)
58
+ parser.add_argument('--output', dest='output', type=str, default=None)
59
+ parser.add_argument('--img', dest='img', type=str, default=None)
60
+ parser.add_argument('--model', dest='modelDir', type=str, default='train_log_SAFA', help='directory with trained model files')
61
+ parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores')
62
+ parser.add_argument('--png', dest='png', action='store_true', help='whether to vid_out png format vid_outs')
63
+ parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='vid_out video extension')
64
+
65
+ args = parser.parse_args()
66
+ assert (not args.video is None or not args.img is None)
67
+ if not args.img is None:
68
+ args.png = True
69
+
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ torch.set_grad_enabled(False)
72
+ if torch.cuda.is_available():
73
+ torch.backends.cudnn.enabled = True
74
+ torch.backends.cudnn.benchmark = True
75
+ if(args.fp16):
76
+ print('set fp16')
77
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
78
+
79
+ try:
80
+ from train_log_SAFA.model import Model
81
+ except:
82
+ print("Please download our model from model list")
83
+ model = Model()
84
+ model.device()
85
+ model.load_model(args.modelDir)
86
+ print("Loaded SAFA model.")
87
+ model.eval()
88
+
89
+ if not args.video is None:
90
+ videoCapture = cv2.VideoCapture(args.video)
91
+ fps = videoCapture.get(cv2.CAP_PROP_FPS)
92
+ tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
93
+ videoCapture.release()
94
+ fpsNotAssigned = True
95
+ videogen = skvideo.io.vreader(args.video)
96
+ lastframe = next(videogen)
97
+ fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
98
+ video_path_wo_ext, ext = os.path.splitext(args.video)
99
+ if args.png == False and fpsNotAssigned == True:
100
+ print("The audio will be merged after interpolation process")
101
+ else:
102
+ print("Will not merge audio because using png or fps flag!")
103
+ else:
104
+ videogen = []
105
+ for f in os.listdir(args.img):
106
+ if 'png' in f:
107
+ videogen.append(f)
108
+ tot_frame = len(videogen)
109
+ videogen.sort(key= lambda x:int(x[:-4]))
110
+ lastframe = cv2.imread(os.path.join(args.img, videogen[0]), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
111
+ videogen = videogen[1:]
112
+
113
+ h, w, _ = lastframe.shape
114
+
115
+ vid_out_name = None
116
+ vid_out = None
117
+ if args.png:
118
+ if not os.path.exists('vid_out'):
119
+ os.mkdir('vid_out')
120
+ else:
121
+ if args.output is not None:
122
+ vid_out_name = args.output
123
+ else:
124
+ vid_out_name = '{}_2X{}'.format(video_path_wo_ext, ext)
125
+ vid_out = cv2.VideoWriter(vid_out_name, fourcc, fps, (w, h))
126
+
127
+ def clear_write_buffer(user_args, write_buffer):
128
+ cnt = 0
129
+ while True:
130
+ item = write_buffer.get()
131
+ if item is None:
132
+ break
133
+ if user_args.png:
134
+ cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1])
135
+ cnt += 1
136
+ else:
137
+ vid_out.write(item[:, :, ::-1])
138
+
139
+ def build_read_buffer(user_args, read_buffer, videogen):
140
+ for frame in videogen:
141
+ if not user_args.img is None:
142
+ frame = cv2.imread(os.path.join(user_args.img, frame), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
143
+ # if user_args.montage:
144
+ # frame = frame[:, left: left + w]
145
+ read_buffer.put(frame)
146
+ read_buffer.put(None)
147
+
148
+ def pad_image(img):
149
+ if(args.fp16):
150
+ return F.pad(img, padding, mode='reflect').half()
151
+ else:
152
+ return F.pad(img, padding, mode='reflect')
153
+
154
+ tmp = 64
155
+ ph = ((h - 1) // tmp + 1) * tmp
156
+ pw = ((w - 1) // tmp + 1) * tmp
157
+ padding = (0, pw - w, 0, ph - h)
158
+ pbar = tqdm(total=tot_frame)
159
+ write_buffer = Queue(maxsize=500)
160
+ read_buffer = Queue(maxsize=500)
161
+ _thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen))
162
+ _thread.start_new_thread(clear_write_buffer, (args, write_buffer))
163
+
164
+ while True:
165
+ frame = read_buffer.get()
166
+ if frame is None:
167
+ break
168
+ # lastframe_2x = cv2.resize(lastframe, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
169
+ # frame_2x = cv2.resize(frame, (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
170
+ I0 = pad_image(torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)
171
+ I1 = pad_image(torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)
172
+ I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
173
+ I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
174
+ ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
175
+ if ssim < 0.2:
176
+ out = [model.inference(I0, I0, [0])[0], model.inference(I1, I1, [0])[0]]
177
+ else:
178
+ out = model.inference(I0, I1, [0, 1])
179
+ assert(len(out) == 2)
180
+ write_buffer.put((out[0][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
181
+ write_buffer.put((out[1][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
182
+ lastframe = read_buffer.get()
183
+ if lastframe is None:
184
+ break
185
+ pbar.update(2)
186
+
187
+ import time
188
+ while(not write_buffer.empty()):
189
+ time.sleep(0.1)
190
+ pbar.close()
191
+ if not vid_out is None:
192
+ vid_out.release()
193
+
194
+ # move audio to new video file if appropriate
195
+ if args.png == False and fpsNotAssigned == True and not args.video is None:
196
+ try:
197
+ transferAudio(args.video, vid_out_name)
198
+ except:
199
+ print("Audio transfer failed. Interpolated video will have no audio")
200
+ targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1]
201
+ os.rename(targetNoAudio, vid_out_name)
model/loss.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.models as models
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+
10
+ class EPE(nn.Module):
11
+ def __init__(self):
12
+ super(EPE, self).__init__()
13
+
14
+ def forward(self, flow, gt, loss_mask):
15
+ loss_map = (flow - gt.detach()) ** 2
16
+ loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5
17
+ return (loss_map * loss_mask)
18
+
19
+
20
+ class Ternary(nn.Module):
21
+ def __init__(self):
22
+ super(Ternary, self).__init__()
23
+ patch_size = 7
24
+ out_channels = patch_size * patch_size
25
+ self.w = np.eye(out_channels).reshape(
26
+ (patch_size, patch_size, 1, out_channels))
27
+ self.w = np.transpose(self.w, (3, 2, 0, 1))
28
+ self.w = torch.tensor(self.w).float().to(device)
29
+
30
+ def transform(self, img):
31
+ patches = F.conv2d(img, self.w, padding=3, bias=None)
32
+ transf = patches - img
33
+ transf_norm = transf / torch.sqrt(0.81 + transf**2)
34
+ return transf_norm
35
+
36
+ def rgb2gray(self, rgb):
37
+ r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :]
38
+ gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
39
+ return gray
40
+
41
+ def hamming(self, t1, t2):
42
+ dist = (t1 - t2) ** 2
43
+ dist_norm = torch.mean(dist / (0.1 + dist), 1, True)
44
+ return dist_norm
45
+
46
+ def valid_mask(self, t, padding):
47
+ n, _, h, w = t.size()
48
+ inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t)
49
+ mask = F.pad(inner, [padding] * 4)
50
+ return mask
51
+
52
+ def forward(self, img0, img1):
53
+ img0 = self.transform(self.rgb2gray(img0))
54
+ img1 = self.transform(self.rgb2gray(img1))
55
+ return self.hamming(img0, img1) * self.valid_mask(img0, 1)
56
+
57
+
58
+ class SOBEL(nn.Module):
59
+ def __init__(self):
60
+ super(SOBEL, self).__init__()
61
+ self.kernelX = torch.tensor([
62
+ [1, 0, -1],
63
+ [2, 0, -2],
64
+ [1, 0, -1],
65
+ ]).float()
66
+ self.kernelY = self.kernelX.clone().T
67
+ self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device)
68
+ self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device)
69
+
70
+ def forward(self, pred, gt):
71
+ N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3]
72
+ img_stack = torch.cat(
73
+ [pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0)
74
+ sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1)
75
+ sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1)
76
+ pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:]
77
+ pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:]
78
+
79
+ L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y)
80
+ loss = (L1X+L1Y)
81
+ return loss
82
+
83
+ class MeanShift(nn.Conv2d):
84
+ def __init__(self, data_mean, data_std, data_range=1, norm=True):
85
+ c = len(data_mean)
86
+ super(MeanShift, self).__init__(c, c, kernel_size=1)
87
+ std = torch.Tensor(data_std)
88
+ self.weight.data = torch.eye(c).view(c, c, 1, 1)
89
+ if norm:
90
+ self.weight.data.div_(std.view(c, 1, 1, 1))
91
+ self.bias.data = -1 * data_range * torch.Tensor(data_mean)
92
+ self.bias.data.div_(std)
93
+ else:
94
+ self.weight.data.mul_(std.view(c, 1, 1, 1))
95
+ self.bias.data = data_range * torch.Tensor(data_mean)
96
+ self.requires_grad = False
97
+
98
+ class VGGPerceptualLoss(torch.nn.Module):
99
+ def __init__(self, rank=0):
100
+ super(VGGPerceptualLoss, self).__init__()
101
+ blocks = []
102
+ pretrained = True
103
+ self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features
104
+ self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
105
+ for param in self.parameters():
106
+ param.requires_grad = False
107
+
108
+ def forward(self, X, Y, indices=None):
109
+ X = self.normalize(X)
110
+ Y = self.normalize(Y)
111
+ indices = [2, 7, 12, 21, 30]
112
+ weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5]
113
+ k = 0
114
+ loss = 0
115
+ for i in range(indices[-1]):
116
+ X = self.vgg_pretrained_features[i](X)
117
+ Y = self.vgg_pretrained_features[i](Y)
118
+ if (i+1) in indices:
119
+ loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1
120
+ k += 1
121
+ return loss
122
+
123
+ if __name__ == '__main__':
124
+ img0 = torch.zeros(3, 3, 256, 256).float().to(device)
125
+ img1 = torch.tensor(np.random.normal(
126
+ 0, 1, (3, 3, 256, 256))).float().to(device)
127
+ ternary_loss = Ternary()
128
+ print(ternary_loss(img0, img1).shape)
model/pytorch_msssim/__init__.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from math import exp
4
+ import numpy as np
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ def gaussian(window_size, sigma):
9
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
10
+ return gauss/gauss.sum()
11
+
12
+
13
+ def create_window(window_size, channel=1):
14
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
15
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)
16
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
17
+ return window
18
+
19
+ def create_window_3d(window_size, channel=1):
20
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
21
+ _2D_window = _1D_window.mm(_1D_window.t())
22
+ _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
23
+ window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
24
+ return window
25
+
26
+
27
+ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
28
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
29
+ if val_range is None:
30
+ if torch.max(img1) > 128:
31
+ max_val = 255
32
+ else:
33
+ max_val = 1
34
+
35
+ if torch.min(img1) < -0.5:
36
+ min_val = -1
37
+ else:
38
+ min_val = 0
39
+ L = max_val - min_val
40
+ else:
41
+ L = val_range
42
+
43
+ padd = 0
44
+ (_, channel, height, width) = img1.size()
45
+ if window is None:
46
+ real_size = min(window_size, height, width)
47
+ window = create_window(real_size, channel=channel).to(img1.device)
48
+
49
+ # mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
50
+ # mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
51
+ mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
52
+ mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
53
+
54
+ mu1_sq = mu1.pow(2)
55
+ mu2_sq = mu2.pow(2)
56
+ mu1_mu2 = mu1 * mu2
57
+
58
+ sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq
59
+ sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq
60
+ sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2
61
+
62
+ C1 = (0.01 * L) ** 2
63
+ C2 = (0.03 * L) ** 2
64
+
65
+ v1 = 2.0 * sigma12 + C2
66
+ v2 = sigma1_sq + sigma2_sq + C2
67
+ cs = torch.mean(v1 / v2) # contrast sensitivity
68
+
69
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
70
+
71
+ if size_average:
72
+ ret = ssim_map.mean()
73
+ else:
74
+ ret = ssim_map.mean(1).mean(1).mean(1)
75
+
76
+ if full:
77
+ return ret, cs
78
+ return ret
79
+
80
+
81
+ def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
82
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
83
+ if val_range is None:
84
+ if torch.max(img1) > 128:
85
+ max_val = 255
86
+ else:
87
+ max_val = 1
88
+
89
+ if torch.min(img1) < -0.5:
90
+ min_val = -1
91
+ else:
92
+ min_val = 0
93
+ L = max_val - min_val
94
+ else:
95
+ L = val_range
96
+
97
+ padd = 0
98
+ (_, _, height, width) = img1.size()
99
+ if window is None:
100
+ real_size = min(window_size, height, width)
101
+ window = create_window_3d(real_size, channel=1).to(img1.device)
102
+ # Channel is set to 1 since we consider color images as volumetric images
103
+
104
+ img1 = img1.unsqueeze(1)
105
+ img2 = img2.unsqueeze(1)
106
+
107
+ mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
108
+ mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
109
+
110
+ mu1_sq = mu1.pow(2)
111
+ mu2_sq = mu2.pow(2)
112
+ mu1_mu2 = mu1 * mu2
113
+
114
+ sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq
115
+ sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq
116
+ sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2
117
+
118
+ C1 = (0.01 * L) ** 2
119
+ C2 = (0.03 * L) ** 2
120
+
121
+ v1 = 2.0 * sigma12 + C2
122
+ v2 = sigma1_sq + sigma2_sq + C2
123
+ cs = torch.mean(v1 / v2) # contrast sensitivity
124
+
125
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
126
+
127
+ if size_average:
128
+ ret = ssim_map.mean()
129
+ else:
130
+ ret = ssim_map.mean(1).mean(1).mean(1)
131
+
132
+ if full:
133
+ return ret, cs
134
+ return ret
135
+
136
+
137
+ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
138
+ device = img1.device
139
+ weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
140
+ levels = weights.size()[0]
141
+ mssim = []
142
+ mcs = []
143
+ for _ in range(levels):
144
+ sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
145
+ mssim.append(sim)
146
+ mcs.append(cs)
147
+
148
+ img1 = F.avg_pool2d(img1, (2, 2))
149
+ img2 = F.avg_pool2d(img2, (2, 2))
150
+
151
+ mssim = torch.stack(mssim)
152
+ mcs = torch.stack(mcs)
153
+
154
+ # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
155
+ if normalize:
156
+ mssim = (mssim + 1) / 2
157
+ mcs = (mcs + 1) / 2
158
+
159
+ pow1 = mcs ** weights
160
+ pow2 = mssim ** weights
161
+ # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
162
+ output = torch.prod(pow1[:-1] * pow2[-1])
163
+ return output
164
+
165
+
166
+ # Classes to re-use window
167
+ class SSIM(torch.nn.Module):
168
+ def __init__(self, window_size=11, size_average=True, val_range=None):
169
+ super(SSIM, self).__init__()
170
+ self.window_size = window_size
171
+ self.size_average = size_average
172
+ self.val_range = val_range
173
+
174
+ # Assume 3 channel for SSIM
175
+ self.channel = 3
176
+ self.window = create_window(window_size, channel=self.channel)
177
+
178
+ def forward(self, img1, img2):
179
+ (_, channel, _, _) = img1.size()
180
+
181
+ if channel == self.channel and self.window.dtype == img1.dtype:
182
+ window = self.window
183
+ else:
184
+ window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
185
+ self.window = window
186
+ self.channel = channel
187
+
188
+ _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
189
+ dssim = (1 - _ssim) / 2
190
+ return dssim
191
+
192
+ class MSSSIM(torch.nn.Module):
193
+ def __init__(self, window_size=11, size_average=True, channel=3):
194
+ super(MSSSIM, self).__init__()
195
+ self.window_size = window_size
196
+ self.size_average = size_average
197
+ self.channel = channel
198
+
199
+ def forward(self, img1, img2):
200
+ return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
model/pytorch_msssim/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (10.3 kB). View file
 
model/warplayer.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+ backwarp_tenGrid = {}
6
+
7
+
8
+ def warp(tenInput, tenFlow):
9
+ k = (str(tenFlow.device), str(tenFlow.size()))
10
+ if k not in backwarp_tenGrid:
11
+ tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
12
+ 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
13
+ tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
14
+ 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
15
+ backwarp_tenGrid[k] = torch.cat(
16
+ [tenHorizontal, tenVertical], 1).to(device)
17
+
18
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
19
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
20
+
21
+ g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
22
+ return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy>=1.16, <=1.23.5
2
+ tqdm>=4.35.0
3
+ sk-video>=1.1.10
4
+ torch>=1.3.0
5
+ opencv-python>=4.1.2
6
+ moviepy>=1.0.3
7
+ torchvision>=0.7.0
video_processing.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from torch.nn import functional as F
7
+ import warnings
8
+ import _thread
9
+ import skvideo.io
10
+ from queue import Queue, Empty
11
+ from model.pytorch_msssim import ssim_matlab
12
+
13
+ warnings.filterwarnings("ignore")
14
+
15
+ def transferAudio(sourceVideo, targetVideo):
16
+ import shutil
17
+ import moviepy.editor
18
+ tempAudioFileName = "./temp/audio.mkv"
19
+
20
+ # split audio from original video file and store in "temp" directory
21
+ if True:
22
+
23
+ # clear old "temp" directory if it exits
24
+ if os.path.isdir("temp"):
25
+ # remove temp directory
26
+ shutil.rmtree("temp")
27
+ # create new "temp" directory
28
+ os.makedirs("temp")
29
+ # extract audio from video
30
+ os.system('ffmpeg -y -i "{}" -c:a copy -vn {}'.format(sourceVideo, tempAudioFileName))
31
+
32
+ targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1]
33
+ os.rename(targetVideo, targetNoAudio)
34
+ # combine audio file and new video file
35
+ os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
36
+
37
+ if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac
38
+ tempAudioFileName = "./temp/audio.m4a"
39
+ os.system('ffmpeg -y -i "{}" -c:a aac -b:a 160k -vn {}'.format(sourceVideo, tempAudioFileName))
40
+ os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
41
+ if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format
42
+ os.rename(targetNoAudio, targetVideo)
43
+ print("Audio transfer failed. Interpolated video will have no audio")
44
+ else:
45
+ print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.")
46
+
47
+ # remove audio-less video
48
+ os.remove(targetNoAudio)
49
+ else:
50
+ os.remove(targetNoAudio)
51
+
52
+ # remove temp directory
53
+ shutil.rmtree("temp")
54
+
55
+ def process_video(video, output, modelDir, fp16, UHD, scale, skip, fps, png, ext, exp, multi):
56
+ if exp != 1:
57
+ multi = (2 ** exp)
58
+ assert (not video is None)
59
+ if skip:
60
+ print("skip flag is abandoned, please refer to issue #207.")
61
+ if UHD and scale==1.0:
62
+ scale = 0.5
63
+ assert scale in [0.25, 0.5, 1.0, 2.0, 4.0]
64
+
65
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+ torch.set_grad_enabled(False)
67
+ if torch.cuda.is_available():
68
+ torch.backends.cudnn.enabled = True
69
+ torch.backends.cudnn.benchmark = True
70
+ if(fp16):
71
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
72
+
73
+ from rife.train_log.RIFE_HDv3 import Model
74
+ model = Model()
75
+ if not hasattr(model, 'version'):
76
+ model.version = 0
77
+ model.load_model(modelDir, -1)
78
+ print("Loaded 3.x/4.x HD model.")
79
+ model.eval()
80
+ model.device()
81
+
82
+ videoCapture = cv2.VideoCapture(video)
83
+ fps_in = videoCapture.get(cv2.CAP_PROP_FPS)
84
+ tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
85
+ videoCapture.release()
86
+ if fps is None:
87
+ fpsNotAssigned = True
88
+ fps_out = fps_in * multi
89
+ else:
90
+ fpsNotAssigned = False
91
+ fps_out = fps
92
+ videogen = skvideo.io.vreader(video)
93
+ lastframe = next(videogen)
94
+ fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
95
+ video_path_wo_ext, video_ext = os.path.splitext(video)
96
+ print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, ext, tot_frame, fps_in, fps_out))
97
+ if png == False and fpsNotAssigned == True:
98
+ print("The audio will be merged after interpolation process")
99
+ else:
100
+ print("Will not merge audio because using png or fps flag!")
101
+
102
+ h, w, _ = lastframe.shape
103
+ vid_out_name = None
104
+ vid_out = None
105
+ if png:
106
+ if not os.path.exists('vid_out'):
107
+ os.mkdir('vid_out')
108
+ else:
109
+ if output is not None:
110
+ vid_out_name = output
111
+ else:
112
+ vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, multi, int(np.round(fps_out)), ext)
113
+ vid_out = cv2.VideoWriter(vid_out_name, fourcc, fps_out, (w, h))
114
+
115
+ def clear_write_buffer(user_args, write_buffer):
116
+ cnt = 0
117
+ while True:
118
+ item = write_buffer.get()
119
+ if item is None:
120
+ break
121
+ if user_args.png:
122
+ cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1])
123
+ cnt += 1
124
+ else:
125
+ vid_out.write(item[:, :, ::-1])
126
+
127
+ def build_read_buffer(user_args, read_buffer, videogen):
128
+ try:
129
+ for frame in videogen:
130
+ read_buffer.put(frame)
131
+ except:
132
+ pass
133
+ read_buffer.put(None)
134
+
135
+ def make_inference(I0, I1, n):
136
+ if model.version >= 3.9:
137
+ res = []
138
+ for i in range(n):
139
+ res.append(model.inference(I0, I1, (i+1) * 1. / (n+1), scale))
140
+ return res
141
+ else:
142
+ middle = model.inference(I0, I1, scale)
143
+ if n == 1:
144
+ return [middle]
145
+ first_half = make_inference(I0, middle, n=n//2)
146
+ second_half = make_inference(middle, I1, n=n//2)
147
+ if n%2:
148
+ return [*first_half, middle, *second_half]
149
+ else:
150
+ return [*first_half, *second_half]
151
+
152
+ def pad_image(img):
153
+ if(fp16):
154
+ return F.pad(img, padding).half()
155
+ else:
156
+ return F.pad(img, padding)
157
+
158
+ tmp = max(128, int(128 / scale))
159
+ ph = ((h - 1) // tmp + 1) * tmp
160
+ pw = ((w - 1) // tmp + 1) * tmp
161
+ padding = (0, pw - w, 0, ph - h)
162
+ pbar = tqdm(total=tot_frame)
163
+ write_buffer = Queue(maxsize=500)
164
+ read_buffer = Queue(maxsize=500)
165
+ _thread.start_new_thread(build_read_buffer, ((), read_buffer, videogen))
166
+ _thread.start_new_thread(clear_write_buffer, ((), write_buffer))
167
+
168
+ I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
169
+ I1 = pad_image(I1)
170
+ temp = None # save lastframe when processing static frame
171
+
172
+ while True:
173
+ if temp is not None:
174
+ frame = temp
175
+ temp = None
176
+ else:
177
+ frame = read_buffer.get()
178
+ if frame is None:
179
+ break
180
+ I0 = I1
181
+ I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
182
+ I1 = pad_image(I1)
183
+ I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
184
+ I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
185
+ ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
186
+
187
+ break_flag = False
188
+ if ssim > 0.996:
189
+ frame = read_buffer.get() # read a new frame
190
+ if frame is None:
191
+ break_flag = True
192
+ frame = lastframe
193
+ else:
194
+ temp = frame
195
+ I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
196
+ I1 = pad_image(I1)
197
+ I1 = model.inference(I0, I1, scale=scale)
198
+ I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
199
+ ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
200
+ frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]
201
+
202
+ if ssim < 0.2:
203
+ output = []
204
+ for i in range(multi - 1):
205
+ output.append(I0)
206
+ else:
207
+ output = make_inference(I0, I1, multi - 1)
208
+
209
+ write_buffer.put(lastframe)
210
+ for mid in output:
211
+ mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
212
+ write_buffer.put(mid[:h, :w])
213
+ pbar.update(1)
214
+ lastframe = frame
215
+ if break_flag:
216
+ break
217
+
218
+ write_buffer.put(lastframe)
219
+ write_buffer.put(None)
220
+
221
+ import time
222
+ while(not write_buffer.empty()):
223
+ time.sleep(0.1)
224
+ pbar.close()
225
+ if not vid_out is None:
226
+ vid_out.release()
227
+
228
+ if png == False and fpsNotAssigned == True and not video is None:
229
+ try:
230
+ transferAudio(video, vid_out_name)
231
+ except:
232
+ print("Audio transfer failed. Interpolated video will have no audio")
233
+ targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1]
234
+ os.rename(targetNoAudio, vid_out_name)
235
+ return vid_out_name