1inkusFace commited on
Commit
1a780e6
·
verified ·
1 Parent(s): fe45247

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -101
app.py CHANGED
@@ -1,123 +1,84 @@
1
  import spaces
2
  import gradio as gr
 
3
  import sys
4
  import time
5
  import os
6
  import random
7
- from PIL import Image
8
- import torch
9
- import asyncio # Import asyncio
10
- from skyreelsinfer import TaskType
11
  from skyreelsinfer.offload import OffloadConfig
12
- from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
13
- from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError
14
  from diffusers.utils import export_to_video
15
  from diffusers.utils import load_image
16
 
17
- # os.environ["CUDA_VISIBLE_DEVICES"] = "" # Uncomment if needed
18
- os.environ["SAFETENSORS_FAST_GPU"] = "1"
19
- os.putenv("HF_HUB_ENABLE_HF_TRANSFER", "1")
20
-
21
- # No longer needed here: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
 
23
- # Use gr.State to hold the predictor. Initialize it to None.
24
- predictor_state = gr.State(None)
25
- device="cuda:0" if torch.cuda.is_available() else "cpu" # Pass device to the constructor
26
-
27
- def init_predictor(task_type: str):
28
- try:
29
- predictor = SkyReelsVideoInfer(
30
- task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
31
- model_id="Skywork/skyreels-v1-Hunyuan-i2v", # Adjust model ID as needed
32
- quant_model=True,
33
- is_offload=False, # Consider removing if you have enough GPU memory
34
- offload_config=None, #OffloadConfig(
35
- # high_cpu_memory=True,
36
- # parameters_level=True,
37
- #),
38
- use_multiprocessing=False,
39
  )
40
- return predictor
41
- except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError) as e:
42
- print(f"Error: Model not found. Details: {e}")
43
- return None
44
- except Exception as e:
45
- print(f"Error loading model: {e}")
46
- return None
47
-
48
- # Make generate_video async
49
- async def generate_video(prompt, image_file, predictor):
50
- if image_file is None:
51
- return gr.Error("Error: For i2v, provide an image.")
52
- if not isinstance(prompt, str) or not prompt.strip():
53
- return gr.Error("Error: Please provide a prompt.")
54
- if predictor is None:
55
- return gr.Error("Error: Model not loaded.")
56
- random.seed(time.time())
57
- seed = int(random.randrange(4294967294))
58
  kwargs = {
59
  "prompt": prompt,
60
- "height": 256,
61
- "width": 256,
62
- "num_frames": 24,
63
  "num_inference_steps": 30,
64
- "seed": int(seed),
65
- "guidance_scale": 7.0,
66
  "embedded_guidance_scale": 1.0,
67
- "negative_prompt": "bad quality, blur",
68
  "cfg_for": False,
69
  }
70
- try:
71
- # Load the image and move it to the correct device *before* inference
72
- image = load_image(image=image_file.name)
73
- # No need to manually move to device. SkyReelsVideoInfer should handle it.
74
- kwargs["image"] = image
75
- except Exception as e:
76
- return gr.Error(f"Image loading error: {e}")
77
- try:
78
- output = predictor.inference(kwargs)
79
- frames = output
80
- except Exception as e:
81
- return gr.Error(f"Inference error: {e}"), None # Return None for predictor on error
82
- save_dir = "./result/i2v" # Consistent directory
83
  os.makedirs(save_dir, exist_ok=True)
84
- video_out_file = os.path.join(save_dir, f"{prompt[:100]}_{int(seed)}.mp4")
85
- print(f"Generating video: {video_out_file}")
86
- try:
87
- export_to_video(frames, video_out_file, fps=24)
88
- except Exception as e:
89
- return gr.Error(f"Video export error: {e}"), None # Return None for predictor
90
- return video_out_file, predictor # Return updated predictor
91
 
