Lemonator commited on
Commit
e9b5918
Β·
verified Β·
1 Parent(s): 3ee8e25

Update app_lora.py

Browse files
Files changed (1) hide show
  1. app_lora.py +58 -100
app_lora.py CHANGED
@@ -14,180 +14,138 @@ from PIL import Image
14
  import random
15
 
16
  import warnings
17
- warnings.filterwarnings("ignore", message=".*Attempting to use legacy OpenCV backend.*")
18
- warnings.filterwarnings("ignore", message=".*num_frames - 1.*")
19
 
20
  MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
21
-
22
  LORA_REPO_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
23
  LORA_FILENAME = "FusionX_LoRa/Wan2.1_I2V_14B_FusionX_LoRA.safetensors"
24
 
25
  # --- Model Loading at Startup ---
26
- # This is the correct pattern for your environment. The model is loaded once
27
- # when the Space starts, leading to a longer build but a fast experience for users.
28
  image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float16)
29
  vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float16)
30
  pipe = WanImageToVideoPipeline.from_pretrained(
31
  MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.float16
32
  )
33
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
34
-
35
- # Enable memory efficient attention and CPU offloading
36
  pipe.enable_model_cpu_offload()
37
 
38
- # THE FIX: These two lines caused the original error and have been removed.
39
- # pipe.enable_vae_slicing()
40
- # pipe.enable_vae_tiling()
41
-
42
  try:
43
  causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
44
  print("βœ… LoRA downloaded to:", causvid_path)
45
-
46
  pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
47
  pipe.set_adapters(["causvid_lora"], adapter_weights=[0.75])
48
  pipe.fuse_lora()
49
-
50
  except Exception as e:
51
- import traceback
52
- print("❌ Error during LoRA loading:")
53
- traceback.print_exc()
54
 
55
- # --- Constants and Helper Functions ---
56
  MOD_VALUE = 32
57
- DEFAULT_H_SLIDER_VALUE = 640
58
- DEFAULT_W_SLIDER_VALUE = 1024
59
- NEW_FORMULA_MAX_AREA = 640.0 * 1024.0
60
-
61
  SLIDER_MIN_H, SLIDER_MAX_H = 128, 1024
62
  SLIDER_MIN_W, SLIDER_MAX_W = 128, 1024
63
  MAX_SEED = np.iinfo(np.int32).max
64
-
65
- FIXED_FPS = 24
66
- MIN_FRAMES_MODEL = 8
67
- MAX_FRAMES_MODEL = 240
68
-
69
- default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
70
- default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
71
-
72
-
73
- def get_duration(duration_seconds):
74
- if duration_seconds > 7: return 180
75
- if duration_seconds > 5: return 120
76
- if duration_seconds > 3: return 90
77
- return 60
78
-
79
- # --- The Main Generation Function ---
80
- # The @spaces.GPU decorator is correctly placed here.
81
- @spaces.GPU(duration=60)
82
  def generate_video(input_image, prompt, height, width,
83
- negative_prompt=default_negative_prompt, duration_seconds=2,
84
- guidance_scale=1, steps=4,
85
- seed=42, randomize_seed=False,
86
  progress=gr.Progress(track_tqdm=True)):
87
 
88
- spaces.set_timeout(get_duration(duration_seconds))
89
-
90
  if input_image is None:
91
  raise gr.Error("Please upload an input image.")
92
 
93
  target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
94
  target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
95
-
96
  raw_frames = int(round(duration_seconds * FIXED_FPS))
97
  num_frames = ((raw_frames - 1) // 4) * 4 + 1
98
- num_frames = np.clip(num_frames, MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
99
 
100
  if num_frames > 120 and max(target_h, target_w) > 768:
101
- scale_factor = 768 / max(target_h, target_w)
102
- target_h = max(MOD_VALUE, int(target_h * scale_factor) // MOD_VALUE * MOD_VALUE)
103
- target_w = max(MOD_VALUE, int(target_w * scale_factor) // MOD_VALUE * MOD_VALUE)
104
- gr.Info(f"Reduced resolution to {target_w}x{target_h} for long video generation")
105
 
106
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
107
  resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS)
108
 
109
- if torch.cuda.is_available():
110
- torch.cuda.empty_cache()
111
-
112
  try:
 
113
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
114
- output_frames_list = pipe(
115
  image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
116
  height=target_h, width=target_w, num_frames=num_frames,
117
  guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
118
  generator=torch.Generator(device="cuda").manual_seed(current_seed),
119
  return_dict=True
120
  ).frames[0]
121
- except torch.cuda.OutOfMemoryError:
122
- raise gr.Error("Out of GPU memory. Try reducing the duration or resolution.")
123
  except Exception as e:
124
- raise gr.Error(f"Generation failed: {str(e)}")
125
  finally:
126
- if torch.cuda.is_available():
127
- torch.cuda.empty_cache()
128
 
129
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
130
  video_path = tmpfile.name
131
  import imageio
132
  writer = imageio.get_writer(video_path, fps=FIXED_FPS, codec='libx264',
133
  pixelformat='yuv420p', quality=8)
134
- for frame in output_frames_list:
135
  writer.append_data(np.array(frame))
136
  writer.close()
137
-
138
  return video_path, current_seed
139
 
140
  # --- Gradio UI ---
141
  with gr.Blocks() as demo:
142
- gr.Markdown("# Fast 4 steps Wan 2.1 I2V (14B) FusionX-LoRA")
143
 
144
  with gr.Row():
145
  with gr.Column():
146
- input_image_component = gr.Image(type="pil", label="Input Image")
147
- prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
148
- duration_seconds_input = gr.Slider(minimum=0.3, maximum=10.0, step=0.1, value=2, label="Duration (seconds)")
149
  with gr.Accordion("Advanced Settings", open=False):
150
- negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
151
- seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
152
- randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
153
  with gr.Row():
154
- height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label="Height")
155
- width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label="Width")
156
- steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Inference Steps")
157
- guidance_scale_input = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.0, label="Guidance Scale", visible=False)
158
- generate_button = gr.Button("Generate Video", variant="primary")
159
  with gr.Column():
