1inkusFace commited on
Commit
7e3737d
·
verified ·
1 Parent(s): eb08525

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -78
app.py CHANGED
@@ -4,30 +4,27 @@ import sys
4
  import time
5
  import os
6
  import random
7
- from PIL import Image
8
- import torch
9
- import asyncio
10
-
11
- # os.environ["CUDA_VISIBLE_DEVICES"] = "" # Uncomment if needed
12
  os.environ["SAFETENSORS_FAST_GPU"] = "1"
13
- os.putenv("HF_HUB_ENABLE_HF_TRANSFER", "1")
14
-
15
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
 
17
- # Create the gr.State component *outside* the gr.Blocks context. This is CORRECT.
18
- predictor_state = gr.State(None) # Use a consistent name
19
- task_type_state = gr.State("i2v") # Store the task type
20
 
21
  def init_predictor(task_type: str):
22
  from skyreelsinfer import TaskType
23
  from skyreelsinfer.offload import OffloadConfig
24
  from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
25
  from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError
26
-
27
  try:
28
  predictor = SkyReelsVideoInfer(
29
  task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
30
- model_id="Skywork/skyreels-v1-Hunyuan-i2v", # Or a different model for T2V
31
  quant_model=True,
32
  is_offload=True,
33
  offload_config=OffloadConfig(
@@ -38,27 +35,24 @@ def init_predictor(task_type: str):
38
  )
39
  return predictor
40
  except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError) as e:
41
- print(f"Error: Model not found. Details: {e}") # Log the error
42
- return None # Return None if model loading fails
43
  except Exception as e:
44
- print(f"Error loading model: {e}") # Log the error
45
- return None # Return None if model loading fails
46
-
47
-
48
 
49
  @spaces.GPU(duration=80)
50
- async def generate_video(prompt, image_file, predictor):
51
  from diffusers.utils import export_to_video
52
  from diffusers.utils import load_image
53
- if image_file is None:
54
- return gr.Error("Error: For i2v, provide image path.") , None # Use gr.Error
55
- if not isinstance(prompt, str) or not prompt.strip(): # Check for empty prompt
56
- return gr.Error("Error: No prompt."), None # Use gr.Error
57
-
58
- #if seed == -1: #removed seed parameter
59
  random.seed(time.time())
60
  seed = int(random.randrange(4294967294))
61
-
62
  kwargs = {
63
  "prompt": prompt,
64
  "height": 256,
@@ -71,63 +65,42 @@ async def generate_video(prompt, image_file, predictor):
71
  "negative_prompt": "bad quality, blur",
72
  "cfg_for": False,
73
  }
74
- # Check if predictor is initialized and the task type is i2v.
75
-
76
- kwargs["image"] = load_image(image=image_file)
77
-
78
- if predictor is None:
79
- return gr.Error("Predictor not initialized."), None
80
-
81
- try:
82
- output = predictor.inference(kwargs)
83
- except Exception as e:
84
- return gr.Error(f"Inference error: {e}"), None
85
 
 
 
86
  frames = output
87
- save_dir = f"./result/i2v" # Use a consistent directory
88
  os.makedirs(save_dir, exist_ok=True)
89
- video_out_file = os.path.join(save_dir, f"{prompt[:100]}_{int(seed)}.mp4")
90
  print(f"Generating video: {video_out_file}")
91
- try:
92
- export_to_video(frames, video_out_file, fps=24)
93
- except Exception as e:
94
- return gr.Error(f"Video export error: {e}"), None
95
- return video_out_file, predictor # Return the updated predictor
96
-
97
-
98
  def display_image(file):
99
  if file is not None:
100
- return Image.open(file) #.name was removed
101
  else:
102
  return None