92
- def display_image(file):
93
- if file is not None:
94
- return Image.open(file.name)
95
- else:
96
- return None
97
-
98
- async def load_model():
99
- predictor = init_predictor('i2v')
100
- return predictor
101
-
102
- async def main():
103
- with gr.Blocks() as demo:
104
- image_file = gr.File(label="Image Prompt (Required)", file_types=["image"])
105
- image_file_preview = gr.Image(label="Image Prompt Preview", interactive=False)
106
- prompt_textbox = gr.Text(label="Prompt")
107
- generate_button = gr.Button("Generate")
108
- output_video = gr.Video(label="Output Video")
109
- image_file.change(
110
- display_image,
111
- inputs=[image_file],
112
- outputs=[image_file_preview]
113
- )
114
- generate_button.click(
115
- fn=generate_video,
116
- inputs=[prompt_textbox, image_file, predictor_state],
117
- outputs=[output_video, predictor_state], # Output predictor_state
118
- )
119
- predictor_state.value = await load_model() # load and set predictor
120
- await demo.launch()
121
 
122
  if __name__ == "__main__":
123
- asyncio.run(main())
 
 
 
 
 
1
  import spaces
2
  import gradio as gr
3
+ import argparse
4
  import sys
5
  import time
6
  import os
7
  import random
 
 
 
 
8
  from skyreelsinfer.offload import OffloadConfig
9
+ from skyreelsinfer import TaskType
10
+ from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer
11
  from diffusers.utils import export_to_video
12
  from diffusers.utils import load_image
13
 
14
+ #predictor = None
15
+ #task_type = None
 
 
 
16
 
17
+ #@spaces.GPU(duration=120)
18
+ def init_predictor():
19
+ global predictor
20
+ predictor = SkyReelsVideoSingleGpuInfer(
21
+ task_type= TaskType.I2V,
22
+ model_id="Skywork/SkyReels-V1-Hunyuan-I2V",
23
+ quant_model=False,
24
+ is_offload=False,
25
+ offload_config=OffloadConfig(
26
+ high_cpu_memory=True,
27
+ parameters_level=True,
28
+ compiler_transformer=False,
 
 
 
 
29
  )
30
+ )
31
+
32
+ @spaces.GPU(duration=80)
33
+ def generate_video(prompt, seed, image=None):
34
+ print(f"image:{type(image)}")
35
+ if seed == -1:
36
+ random.seed(time.time())
37
+ seed = int(random.randrange(4294967294))
 
 
 
 
 
 
 
 
 
 
38
  kwargs = {
39
  "prompt": prompt,
40
+ "height": 512,
41
+ "width": 512,
42
+ "num_frames": 97,
43
  "num_inference_steps": 30,
44
+ "seed": seed,
45
+ "guidance_scale": 6.0,
46
  "embedded_guidance_scale": 1.0,
47
+ "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
48
  "cfg_for": False,
49
  }
50
+ assert image is not None, "please input image"
51
+ kwargs["image"] = load_image(image=image)
52
+ #global predictor
53
+ output = predictor.inference(kwargs)
54
+ save_dir = f"./result/{task_type}"
 
 
 
 
 
 
 
 
55
  os.makedirs(save_dir, exist_ok=True)
56
+ video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
57
+ print(f"generate video, local path: {video_out_file}")
58
+ export_to_video(output, video_out_file, fps=24)
59
+ return video_out_file, kwargs
 
 
 
60
 
61
+ def create_gradio_interface():
62
+ with gr.Blocks() as demo:
63
+ with gr.Row():
64
+ image = gr.Image(label="Upload Image", type="filepath")
65
+ prompt = gr.Textbox(label="Input Prompt")
66
+ seed = gr.Number(label="Random Seed", value=-1)
67
+ submit_button = gr.Button("Generate Video")
68
+ output_video = gr.Video(label="Generated Video")
69
+ output_params = gr.Textbox(label="Output Parameters")
70
+ submit_button.click(
71
+ fn=generate_video,
72
+ inputs=[prompt, seed, image],
73
+ outputs=[output_video, output_params],
74
+ )
75
+ return demo
76
+
77
+ #init_predictor()
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  if __name__ == "__main__":
80
+ #import multiprocessing
81
+ #multiprocessing.freeze_support()
82
+ init_predictor()
83
+ demo = create_gradio_interface()
84
+ demo.launch()