File size: 2,771 Bytes
748160a
227bc73
1a780e6
27b9ec6
a5c228f
27b9ec6
227bc73
df19679
1a780e6
 
df19679
 
a4a2927
1a780e6
 
eb08525
1a780e6
 
 
 
 
 
 
 
 
 
 
 
2e5533e
1a780e6
 
 
 
 
 
 
 
ecea5f9
 
1a780e6
 
 
a5c228f
1a780e6
 
ecea5f9
1a780e6
ecea5f9
 
1a780e6
 
 
 
 
47d7323
1a780e6
 
 
 
a4a2927
1a780e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4a2927
 
1a780e6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import spaces
import gradio as gr
import argparse
import sys
import time
import os
import random
from skyreelsinfer.offload import OffloadConfig
from skyreelsinfer import TaskType
from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer
from diffusers.utils import export_to_video
from diffusers.utils import load_image

#predictor = None
#task_type = None

#@spaces.GPU(duration=120)
def init_predictor():
    global predictor
    predictor = SkyReelsVideoSingleGpuInfer(
        task_type= TaskType.I2V,
        model_id="Skywork/SkyReels-V1-Hunyuan-I2V",
        quant_model=False,
        is_offload=False,
        offload_config=OffloadConfig(
            high_cpu_memory=True,
            parameters_level=True,
            compiler_transformer=False,
        )
    )
    
@spaces.GPU(duration=80)
def generate_video(prompt, seed, image=None):
    print(f"image:{type(image)}")
    if seed == -1:
        random.seed(time.time())
        seed = int(random.randrange(4294967294))
    kwargs = {
        "prompt": prompt,
        "height": 512,
        "width": 512,
        "num_frames": 97,
        "num_inference_steps": 30,
        "seed": seed,
        "guidance_scale": 6.0,
        "embedded_guidance_scale": 1.0,
        "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
        "cfg_for": False,
    }
    assert image is not None, "please input image"
    kwargs["image"] = load_image(image=image)
    #global predictor
    output = predictor.inference(kwargs)
    save_dir = f"./result/{task_type}"
    os.makedirs(save_dir, exist_ok=True)
    video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
    print(f"generate video, local path: {video_out_file}")
    export_to_video(output, video_out_file, fps=24)
    return video_out_file, kwargs

def create_gradio_interface():
        with gr.Blocks() as demo:
            with gr.Row():
                image = gr.Image(label="Upload Image", type="filepath")
                prompt = gr.Textbox(label="Input Prompt")
                seed = gr.Number(label="Random Seed", value=-1)
            submit_button = gr.Button("Generate Video")
            output_video = gr.Video(label="Generated Video")
            output_params = gr.Textbox(label="Output Parameters")
            submit_button.click(
                fn=generate_video,
                inputs=[prompt, seed, image],
                outputs=[output_video, output_params],
            )
        return demo
    
#init_predictor()

if __name__ == "__main__":
    #import multiprocessing
    #multiprocessing.freeze_support()
    init_predictor()
    demo = create_gradio_interface()
    demo.launch()