1inkusFace commited on
Commit
47d7323
·
verified ·
1 Parent(s): 2ad4023

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -53
app.py CHANGED
@@ -10,15 +10,7 @@ import random
10
  # Create the gr.State component *outside* the gr.Blocks context
11
  predictor_state = gr.State(None)
12
 
13
- def get_transformer_model_id(task_type: str) -> str:
14
- if task_type == "i2v":
15
- return "Skywork/skyreels-v1-Hunyuan-i2v"
16
- else:
17
- return "Skywork/skyreels-v1-Hunyuan-t2v"
18
-
19
- # @spaces.GPU(duration=120)
20
- def init_predictor(task_type: str):
21
- # ALL IMPORTS NOW INSIDE THIS FUNCTION
22
  import torch
23
  from skyreelsinfer import TaskType
24
  from skyreelsinfer.offload import OffloadConfig
@@ -28,7 +20,7 @@ def init_predictor(task_type: str):
28
  try:
29
  predictor = SkyReelsVideoInfer(
30
  task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
31
- model_id=get_transformer_model_id(task_type),
32
  quant_model=True,
33
  is_offload=True,
34
  offload_config=OffloadConfig(
@@ -47,20 +39,19 @@ def init_predictor(task_type: str):
47
  init_predictor('i2v')
48
 
49
  @spaces.GPU(duration=80)
50
- def generate_video(prompt, seed, image, task_type, predictor): # predictor as argument
51
- # IMPORTS INSIDE THIS FUNCTION TOO
52
  from diffusers.utils import export_to_video
53
  from diffusers.utils import load_image
54
  import os
55
 
56
- if task_type == "i2v" and not isinstance(image, str):
57
  return "Error: For i2v, provide image path.", "{}"
58
- if not isinstance(prompt, str) or not isinstance(seed, (int, float)):
59
- return "Error: Invalid inputs.", "{}"
60
 
61
- if seed == -1:
62
- random.seed(time.time())
63
- seed = int(random.randrange(4294967294))
64
 
65
  kwargs = {
66
  "prompt": prompt,
@@ -75,51 +66,41 @@ def generate_video(prompt, seed, image, task_type, predictor): # predictor as ar
75
  "cfg_for": False,
76
  }
77
 
78
- if task_type == "i2v":
79
- if image is None or not os.path.exists(image):
80
- return "Error: Image not found.", "{}"
81
- try:
82
- kwargs["image"] = load_image(image=image)
83
- except Exception as e:
84
- return f"Error loading image: {e}", "{}"
85
-
86
- try:
87
- if predictor is None:
88
- return "Error: Model not init.", "{}"
89
-
90
- output = predictor.inference(kwargs)
91
- frames = output
92
-
93
- save_dir = f"./result/{task_type}"
94
- os.makedirs(save_dir, exist_ok=True)
95
- video_out_file = f"{save_dir}/{prompt[:100]}_{int(seed)}.mp4"
96
- print(f"Generating video: {video_out_file}")
97
- export_to_video(frames, video_out_file, fps=24)
98
- return video_out_file
99
-
100
- except Exception as e:
101
- return f"Error: {e}", "{}"
102
-
103
  # --- Minimal Gradio Interface ---
104
  with gr.Blocks() as demo:
105
- task_type_dropdown = gr.Dropdown(
106
- choices=["i2v", "t2v"], label="Task", value="t2v", elem_id="task_type"
107
- )
108
- load_model_button = gr.Button("Load Model")
109
  prompt_textbox = gr.Textbox(label="Prompt")
110
  generate_button = gr.Button("Generate")
111
- output_textbox = gr.Textbox(label="Output") # Just a textbox
112
  output_video = gr.Video(label="Output Video") # Just a textbox
113
 
114
- load_model_button.click(
115
- fn=init_predictor,
116
- inputs=[task_type_dropdown],
117
- outputs=[output_textbox, predictor_state], # Correct order of outputs
118
  )
119
 
120
  generate_button.click(
121
  fn=generate_video,
122
- inputs=[prompt_textbox, task_type_dropdown, predictor_state],
123
  outputs=[output_video],
124
  )
125
 
 
10
  # Create the gr.State component *outside* the gr.Blocks context
11
  predictor_state = gr.State(None)
12
 
13
+ def init_predictor(task_type: str):
 
 
 
 
 
 
 
 
14
  import torch
15
  from skyreelsinfer import TaskType
16
  from skyreelsinfer.offload import OffloadConfig
 
20
  try:
21
  predictor = SkyReelsVideoInfer(
22
  task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
23
+ model_id="Skywork/skyreels-v1-Hunyuan-i2v",
24
  quant_model=True,
25
  is_offload=True,
26
  offload_config=OffloadConfig(
 
39
  init_predictor('i2v')
40
 
41
  @spaces.GPU(duration=80)
42
+ def generate_video(prompt, image, predictor):
 
43
  from diffusers.utils import export_to_video
44
  from diffusers.utils import load_image
45
  import os
46
 
47
+ if image == None:
48
  return "Error: For i2v, provide image path.", "{}"
49
+ if not isinstance(prompt, str):
50
+ return "Error: No prompt.", "{}"
51
 
52
+ #if seed == -1:
53
+ random.seed(time.time())
54
+ seed = int(random.randrange(4294967294))
55
 
56
  kwargs = {
57
  "prompt": prompt,
 
66
  "cfg_for": False,
67
  }
68
 
69
+ kwargs["image"] = load_image(image=image)
70
+ output = predictor.inference(kwargs)
71
+ frames = output
72
+ save_dir = f"./result/{task_type}"
73
+ os.makedirs(save_dir, exist_ok=True)
74
+ video_out_file = f"{save_dir}/{prompt[:100]}_{int(seed)}.mp4"
75
+ print(f"Generating video: {video_out_file}")
76
+ export_to_video(frames, video_out_file, fps=24)
77
+ return video_out_file
78
+
79
+
80
+ def display_image(file):
81
+ if file is not None:
82
+ return Image.open(file.name)
83
+ else:
84
+ return None
85
+
 
 
 
 
 
 
 
 
86
  # --- Minimal Gradio Interface ---
87
  with gr.Blocks() as demo:
88
+
89
+ image_file = gr.File(label="Image Prompt (Required)", file_types=["image"])
90
+ image_file_preview = gr.Image(label="Image Prompt Preview", interactive=False)
 
91
  prompt_textbox = gr.Textbox(label="Prompt")
92
  generate_button = gr.Button("Generate")
 
93
  output_video = gr.Video(label="Output Video") # Just a textbox
94
 
95
+ image_file.change(
96
+ display_image,
97
+ inputs=[image_file],
98
+ outputs=[image_file_preview]
99
  )
100
 
101
  generate_button.click(
102
  fn=generate_video,
103
+ inputs=[prompt_textbox, image, predictor_state],
104
  outputs=[output_video],
105
  )
106