File size: 4,586 Bytes
748160a
227bc73
27b9ec6
a5c228f
27b9ec6
227bc73
a4a2927
7e3737d
a4a2927
df19679
 
 
 
 
 
a4a2927
 
 
 
 
7834fa8
eb08525
a4a2927
 
5fddbe1
7e3737d
a1c2882
2e5533e
 
 
a4a2927
a1c2882
42d98b4
 
 
 
 
a1c2882
2e5533e
88264d5
2e5533e
a4a2927
 
2e5533e
a4a2927
 
 
 
 
 
 
 
 
 
 
47d7323
 
ecea5f9
 
a5c228f
 
 
 
89b0689
a5c228f
ecea5f9
a5c228f
ecea5f9
 
a4a2927
7834fa8
 
 
 
a4a2927
7834fa8
a4a2927
 
 
 
7834fa8
a4a2927
47d7323
a4a2927
47d7323
a4a2927
 
 
7834fa8
a4a2927
 
47d7323
 
7e3737d
47d7323
 
a4a2927
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import spaces
import gradio as gr
import sys
import time
import os
import random
from PIL import Image
import torch
import asyncio  # Import asyncio
from skyreelsinfer import TaskType
from skyreelsinfer.offload import OffloadConfig
from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError
from diffusers.utils import export_to_video
from diffusers.utils import load_image

# os.environ["CUDA_VISIBLE_DEVICES"] = ""  # Uncomment if needed
os.environ["SAFETENSORS_FAST_GPU"] = "1"
os.putenv("HF_HUB_ENABLE_HF_TRANSFER", "1")

# No longer needed here: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Use gr.State to hold the predictor.  Initialize it to None.
predictor_state = gr.State(None)
device="cuda:0" if torch.cuda.is_available() else "cpu" # Pass device to the constructor

def init_predictor(task_type: str):
    try:
        predictor = SkyReelsVideoInfer(
            task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
            model_id="Skywork/skyreels-v1-Hunyuan-i2v",  # Adjust model ID as needed
            quant_model=True,
            is_offload=False,  # Consider removing if you have enough GPU memory
            offload_config=None, #OffloadConfig(
            #    high_cpu_memory=True,
            #    parameters_level=True,
            #),
            use_multiprocessing=False,
        )
        return predictor
    except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError) as e:
        print(f"Error: Model not found. Details: {e}")
        return None
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

# Make generate_video async
async def generate_video(prompt, image_file, predictor):
    if image_file is None:
        return gr.Error("Error: For i2v, provide an image.")
    if not isinstance(prompt, str) or not prompt.strip():
        return gr.Error("Error: Please provide a prompt.")
    if predictor is None:
        return gr.Error("Error: Model not loaded.")
    random.seed(time.time())
    seed = int(random.randrange(4294967294))
    kwargs = {
        "prompt": prompt,
        "height": 256,
        "width": 256,
        "num_frames": 24,
        "num_inference_steps": 30,
        "seed": int(seed),
        "guidance_scale": 7.0,
        "embedded_guidance_scale": 1.0,
        "negative_prompt": "bad quality, blur",
        "cfg_for": False,
    }
    try:
        # Load the image and move it to the correct device *before* inference
        image = load_image(image=image_file.name)
        # No need to manually move to device. SkyReelsVideoInfer should handle it.
        kwargs["image"] = image
    except Exception as e:
        return gr.Error(f"Image loading error: {e}")
    try:
        output = predictor.inference(kwargs)
        frames = output
    except Exception as e:
        return gr.Error(f"Inference error: {e}"), None  # Return None for predictor on error
    save_dir = "./result/i2v"  # Consistent directory
    os.makedirs(save_dir, exist_ok=True)
    video_out_file = os.path.join(save_dir, f"{prompt[:100]}_{int(seed)}.mp4")
    print(f"Generating video: {video_out_file}")
    try:
        export_to_video(frames, video_out_file, fps=24)
    except Exception as e:
        return gr.Error(f"Video export error: {e}"), None # Return None for predictor
    return video_out_file, predictor  # Return updated predictor

def display_image(file):
    if file is not None:
        return Image.open(file.name)
    else:
        return None

async def load_model():
    predictor = init_predictor('i2v')
    return predictor

async def main():
    with gr.Blocks() as demo:
        image_file = gr.File(label="Image Prompt (Required)", file_types=["image"])
        image_file_preview = gr.Image(label="Image Prompt Preview", interactive=False)
        prompt_textbox = gr.Text(label="Prompt")
        generate_button = gr.Button("Generate")
        output_video = gr.Video(label="Output Video")
        image_file.change(
            display_image,
            inputs=[image_file],
            outputs=[image_file_preview]
        )
        generate_button.click(
            fn=generate_video,
            inputs=[prompt_textbox, image_file, predictor_state],
            outputs=[output_video, predictor_state],  # Output predictor_state
        )
        predictor_state.value = await load_model() # load and set predictor
    await demo.launch()

if __name__ == "__main__":
    asyncio.run(main())