103
-
104
- async def load_model_async(task_type):
105
- predictor = init_predictor(task_type)
106
- return predictor
107
-
108
- async def main():
109
- with gr.Blocks() as demo:
110
- # Remove the local predictor variable. We'll use the global predictor_state.
111
- image_file = gr.File(label="Image Prompt (Required)", file_types=["image"])
112
- image_file_preview = gr.Image(label="Image Prompt Preview", interactive=False)
113
- prompt_textbox = gr.Text(label="Prompt")
114
- generate_button = gr.Button("Generate")
115
- output_video = gr.Video(label="Output Video")
116
-
117
- image_file.change(
118
- display_image,
119
- inputs=[image_file],
120
- outputs=[image_file_preview]
121
- )
122
- # Use predictor_state in inputs and outputs, and return it from the function.
123
- generate_button.click(
124
- fn=generate_video,
125
- inputs=[prompt_textbox, image_file, predictor_state],
126
- outputs=[output_video, predictor_state], # Output the predictor_state
127
- )
128
- predictor_state.value = await load_model_async("i2v")
129
-
130
- await demo.launch()
131
-
132
- if __name__ == "__main__":
133
- asyncio.run(main())
 
4
  import time
5
  import os
6
  import random
7
+ from PIL import Image
8
+ # os.environ["CUDA_VISIBLE_DEVICES"] = ""
 
 
 
9
  os.environ["SAFETENSORS_FAST_GPU"] = "1"
10
+ os.putenv("HF_HUB_ENABLE_HF_TRANSFER","1")
11
+ import torch
12
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13
 
14
+ # Create the gr.State component *outside* the gr.Blocks context
15
+
16
+ predictor = gr.State(None)
17
 
18
  def init_predictor(task_type: str):
19
  from skyreelsinfer import TaskType
20
  from skyreelsinfer.offload import OffloadConfig
21
  from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
22
  from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError
23
+ global predictor
24
  try:
25
  predictor = SkyReelsVideoInfer(
26
  task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
27
+ model_id="Skywork/skyreels-v1-Hunyuan-i2v",
28
  quant_model=True,
29
  is_offload=True,
30
  offload_config=OffloadConfig(
 
35
  )
36
  return predictor
37
  except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError) as e:
38
+ return f"Error: Model not found. Details: {e}", None
 
39
  except Exception as e:
40
+ return f"Error loading model: {e}", None
41
+
42
+ predictor = init_predictor('i2v')
43
+ global predictor
44
 
45
  @spaces.GPU(duration=80)
46
+ def generate_video(prompt, image, predictor):
47
  from diffusers.utils import export_to_video
48
  from diffusers.utils import load_image
49
+ if image == None:
50
+ return "Error: For i2v, provide image path.", "{}"
51
+ if not isinstance(prompt, str):
52
+ return "Error: No prompt.", "{}"
53
+ #if seed == -1:
 
54
  random.seed(time.time())
55
  seed = int(random.randrange(4294967294))
 
56
  kwargs = {
57
  "prompt": prompt,
58
  "height": 256,
 
65
  "negative_prompt": "bad quality, blur",
66
  "cfg_for": False,
67
  }
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ kwargs["image"] = load_image(image=image)
70
+ output = predictor.inference(kwargs)
71
  frames = output
72
+ save_dir = f"./result/{task_type}"
73
  os.makedirs(save_dir, exist_ok=True)
74
+ video_out_file = f"{save_dir}/{prompt[:100]}_{int(seed)}.mp4"
75
  print(f"Generating video: {video_out_file}")
76
+ export_to_video(frames, video_out_file, fps=24)
77
+ return video_out_file
78
+
 
 
 
 
79
  def display_image(file):
80
  if file is not None:
81
+ return Image.open(file.name)
82
  else:
83
  return None
84
+
85
+ with gr.Blocks() as demo:
86
+ #predictor = gr.State({}) # Initialize as an empty dictionary
87
+
88
+ image_file = gr.File(label="Image Prompt (Required)", file_types=["image"])
89
+ image_file_preview = gr.Image(label="Image Prompt Preview", interactive=False)
90
+ prompt_textbox = gr.Text(label="Prompt")
91
+ generate_button = gr.Button("Generate")
92
+ output_video = gr.Video(label="Output Video")
93
+
94
+ image_file.change(
95
+ display_image,
96
+ inputs=[image_file],
97
+ outputs=[image_file_preview]
98
+ )
99
+
100
+ generate_button.click(
101
+ fn=generate_video,
102
+ inputs=[prompt_textbox, image_file, predictor],
103
+ outputs=[output_video],
104
+ )
105
+
106
+ demo.launch()