1inkusFace commited on
Commit
eca6bdb
·
verified ·
1 Parent(s): 7e82a5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -74
app.py CHANGED
@@ -1,19 +1,20 @@
1
  import spaces
2
 
3
  import gradio as gr
4
- import argparse
5
  import sys
6
  import time
7
  import os
8
  import random
9
- #sys.path.append("..")
 
 
10
  from skyreelsinfer import TaskType
11
  from skyreelsinfer.offload import OffloadConfig
12
  from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
13
  from diffusers.utils import export_to_video
14
  from diffusers.utils import load_image
15
-
16
- import torch
17
 
18
  torch.backends.cuda.matmul.allow_tf32 = False
19
  torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
@@ -24,44 +25,57 @@ torch.backends.cudnn.benchmark = False
24
  torch.set_float32_matmul_precision("highest")
25
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
 
27
- predictor = None
28
- task_type = None
29
 
30
- def get_transformer_model_id(task_type:str) -> str:
31
  return "Skywork/SkyReels-V1-Hunyuan-I2V" if task_type == "i2v" else "Skywork/SkyReels-V1-Hunyuan-T2V"
32
 
33
- @spaces.GPU()
34
- def init_predictor(task_type:str, gpu_num:int=1):
35
  global predictor
36
- predictor = SkyReelsVideoInfer(
37
- task_type= TaskType.I2V if task_type == "i2v" else TaskType.T2V,
38
- model_id=get_transformer_model_id(task_type),
39
- quant_model=True,
40
- world_size=gpu_num,
41
- is_offload=True,
42
- offload_config=OffloadConfig(
43
- high_cpu_memory=True,
44
- parameters_level=True,
45
- compiler_transformer=False,
 
 
46
  )
47
- )
48
-
 
 
 
 
 
49
  @spaces.GPU(duration=90)
50
- def generate_video(prompt, seed, image=None):
51
- global task_type
52
- print(f"image:{type(image)}")
 
 
 
 
 
 
53
 
54
  if seed == -1:
55
  random.seed(time.time())
56
  seed = int(random.randrange(4294967294))
57
-
58
  kwargs = {
59
  "prompt": prompt,
60
- "height": 512,
61
- "width": 512,
62
- "num_frames": 97,
63
- "num_inference_steps": 30,
64
- "seed": seed,
65
  "guidance_scale": 6.0,
66
  "embedded_guidance_scale": 1.0,
67
  "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
@@ -69,57 +83,62 @@ def generate_video(prompt, seed, image=None):
69
  }
70
 
71
  if task_type == "i2v":
72
- assert image is not None, "please input image"
73
- kwargs["image"] = load_image(image=image)
74
- global predictor
75
- output = predictor.inference(kwargs)
76
- save_dir = f"./result/{task_type}"
77
- os.makedirs(save_dir, exist_ok=True)
78
- video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{seed}.mp4"
79
- print(f"generate video, local path: {video_out_file}")
80
- export_to_video(output, video_out_file, fps=24)
81
- return video_out_file, kwargs
82
 
 
 
 
 
83
 
84
- def create_gradio_interface(task_type):
85
- """Create a Gradio interface based on the task type."""
86
- if task_type == "i2v":
87
- with gr.Blocks() as demo:
88
- with gr.Row():
89
- image = gr.Image(label="Upload Image", type="filepath")
90
- prompt = gr.Textbox(label="Input Prompt")
91
- seed = gr.Number(label="Random Seed", value=-1)
92
- submit_button = gr.Button("Generate Video")
93
- output_video = gr.Video(label="Generated Video")
94
- output_params = gr.Textbox(label="Output Parameters")
95
 
96
- # Submit button logic
97
- submit_button.click(
98
- fn=generate_video,
99
- inputs=[prompt, seed, image],
100
- outputs=[output_video, output_params],
101
- )
102
 
103
- elif task_type == "t2v":
104
- with gr.Blocks() as demo:
105
- with gr.Row():
106
- prompt = gr.Textbox(label="Input Prompt")
107
- seed = gr.Number(label="Random Seed", value=-1)
 
 
 
 
 
 
 
 
 
108
  submit_button = gr.Button("Generate Video")
 
109
  output_video = gr.Video(label="Generated Video")
110
  output_params = gr.Textbox(label="Output Parameters")
111
 
112
- # Submit button logic
113
- submit_button.click(
114
- fn=generate_video,
115
- inputs=[prompt, seed],
116
- outputs=[output_video, output_params], # Pass task_type as additional input
117
- )
118
 
119
- return demo
 
 
 
 
 
120
 
121
- if __name__ == "__main__":
122
- # Parse command-line arguments
123
- init_predictor(task_type="i2v", gpu_num=1)
124
- demo = create_gradio_interface("i2v")
125
- demo.launch()
 
1
  import spaces
2
 
3
  import gradio as gr
4
+ import argparse # Keep argparse, but we'll modify its use
5
  import sys
6
  import time
7
  import os
8
  import random
9
+ # VERY IMPORTANT: Add the SkyReels-V1 root directory to the Python path
10
+ # Assuming your app.py is in the root of your cloned/forked repo.
11
+ sys.path.append(".") # Correct path for Hugging Face Space
12
  from skyreelsinfer import TaskType
13
  from skyreelsinfer.offload import OffloadConfig
14
  from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
15
  from diffusers.utils import export_to_video
16
  from diffusers.utils import load_image
17
+ import torch # Import Torch
 
18
 
19
  torch.backends.cuda.matmul.allow_tf32 = False
