Lemonator commited on
Commit
e1a016b
Β·
verified Β·
1 Parent(s): e9b5918

Update app_lora.py

Browse files
Files changed (1) hide show
  1. app_lora.py +34 -46
app_lora.py CHANGED
@@ -20,16 +20,17 @@ MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
20
  LORA_REPO_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
21
  LORA_FILENAME = "FusionX_LoRa/Wan2.1_I2V_14B_FusionX_LoRA.safetensors"
22
 
23
- # --- Model Loading at Startup ---
 
 
24
  image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float16)
25
  vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float16)
26
  pipe = WanImageToVideoPipeline.from_pretrained(
27
- MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.float16
28
  )
29
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
30
- pipe.enable_model_cpu_offload()
31
 
32
- # LoRA Loading
33
  try:
34
  causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
35
  print("βœ… LoRA downloaded to:", causvid_path)
@@ -46,74 +47,65 @@ MAX_AREA = DEFAULT_H * DEFAULT_W
46
  SLIDER_MIN_H, SLIDER_MAX_H = 128, 1024
47
  SLIDER_MIN_W, SLIDER_MAX_W = 128, 1024
48
  MAX_SEED = np.iinfo(np.int32).max
49
- FIXED_FPS, MIN_FRAMES, MAX_FRAMES = 24, 8, 240
50
  default_prompt = "make this image come alive, cinematic motion, smooth animation"
51
- default_neg_prompt = "static, blurry, watermark, text, signature, ugly, deformed"
52
 
53
- # --- Main Generation Function ---
54
- # THE FIX: Set a generous, FIXED duration for the decorator. 180 seconds (3 minutes)
55
- # should be enough for the longest video generation.
56
- @spaces.GPU(duration=180)
 
 
 
57
  def generate_video(input_image, prompt, height, width,
58
  negative_prompt, duration_seconds,
59
  guidance_scale, steps, seed, randomize_seed,
60
  progress=gr.Progress(track_tqdm=True)):
61
-
62
  if input_image is None:
63
  raise gr.Error("Please upload an input image.")
64
 
65
  target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
66
  target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
 
 
67
  raw_frames = int(round(duration_seconds * FIXED_FPS))
