File size: 4,767 Bytes
7b1bf28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52e114d
 
7b1bf28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52e114d
 
 
7b1bf28
 
 
 
52e114d
7b1bf28
 
 
 
 
 
 
 
 
52e114d
 
7b1bf28
 
 
 
52e114d
7b1bf28
 
 
 
 
 
 
 
 
52e114d
 
7b1bf28
 
 
 
52e114d
7b1bf28
 
52e114d
7b1bf28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import os
import shutil
import torch
from PIL import Image
import argparse
import pathlib

os.system("git clone https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model")
os.chdir("Thin-Plate-Spline-Motion-Model")
os.system("mkdir checkpoints")
os.system("wget -c https://cloud.tsinghua.edu.cn/f/da8d61d012014b12a9e4/?dl=1 -O checkpoints/vox.pth.tar")



title = "# 图片动画"
DESCRIPTION = '''### 图片动画的Gradio实现</b>, CVPR 2022. <a href='https://arxiv.org/abs/2203.14367'>[Paper]</a><a href='https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model'>[Github Code]</a>

<img id="overview" alt="overview" src="https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model/raw/main/assets/vox.gif" />
'''
FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.Image-Animation-using-Thin-Plate-Spline-Motion-Model" />'


def get_style_image_path(style_name: str) -> str:
    base_path = 'assets'
    filenames = {
        'source': 'source.png',
        'driving': 'driving.mp4',
    }
    return f'{base_path}/{filenames[style_name]}'


def get_style_image_markdown_text(style_name: str) -> str:
    url = get_style_image_path(style_name)
    return f'<img id="style-image" src="{url}" alt="style image">'


def update_style_image(style_name: str) -> dict:
    text = get_style_image_markdown_text(style_name)
    return gr.Markdown.update(value=text)


def set_example_image(example: list) -> dict:
    return gr.Image.update(value=example[0])

def set_example_video(example: list) -> dict:
    return gr.Video.update(value=example[0])

def inference(img,vid):
  if not os.path.exists('temp'):
    os.system('mkdir temp')
  
  img.save("temp/image.jpg", "JPEG")
  os.system(f"python demo.py --config config/vox-256.yaml --checkpoint ./checkpoints/vox.pth.tar --source_image 'temp/image.jpg' --driving_video {vid} --result_video './temp/result.mp4' --cpu")
  return './temp/result.mp4'
  


def main():
    with gr.Blocks(theme="huggingface", css='style.css') as demo:
        gr.Markdown(title)
        gr.Markdown(DESCRIPTION)

        with gr.Box():
            gr.Markdown('''## 第1步 (上传人脸图片)
- 拖一张含人脸的图片到 **输入图片**.
    - 如果图片中有多张人脸, 使用右上角的编辑按钮裁剪图片.
''')
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        input_image = gr.Image(label='输入图片',
                                               type="pil")
                        
            with gr.Row():
                paths = sorted(pathlib.Path('assets').glob('*.png'))
                example_images = gr.Dataset(components=[input_image],
                                            samples=[[path.as_posix()]
                                                     for path in paths])

        with gr.Box():
            gr.Markdown('''## 第2步 (选择动态视频)
-  **为人脸图片选择目标视频**.
''')
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        driving_video = gr.Video(label='目标视频',
                                               format="mp4")

            with gr.Row():
                paths = sorted(pathlib.Path('assets').glob('*.mp4'))
                example_video = gr.Dataset(components=[driving_video],
                                            samples=[[path.as_posix()]
                                                     for path in paths])

        with gr.Box():
            gr.Markdown('''## 第3步 (基于视频生成动态图片)
- 点击 **开始** 按钮. (注意: 由于是在CPU上运行, 生成最终结果需要花费大约3分钟.)
''')
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        generate_button = gr.Button('开始')

                with gr.Column():
                    result = gr.Video(type="file", label="输出")
        gr.Markdown(FOOTER)
        generate_button.click(fn=inference,
                              inputs=[
                                  input_image,
                                  driving_video
                              ],
                              outputs=result)
        example_images.click(fn=set_example_image,
                             inputs=example_images,
                             outputs=example_images.components)
        example_video.click(fn=set_example_video,
                             inputs=example_video,
                             outputs=example_video.components)

    demo.launch(
        enable_queue=True,
        debug=True
    )

if __name__ == '__main__':
    main()