20
  torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
 
25
  torch.set_float32_matmul_precision("highest")
26
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
27
 
28
+ # --- Model Loading (CRITICAL CHANGES) ---
29
+ predictor = None # Global predictor, BUT loaded inside a function
30
 
31
+ def get_transformer_model_id(task_type: str) -> str:
32
  return "Skywork/SkyReels-V1-Hunyuan-I2V" if task_type == "i2v" else "Skywork/SkyReels-V1-Hunyuan-T2V"
33
 
34
+ @spaces.GPU(duration=90)
35
+ def init_predictor(task_type: str):
36
  global predictor
37
+ try:
38
+ predictor = SkyReelsVideoInfer(
39
+ task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
40
+ model_id=get_transformer_model_id(task_type),
41
+ quant_model=True, # Keep quantization for smaller model size
42
+ world_size=1, # VERY IMPORTANT: Set world_size to 1 for CPU
43
+ is_offload=True, # Keep offload for CPU
44
+ offload_config=OffloadConfig(
45
+ high_cpu_memory=True,
46
+ parameters_level=True,
47
+ compiler_transformer=False, # Consider setting to True if compatible
48
+ )
49
  )
50
+ # Explicitly move the predictor to CPU (CRUCIAL)
51
+ if hasattr(predictor, 'pipe') and hasattr(predictor.pipe, 'to'): #check to make sure the predictor has a pipe and to() method
52
+ predictor.pipe.to("cpu")
53
+ return "Model loaded successfully!"
54
+ except Exception as e:
55
+ return f"Error loading model: {e}"
56
+
57
  @spaces.GPU(duration=90)
58
+ def generate_video(prompt, seed, image=None, task_type=None):
59
+ global predictor
60
+
61
+ # Input Type Validation
62
+ if task_type == "i2v" and not isinstance(image, str):
63
+ return "Error: For i2v, please provide a valid image file path.", "{}"
64
+ if not isinstance(prompt, str) or not isinstance(seed, (int, float)):
65
+ return "Error: Invalid input types for prompt or seed.", "{}"
66
+
67
 
68
  if seed == -1:
69
  random.seed(time.time())
70
  seed = int(random.randrange(4294967294))
71
+
72
  kwargs = {
73
  "prompt": prompt,
74
+ "height": 512, # Consider reducing for faster processing on CPU
75
+ "width": 512, # Consider reducing for faster processing on CPU
76
+ "num_frames": 97, # Consider reducing for faster processing on CPU
77
+ "num_inference_steps": 30, # Consider reducing for faster processing
78
+ "seed": int(seed), #make sure seed is int
79
  "guidance_scale": 6.0,
80
  "embedded_guidance_scale": 1.0,
81
  "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
 
83
  }
84
 
85
  if task_type == "i2v":
86
+ if image is None or not os.path.exists(image):
87
+ return "Error: Image not provided or not found.", "{}"
88
+ try:
89
+ kwargs["image"] = load_image(image=image)
90
+ except Exception as e:
91
+ return f"Error loading image: {e}", "{}"
 
 
 
 
92
 
93
+ try:
94
+ #Ensure Predictor is Loaded
95
+ if predictor is None:
96
+ return "Error: Model not initialized. Please reload the Space.", "{}"
97
 
98
+ output = predictor.inference(kwargs)
99
+ save_dir = f"./result/{task_type}"
100
+ os.makedirs(save_dir, exist_ok=True)
101
+ video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{int(seed)}.mp4" # Ensure seed is an integer
102
+ print(f"Generating video, local path: {video_out_file}")
103
+ export_to_video(output, video_out_file, fps=24)
104
+ return video_out_file, str(kwargs) # Return kwargs as a string
 
 
 
 
105
 
106
+ except Exception as e:
107
+ return f"Error during video generation: {e}", "{}"
 
 
 
 
108
 
109
+ # --- Gradio Interface ---
110
+ # We'll define a single interface that handles BOTH i2v and t2v
111
+ with gr.Blocks() as demo:
112
+ with gr.Row():
113
+ task_type_dropdown = gr.Dropdown(
114
+ choices=["i2v", "t2v"], label="Task Type", value="t2v"
115
+ ) # Default to t2v
116
+ load_model_button = gr.Button("Load Model")
117
+ model_status = gr.Textbox(label="Model Status")
118
+ with gr.Row():
119
+ with gr.Column(): # Use Columns for better layout
120
+ prompt = gr.Textbox(label="Input Prompt")
121
+ seed = gr.Number(label="Random Seed", value=-1)
122
+ image = gr.Image(label="Upload Image (for i2v)", type="filepath")
123
  submit_button = gr.Button("Generate Video")
124
+ with gr.Column():
125
  output_video = gr.Video(label="Generated Video")
126
  output_params = gr.Textbox(label="Output Parameters")
127
 
128
+ # Load Model Button Logic
129
+ load_model_button.click(
130
+ fn=init_predictor,
131
+ inputs=[task_type_dropdown],
132
+ outputs=[model_status]
133
+ )
134
 
135
+ # Submit Button Logic (Handles both i2v and t2v)
136
+ submit_button.click(
137
+ fn=generate_video,
138
+ inputs=[prompt, seed, image, task_type_dropdown], # Include task_type
139
+ outputs=[output_video, output_params],
140
+ )
141
 
142
+ # --- Launch the App ---
143
+ # No need for argparse in app.py for Hugging Face Spaces
144
+ # demo.launch() # Don't use demo.launch() inside HuggingFace Spaces.