68
  num_frames = ((raw_frames - 1) // 4) * 4 + 1
69
  num_frames = np.clip(num_frames, MIN_FRAMES, MAX_FRAMES)
70
-
71
- if num_frames > 120 and max(target_h, target_w) > 768:
72
- scale = 768 / max(target_h, target_w)
73
- target_h = max(MOD_VALUE, int(target_h * scale) // MOD_VALUE * MOD_VALUE)
74
- target_w = max(MOD_VALUE, int(target_w * scale) // MOD_VALUE * MOD_VALUE)
75
- gr.Info(f"Reduced resolution to {target_w}x{target_h} for long video.")
76
-
77
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
78
  resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS)
79
 
80
- try:
81
- torch.cuda.empty_cache()
82
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
83
- frames = pipe(
84
- image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
85
- height=target_h, width=target_w, num_frames=num_frames,
86
- guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
87
- generator=torch.Generator(device="cuda").manual_seed(current_seed),
88
- return_dict=True
89
- ).frames[0]
90
- except torch.cuda.OutOfMemoryError as e:
91
- raise gr.Error("Out of GPU memory. Try reducing duration or resolution.")
92
- except Exception as e:
93
- raise gr.Error(f"Generation failed: {e}")
94
- finally:
95
- torch.cuda.empty_cache()
96
 
97
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
98
  video_path = tmpfile.name
 
99
  import imageio
100
  writer = imageio.get_writer(video_path, fps=FIXED_FPS, codec='libx264',
101
  pixelformat='yuv420p', quality=8)
102
  for frame in frames:
103
  writer.append_data(np.array(frame))
104
  writer.close()
105
-
106
  return video_path, current_seed
107
 
108
- # --- Gradio UI ---
109
  with gr.Blocks() as demo:
110
- gr.Markdown("# Wan 2.1 I2V FusionX-LoRA")
 
111
 
112
  with gr.Row():
113
  with gr.Column():
114
  input_image_comp = gr.Image(type="pil", label="Input Image")
115
  prompt_comp = gr.Textbox(label="Prompt", value=default_prompt)
116
- duration_comp = gr.Slider(minimum=round(MIN_FRAMES/FIXED_FPS, 1), maximum=round(MAX_FRAMES/FIXED_FPS, 1), step=0.1, value=2, label="Duration (s)")
117
  with gr.Accordion("Advanced Settings", open=False):
118
  neg_prompt_comp = gr.Textbox(label="Negative Prompt", value=default_neg_prompt, lines=3)
119
  seed_comp = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
@@ -126,20 +118,16 @@ with gr.Blocks() as demo:
126
  gen_button = gr.Button("Generate Video", variant="primary")
127
  with gr.Column():
128
  video_comp = gr.Video(label="Generated Video", autoplay=True, interactive=False)
129
- gr.Markdown("### Tips:\n- For long videos (>5s), consider lower resolutions.\n- 4-8 steps is often optimal.")
130
 
131
  def handle_upload(img):
132
  if img is None: return gr.update(value=DEFAULT_H), gr.update(value=DEFAULT_W)
133
  try:
134
- w, h = img.size
135
- a = h / w
136
- h_new = int(np.sqrt(MAX_AREA * a))
137
- w_new = int(np.sqrt(MAX_AREA / a))
138
  h_final = max(MOD_VALUE, h_new // MOD_VALUE * MOD_VALUE)
139
  w_final = max(MOD_VALUE, w_new // MOD_VALUE * MOD_VALUE)
140
  return gr.update(value=h_final), gr.update(value=w_final)
141
- except Exception:
142
- return gr.update(value=DEFAULT_H), gr.update(value=DEFAULT_W)
143
 
144
  input_image_comp.upload(handle_upload, inputs=input_image_comp, outputs=[height_comp, width_comp])
145
 
@@ -148,4 +136,4 @@ with gr.Blocks() as demo:
148
  gen_button.click(fn=generate_video, inputs=inputs, outputs=outputs)
149
 
150
  if __name__ == "__main__":
151
- demo.queue(max_size=3).launch()
 
20
  LORA_REPO_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
21
  LORA_FILENAME = "FusionX_LoRa/Wan2.1_I2V_14B_FusionX_LoRA.safetensors"
22
 
23
+ # --- Model Loading at Startup (Your Correct Method) ---
24
+ # This loads the entire model into GPU VRAM when the Space starts.
25
+ # This is correct for your H200 hardware to ensure fast inference.
26
  image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float16)
27
  vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float16)
28
  pipe = WanImageToVideoPipeline.from_pretrained(
29
+ MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
30
  )
31
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
32
+ pipe.to("cuda")
33
 
 
34
  try:
35
  causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
36
  print("βœ… LoRA downloaded to:", causvid_path)
 
47
  SLIDER_MIN_H, SLIDER_MAX_H = 128, 1024
48
  SLIDER_MIN_W, SLIDER_MAX_W = 128, 1024
49
  MAX_SEED = np.iinfo(np.int32).max
50
+ FIXED_FPS, MIN_FRAMES, MAX_FRAMES = 24, 8, 81
51
  default_prompt = "make this image come alive, cinematic motion, smooth animation"
52
+ default_neg_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
53
 
54
+ # This function correctly provides a static duration to the decorator at startup.
55
+ def get_duration(steps, duration_seconds):
56
+ if steps > 4 and duration_seconds > 2: return 90
57
+ if steps > 4 or duration_seconds > 2: return 75
58
+ return 60
59
+
60
+ @spaces.GPU(duration=60) # Default duration, the get_duration logic inside the function is not effective for the decorator itself
61
  def generate_video(input_image, prompt, height, width,
62
  negative_prompt, duration_seconds,
63
  guidance_scale, steps, seed, randomize_seed,
64
  progress=gr.Progress(track_tqdm=True)):
65
+
66
  if input_image is None:
67
  raise gr.Error("Please upload an input image.")
68
 
69
  target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
70
  target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
71
+
72
+ # Using a robust frame calculation to prevent potential model errors
73
  raw_frames = int(round(duration_seconds * FIXED_FPS))
74
  num_frames = ((raw_frames - 1) // 4) * 4 + 1
75
  num_frames = np.clip(num_frames, MIN_FRAMES, MAX_FRAMES)
76
+
 
 
 
 
 
 
77
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
78
  resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS)
79
 
80
+ with torch.inference_mode():
81
+ frames = pipe(
82
+ image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
83
+ height=target_h, width=target_w, num_frames=num_frames,
84
+ guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
85
+ generator=torch.Generator(device="cuda").manual_seed(current_seed)
86
+ ).frames[0]
 
 
 
 
 
 
 
 
 
87
 
88
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
89
  video_path = tmpfile.name
90
+ # Using a more robust video exporter for better quality and compression
91
  import imageio
92
  writer = imageio.get_writer(video_path, fps=FIXED_FPS, codec='libx264',
93
  pixelformat='yuv420p', quality=8)
94
  for frame in frames:
95
  writer.append_data(np.array(frame))
96
  writer.close()
97
+
98
  return video_path, current_seed
99
 
 
100
  with gr.Blocks() as demo:
101
+ gr.Markdown("# Fast 4 steps Wan 2.1 I2V (14B) fusionx-lora")
102
+ gr.Markdown("Note: The Space will restart after a period of inactivity, causing a one-time long load.")
103
 
104
  with gr.Row():
105
  with gr.Column():
106
  input_image_comp = gr.Image(type="pil", label="Input Image")
107
  prompt_comp = gr.Textbox(label="Prompt", value=default_prompt)
108
+ duration_comp = gr.Slider(minimum=round(MIN_FRAMES/FIXED_FPS,1), maximum=round(MAX_FRAMES/FIXED_FPS,1), step=0.1, value=2, label="Duration (s)")
109
  with gr.Accordion("Advanced Settings", open=False):
110
  neg_prompt_comp = gr.Textbox(label="Negative Prompt", value=default_neg_prompt, lines=3)
111
  seed_comp = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
 
118
  gen_button = gr.Button("Generate Video", variant="primary")
119
  with gr.Column():
120
  video_comp = gr.Video(label="Generated Video", autoplay=True, interactive=False)
 
121
 
122
  def handle_upload(img):
123
  if img is None: return gr.update(value=DEFAULT_H), gr.update(value=DEFAULT_W)
124
  try:
125
+ w, h = img.size; a = h / w
126
+ h_new = int(np.sqrt(MAX_AREA * a)); w_new = int(np.sqrt(MAX_AREA / a))
 
 
127
  h_final = max(MOD_VALUE, h_new // MOD_VALUE * MOD_VALUE)
128
  w_final = max(MOD_VALUE, w_new // MOD_VALUE * MOD_VALUE)
129
  return gr.update(value=h_final), gr.update(value=w_final)
130
+ except Exception: return gr.update(value=DEFAULT_H), gr.update(value=DEFAULT_W)
 
131
 
132
  input_image_comp.upload(handle_upload, inputs=input_image_comp, outputs=[height_comp, width_comp])
133
 
 
136
  gen_button.click(fn=generate_video, inputs=inputs, outputs=outputs)
137
 
138
  if __name__ == "__main__":
139
+ demo.queue().launch()