Lemonator commited on
Commit
5e15674
Β·
verified Β·
1 Parent(s): 1f619e0

Update app_lora.py

Browse files
Files changed (1) hide show
  1. app_lora.py +66 -78
app_lora.py CHANGED
@@ -7,40 +7,65 @@ import tempfile
7
  import os
8
  import subprocess
9
 
 
 
10
  from huggingface_hub import hf_hub_download
11
  import numpy as np
12
  from PIL import Image
13
  import random
14
 
15
  import warnings
16
- warnings.filterwarnings("ignore", message=".*Attempting to use legacy OpenCV backend.*")
17
- warnings.filterwarnings("ignore", message=".*num_frames - 1.*")
18
 
19
  MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
20
  LORA_REPO_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
21
  LORA_FILENAME = "FusionX_LoRa/Wan2.1_I2V_14B_FusionX_LoRA.safetensors"
22
 
23
- # Global variable to hold the pipeline. It's initialized to None.
 
24
  pipe = None
25
 
26
- def initialize_pipeline():
 
 
 
 
 
 
 
 
 
 
 
 
27
  """
28
- Initializes the model pipeline on the first request.
29
- This function is designed for serverless GPU environments like ZeroGPU.
30
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  global pipe
32
- # The 'pipe' global variable acts as a flag. If it's not None, we've already initialized.
 
 
33
  if pipe is None:
34
- print("First time setup: Initializing model pipeline...")
35
- gr.Info("Cold start: The first generation will take longer as the model is loaded.")
36
-
37
- if not torch.cuda.is_available():
38
- raise gr.Error("GPU not available. This application requires a GPU to run.")
39
-
40
  image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float16)
41
  vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float16)
42
 
43
- # All model loading happens here, when a GPU is guaranteed to be active.
44
  pipe = WanImageToVideoPipeline.from_pretrained(
45
  MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.float16
46
  )
@@ -49,74 +74,17 @@ def initialize_pipeline():
49
 
50
  try:
51
  causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
52
- print("βœ… LoRA downloaded to:", causvid_path)
53
  pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
54
  pipe.set_adapters(["causvid_lora"], adapter_weights=[0.75])
55
  pipe.fuse_lora()
 
56
  except Exception as e:
57
  raise gr.Error(f"Error loading LoRA: {e}")
58
 
59
  print("βœ… Pipeline initialized successfully.")
60
 
61
- # --- Constants and Helper Functions ---
62
- # (These are unchanged)
63
- MOD_VALUE = 32
64
- DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE = 640, 1024
65
- NEW_FORMULA_MAX_AREA = 640.0 * 1024.0
66
- SLIDER_MIN_H, SLIDER_MAX_H = 128, 1024
67
- SLIDER_MIN_W, SLIDER_MAX_W = 128, 1024
68
- MAX_SEED = np.iinfo(np.int32).max
69
- FIXED_FPS, MIN_FRAMES_MODEL, MAX_FRAMES_MODEL = 24, 8, 240
70
- default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
71
- default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
72
-
73
- def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
74
- min_slider_h, max_slider_h, min_slider_w, max_slider_w,
75
- default_h, default_w):
76
- orig_w, orig_h = pil_image.size
77
- if orig_w <= 0 or orig_h <= 0: return default_h, default_w
78
- aspect_ratio = orig_h / orig_w
79
- calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
80
- calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
81
- calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
82
- calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
83
- new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
84
- new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
85
- return new_h, new_w
86
-
87
- def handle_image_upload_for_dims_wan(uploaded_pil_image):
88
- if uploaded_pil_image is None:
89
- return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
90
- try:
91
- new_h, new_w = _calculate_new_dimensions_wan(
92
- uploaded_pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA,
93
- SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W,
94
- DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
95
- )
96
- return gr.update(value=new_h), gr.update(value=new_w)
97
- except Exception as e:
98
- gr.Warning("Error calculating new dimensions.")
99
- return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
100
-
101
- def export_video_with_ffmpeg(frames, output_path, fps=24):
102
- try:
103
- import imageio
104
- writer = imageio.get_writer(output_path, fps=fps, codec='libx264',
105
- pixelformat='yuv420p', quality=8)
106
- for frame in frames:
107
- writer.append_data(np.array(frame))
108
- writer.close()
109
- except ImportError:
110
- export_to_video(frames, output_path, fps=fps)
111
-
112
- def generate_video(input_image, prompt, height, width,
113
- negative_prompt, duration_seconds,
114
- guidance_scale, steps, seed, randomize_seed,
115
- progress=gr.Progress(track_tqdm=True)):
116
-
117
- # --- LAZY LOADING TRIGGER ---
118
- # This will load the model on the first run, and do nothing on subsequent runs.
119
- initialize_pipeline()
120
 
121
  if input_image is None:
122
  raise gr.Error("Please upload an input image.")
@@ -143,7 +111,8 @@ def generate_video(input_image, prompt, height, width,
143
  image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
144
  height=target_h, width=target_w, num_frames=num_frames,
145
  guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
146
- generator=torch.Generator(device="cuda").manual_seed(current_seed)
 
147
  ).frames[0]
148
  except torch.cuda.OutOfMemoryError:
149
  raise gr.Error("Out of GPU memory. Try reducing duration or resolution.")
@@ -152,11 +121,31 @@ def generate_video(input_image, prompt, height, width,
152
 
153
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
154
  video_path = tmpfile.name
155
- export_video_with_ffmpeg(output_frames_list, video_path, fps=FIXED_FPS)
 
 
 
 
 
156
 
157
  return video_path, current_seed
158
 
159
  # --- Gradio UI ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  with gr.Blocks() as demo:
161
  gr.Markdown("# Wan 2.1 I2V FusionX-LoRA (ZeroGPU Ready)")
162
  gr.Markdown("The first generation will be slow due to a 'cold start'. Subsequent generations will be much faster.")
@@ -186,5 +175,4 @@ with gr.Blocks() as demo:
186
  generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
187
 
188
  if __name__ == "__main__":
189
- # We launch the demo unconditionally now. The GPU check is deferred until the first click.
190
- demo.queue(max_size=3).launch()
 
7
  import os
8
  import subprocess
9
 
10
+ # The spaces library IS required for ZeroGPU.
11
+ import spaces
12
  from huggingface_hub import hf_hub_download
13
  import numpy as np
14
  from PIL import Image
15
  import random
16
 
17
  import warnings
18
+ warnings.filterwarnings("ignore")
 
19
 
20
  MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
21
  LORA_REPO_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
22
  LORA_FILENAME = "FusionX_LoRa/Wan2.1_I2V_14B_FusionX_LoRA.safetensors"
23
 
24
+ # --- Global variable for the pipeline ---
25
+ # We use a global variable to cache the model between calls.
26
  pipe = None
27
 
28
+ # --- Constants and Helper Functions ---
29
+ MOD_VALUE = 32
30
+ DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE = 640, 1024
31
+ NEW_FORMULA_MAX_AREA = 640.0 * 1024.0
32
+ SLIDER_MIN_H, SLIDER_MAX_H = 128, 1024
33
+ SLIDER_MIN_W, SLIDER_MAX_W = 128, 1024
34
+ MAX_SEED = np.iinfo(np.int32).max
35
+ FIXED_FPS, MIN_FRAMES_MODEL, MAX_FRAMES_MODEL = 24, 8, 240
36
+ default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
37
+ default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
38
+
39
+
40
+ def get_duration(duration_seconds):
41
  """