160
- video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
161
- gr.Markdown("### Tips:\n- For videos > 5s, consider lower resolutions.\n- 4-8 steps is often optimal.")
162
 
163
- def handle_image_upload(img):
164
- if img is None: return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
165
  try:
166
  w, h = img.size
167
- aspect = h / w
168
- calc_h = round(np.sqrt(NEW_FORMULA_MAX_AREA * aspect))
169
- calc_w = round(np.sqrt(NEW_FORMULA_MAX_AREA / aspect))
170
- new_h = int(np.clip((calc_h // MOD_VALUE) * MOD_VALUE, SLIDER_MIN_H, SLIDER_MAX_H))
171
- new_w = int(np.clip((calc_w // MOD_VALUE) * MOD_VALUE, SLIDER_MIN_W, SLIDER_MAX_W))
172
- return gr.update(value=new_h), gr.update(value=new_w)
173
- except: return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
174
-
175
- input_image_component.upload(handle_image_upload, inputs=input_image_component, outputs=[height_input, width_input])
 
176
 
177
- ui_inputs = [input_image_component, prompt_input, height_input, width_input, negative_prompt_input, duration_seconds_input, guidance_scale_input, steps_slider, seed_input, randomize_seed_checkbox]
178
- generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
179
-
180
- # The gr.Examples requires the files to be in your repo. Commenting out to prevent errors.
181
- # gr.Examples(
182
- # examples=[
183
- # ["peng.png", "a penguin playfully dancing in the snow, Antarctica", 896, 512],
184
- # ["forg.jpg", "the frog jumps around", 448, 832],
185
- # ],
186
- # inputs=[input_image_component, prompt_input, height_input, width_input],
187
- # outputs=[video_output, seed_input],
188
- # fn=generate_video,
189
- # cache_examples="lazy"
190
- # )
191
 
192
  if __name__ == "__main__":
193
  demo.queue(max_size=3).launch()
 
14
  import random
15
 
16
  import warnings
17
+ warnings.filterwarnings("ignore")
 
18
 
19
  MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
 
20
  LORA_REPO_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
21
  LORA_FILENAME = "FusionX_LoRa/Wan2.1_I2V_14B_FusionX_LoRA.safetensors"
22
 
23
  # --- Model Loading at Startup ---
 
 
24
  image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float16)
25
  vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float16)
26
  pipe = WanImageToVideoPipeline.from_pretrained(
27
  MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.float16
28
  )
29
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
 
 
30
  pipe.enable_model_cpu_offload()
31
 
32
+ # LoRA Loading
 
 
 
33
  try:
34
  causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
35
  print("βœ… LoRA downloaded to:", causvid_path)
 
36
  pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
37
  pipe.set_adapters(["causvid_lora"], adapter_weights=[0.75])
38
  pipe.fuse_lora()
 
39
  except Exception as e:
40
+ print(f"❌ Error during LoRA loading: {e}")
 
 
41
 
42
+ # --- Constants ---
43
  MOD_VALUE = 32
44
+ DEFAULT_H, DEFAULT_W = 640, 1024
45
+ MAX_AREA = DEFAULT_H * DEFAULT_W
 
 
46
  SLIDER_MIN_H, SLIDER_MAX_H = 128, 1024
47
  SLIDER_MIN_W, SLIDER_MAX_W = 128, 1024
48
  MAX_SEED = np.iinfo(np.int32).max
49
+ FIXED_FPS, MIN_FRAMES, MAX_FRAMES = 24, 8, 240
50
+ default_prompt = "make this image come alive, cinematic motion, smooth animation"
51
+ default_neg_prompt = "static, blurry, watermark, text, signature, ugly, deformed"
52
+
53
+ # --- Main Generation Function ---
54
+ # THE FIX: Set a generous, FIXED duration for the decorator. 180 seconds (3 minutes)
55
+ # should be enough for the longest video generation.
56
+ @spaces.GPU(duration=180)
 
 
 
 
 
 
 
 
 
 
57
  def generate_video(input_image, prompt, height, width,
58
+ negative_prompt, duration_seconds,
59
+ guidance_scale, steps, seed, randomize_seed,
 
60
  progress=gr.Progress(track_tqdm=True)):
61
 
 
 
62
  if input_image is None:
63
  raise gr.Error("Please upload an input image.")
64
 
65
  target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
66
  target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
 
67
  raw_frames = int(round(duration_seconds * FIXED_FPS))
68
  num_frames = ((raw_frames - 1) // 4) * 4 + 1
69
+ num_frames = np.clip(num_frames, MIN_FRAMES, MAX_FRAMES)
70
 
71
  if num_frames > 120 and max(target_h, target_w) > 768:
72
+ scale = 768 / max(target_h, target_w)
73
+ target_h = max(MOD_VALUE, int(target_h * scale) // MOD_VALUE * MOD_VALUE)
74
+ target_w = max(MOD_VALUE, int(target_w * scale) // MOD_VALUE * MOD_VALUE)
75
+ gr.Info(f"Reduced resolution to {target_w}x{target_h} for long video.")
76
 
77
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
78
  resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS)
79
 
 
 
 
80
  try:
81
+ torch.cuda.empty_cache()
82
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
83
+ frames = pipe(
84
  image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
85
  height=target_h, width=target_w, num_frames=num_frames,
86
  guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
87
  generator=torch.Generator(device="cuda").manual_seed(current_seed),
88
  return_dict=True
89
  ).frames[0]
90
+ except torch.cuda.OutOfMemoryError as e:
91
+ raise gr.Error("Out of GPU memory. Try reducing duration or resolution.")
92
  except Exception as e:
93
+ raise gr.Error(f"Generation failed: {e}")
94
  finally:
95
+ torch.cuda.empty_cache()
 
96
 
97
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
98
  video_path = tmpfile.name
99
  import imageio
100
  writer = imageio.get_writer(video_path, fps=FIXED_FPS, codec='libx264',
101
  pixelformat='yuv420p', quality=8)
102
+ for frame in frames:
103
  writer.append_data(np.array(frame))
104
  writer.close()
105
+
106
  return video_path, current_seed
107
 
108
  # --- Gradio UI ---
109
  with gr.Blocks() as demo:
110
+ gr.Markdown("# Wan 2.1 I2V FusionX-LoRA")
111
 
112
  with gr.Row():
113
  with gr.Column():
114
+ input_image_comp = gr.Image(type="pil", label="Input Image")
115
+ prompt_comp = gr.Textbox(label="Prompt", value=default_prompt)
116
+ duration_comp = gr.Slider(minimum=round(MIN_FRAMES/FIXED_FPS, 1), maximum=round(MAX_FRAMES/FIXED_FPS, 1), step=0.1, value=2, label="Duration (s)")
117
  with gr.Accordion("Advanced Settings", open=False):
118
+ neg_prompt_comp = gr.Textbox(label="Negative Prompt", value=default_neg_prompt, lines=3)
119
+ seed_comp = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
120
+ rand_seed_comp = gr.Checkbox(label="Randomize seed", value=True)
121
  with gr.Row():
122
+ height_comp = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H, label="Height")
123
+ width_comp = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W, label="Width")
124
+ steps_comp = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Steps")
125
+ guidance_comp = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.0, label="CFG Scale", visible=False)
126
+ gen_button = gr.Button("Generate Video", variant="primary")
127
  with gr.Column():
128
+ video_comp = gr.Video(label="Generated Video", autoplay=True, interactive=False)
129
+ gr.Markdown("### Tips:\n- For long videos (>5s), consider lower resolutions.\n- 4-8 steps is often optimal.")
130
 
131
+ def handle_upload(img):
132
+ if img is None: return gr.update(value=DEFAULT_H), gr.update(value=DEFAULT_W)
133
  try:
134
  w, h = img.size
135
+ a = h / w
136
+ h_new = int(np.sqrt(MAX_AREA * a))
137
+ w_new = int(np.sqrt(MAX_AREA / a))
138
+ h_final = max(MOD_VALUE, h_new // MOD_VALUE * MOD_VALUE)
139
+ w_final = max(MOD_VALUE, w_new // MOD_VALUE * MOD_VALUE)
140
+ return gr.update(value=h_final), gr.update(value=w_final)
141
+ except Exception:
142
+ return gr.update(value=DEFAULT_H), gr.update(value=DEFAULT_W)
143
+
144
+ input_image_comp.upload(handle_upload, inputs=input_image_comp, outputs=[height_comp, width_comp])
145
 
146
+ inputs = [input_image_comp, prompt_comp, height_comp, width_comp, neg_prompt_comp, duration_comp, guidance_comp, steps_comp, seed_comp, rand_seed_comp]
147
+ outputs = [video_comp, seed_comp]
148
+ gen_button.click(fn=generate_video, inputs=inputs, outputs=outputs)
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  if __name__ == "__main__":
151
  demo.queue(max_size=3).launch()