File size: 3,364 Bytes
748160a
227bc73
27b9ec6
a5c228f
27b9ec6
227bc73
7e3737d
 
eb08525
7e3737d
 
eb08525
 
7e3737d
 
abdf73f
88264d5
a1c2882
 
 
 
 
7e3737d
2e5533e
 
 
7e3737d
a1c2882
 
 
 
 
 
 
2e5533e
88264d5
2e5533e
7e3737d
2e5533e
7e3737d
 
 
a5c228f
c2f6dae
7e3737d
a5c228f
 
7e3737d
 
 
 
 
47d7323
 
ecea5f9
 
a5c228f
 
 
 
89b0689
a5c228f
ecea5f9
a5c228f
ecea5f9
 
 
7e3737d
 
47d7323
7e3737d
47d7323
7e3737d
47d7323
7e3737d
 
 
47d7323
 
7e3737d
47d7323
 
7e3737d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import spaces
import gradio as gr
import sys
import time
import os
import random
from PIL import Image 
 # os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["SAFETENSORS_FAST_GPU"] = "1"
os.putenv("HF_HUB_ENABLE_HF_TRANSFER","1")
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Create the gr.State component *outside* the gr.Blocks context

global predictor

def init_predictor(task_type: str):
    from skyreelsinfer import TaskType
    from skyreelsinfer.offload import OffloadConfig
    from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
    from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError
    global predictor
    try:
        predictor = SkyReelsVideoInfer(
            task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
            model_id="Skywork/skyreels-v1-Hunyuan-i2v",
            quant_model=True,
            is_offload=True,
            offload_config=OffloadConfig(
                high_cpu_memory=True,
                parameters_level=True,
            ),
            use_multiprocessing=False,
        )
        return predictor
    except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError) as e:
        return f"Error: Model not found. Details: {e}", None
    except Exception as e:
        return f"Error loading model: {e}", None
        
predictor = init_predictor('i2v')

@spaces.GPU(duration=80)
def generate_video(prompt, image, predictor):
    from diffusers.utils import export_to_video
    from diffusers.utils import load_image
    if image == None:
        return "Error: For i2v, provide image path.", "{}"
    if not isinstance(prompt, str):
        return "Error: No prompt.", "{}"
    #if seed == -1:
    random.seed(time.time())
    seed = int(random.randrange(4294967294))
    kwargs = {
        "prompt": prompt,
        "height": 256,
        "width": 256,
        "num_frames": 24,
        "num_inference_steps": 30,
        "seed": int(seed),
        "guidance_scale": 7.0,
        "embedded_guidance_scale": 1.0,
        "negative_prompt": "bad quality, blur",
        "cfg_for": False,
    }

    kwargs["image"] = load_image(image=image)
    output = predictor.inference(kwargs)
    frames = output
    save_dir = f"./result/{task_type}"
    os.makedirs(save_dir, exist_ok=True)
    video_out_file = f"{save_dir}/{prompt[:100]}_{int(seed)}.mp4"
    print(f"Generating video: {video_out_file}")
    export_to_video(frames, video_out_file, fps=24)
    return video_out_file
 
def display_image(file):
    if file is not None:
        return Image.open(file.name)
    else:
        return None
        
with gr.Blocks() as demo:
    #predictor = gr.State({})  # Initialize as an empty dictionary

    image_file = gr.File(label="Image Prompt (Required)", file_types=["image"])
    image_file_preview = gr.Image(label="Image Prompt Preview", interactive=False)
    prompt_textbox = gr.Text(label="Prompt")
    generate_button = gr.Button("Generate")
    output_video = gr.Video(label="Output Video")

    image_file.change(
        display_image,
        inputs=[image_file],
        outputs=[image_file_preview]
    )

    generate_button.click(
        fn=generate_video,
        inputs=[prompt_textbox, image_file, predictor],
        outputs=[output_video],
    )

demo.launch()