42
+ Dynamically set the timeout for the @spaces.GPU decorator based on video length.
 
43
  """
44
+ if duration_seconds > 7: return 180
45
+ if duration_seconds > 5: return 120
46
+ if duration_seconds > 3: return 90
47
+ return 60
48
+
49
+ # --- The Main GPU Function ---
50
+ # The @spaces.GPU decorator is ESSENTIAL for ZeroGPU.
51
+ # It tells the platform that this function needs a GPU.
52
+ @spaces.GPU(duration=60) # Default duration, can be updated dynamically
53
+ def generate_video(input_image, prompt, height, width,
54
+ negative_prompt, duration_seconds,
55
+ guidance_scale, steps, seed, randomize_seed,
56
+ progress=gr.Progress(track_tqdm=True)):
57
+
58
  global pipe
59
+
60
+ # --- LAZY LOADING of the model ---
61
+ # This block will only run on the very first generation request.
62
  if pipe is None:
63
+ progress(0, desc="Cold start: Initializing model...")
64
+ print("Cold start: Initializing model pipeline...")
65
+
 
 
 
66
  image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float16)
67
  vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float16)
68
 
 
69
  pipe = WanImageToVideoPipeline.from_pretrained(
70
  MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.float16
71
  )
 
74
 
75
  try:
76
  causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
 
77
  pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
78
  pipe.set_adapters(["causvid_lora"], adapter_weights=[0.75])
79
  pipe.fuse_lora()
80
+ print("βœ… LoRA loaded successfully.")
81
  except Exception as e:
82
  raise gr.Error(f"Error loading LoRA: {e}")
83
 
84
  print("βœ… Pipeline initialized successfully.")
85
 
86
+ # Update the GPU duration based on user input for longer videos.
87
+ spaces.set_timeout(get_duration(duration_seconds))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  if input_image is None:
90
  raise gr.Error("Please upload an input image.")
 
111
  image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
112
  height=target_h, width=target_w, num_frames=num_frames,
113
  guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
114
+ generator=torch.Generator(device="cuda").manual_seed(current_seed),
115
+ callback_on_step_end=lambda p, s, t: progress(s/int(steps))
116
  ).frames[0]
117
  except torch.cuda.OutOfMemoryError:
118
  raise gr.Error("Out of GPU memory. Try reducing duration or resolution.")
 
121
 
122
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
123
  video_path = tmpfile.name
124
+ # (Video export logic is unchanged)
125
+ import imageio
126
+ writer = imageio.get_writer(video_path, fps=FIXED_FPS, codec='libx264', pixelformat='yuv420p', quality=8)
127
+ for frame in output_frames_list:
128
+ writer.append_data(np.array(frame))
129
+ writer.close()
130
 
131
  return video_path, current_seed
132
 
133
  # --- Gradio UI ---
134
+ # (Helper functions for UI are unchanged)
135
+ def handle_image_upload_for_dims_wan(uploaded_pil_image):
136
+ if uploaded_pil_image is None: return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
137
+ try:
138
+ orig_w, orig_h = uploaded_pil_image.size
139
+ aspect_ratio = orig_h / orig_w
140
+ calc_h = round(np.sqrt(NEW_FORMULA_MAX_AREA * aspect_ratio))
141
+ calc_w = round(np.sqrt(NEW_FORMULA_MAX_AREA / aspect_ratio))
142
+ calc_h = max(MOD_VALUE, (calc_h // MOD_VALUE) * MOD_VALUE)
143
+ calc_w = max(MOD_VALUE, (calc_w // MOD_VALUE) * MOD_VALUE)
144
+ new_h = int(np.clip(calc_h, SLIDER_MIN_H, SLIDER_MAX_H))
145
+ new_w = int(np.clip(calc_w, SLIDER_MIN_W, SLIDER_MAX_W))
146
+ return gr.update(value=new_h), gr.update(value=new_w)
147
+ except: return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
148
+
149
  with gr.Blocks() as demo:
150
  gr.Markdown("# Wan 2.1 I2V FusionX-LoRA (ZeroGPU Ready)")
151
  gr.Markdown("The first generation will be slow due to a 'cold start'. Subsequent generations will be much faster.")
 
175
  generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
176
 
177
  if __name__ == "__main__":
178
+ demo.queue().launch()