Spaces:
Running
Running
# Update the sample_solver choices to include 'vanilla' | |
@@ -1035,7 +1035,7 @@ with gr.Blocks( | |
with gr.Row(): | |
wanx_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
- wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++"], label="Sample Solver", value="unipc") | |
+ wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") | |
wanx_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
wanx_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0) | |
wanx_fp8 = gr.Checkbox(label="Use FP8", value=True) | |
# Add exclude_single_blocks checkbox for WanX-i2v | |
@@ -1035,6 +1035,7 @@ with gr.Blocks( | |
with gr.Row(): | |
wanx_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") | |
+ wanx_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
wanx_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
wanx_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0) | |
wanx_fp8 = gr.Checkbox(label="Use FP8", value=True) | |
# Add LoRA support to WanX-i2v tab | |
@@ -979,7 +979,27 @@ with gr.Blocks( | |
) | |
wanx_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") | |
+ # Add LoRA section for WanX-i2v similar to other tabs | |
+ wanx_refresh_btn = gr.Button("π", elem_classes="refresh-btn") | |
+ wanx_lora_weights = [] | |
+ wanx_lora_multipliers = [] | |
+ for i in range(4): | |
+ with gr.Column(): | |
+ wanx_lora_weights.append(gr.Dropdown( | |
+ label=f"LoRA {i+1}", | |
+ choices=get_lora_options(), | |
+ value="None", | |
+ allow_custom_value=True, | |
+ interactive=True | |
+ )) | |
+ wanx_lora_multipliers.append(gr.Slider( | |
+ label=f"Multiplier", | |
+ minimum=0.0, | |
+ maximum=2.0, | |
+ step=0.05, | |
+ value=1.0 | |
+ )) | |
+ | |
with gr.Row(): | |
wanx_seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
wanx_task = gr.Dropdown( | |
@@ -992,6 +1012,7 @@ with gr.Blocks( | |
wanx_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") | |
wanx_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") | |
wanx_clip_path = gr.Textbox(label="CLIP Path", value="wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") | |
+ wanx_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
wanx_save_path = gr.Textbox(label="Save Path", value="outputs") | |
# Update WanX-t2v sample solver choices | |
@@ -1099,7 +1099,7 @@ with gr.Blocks( | |
with gr.Row(): | |
wanx_t2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
- wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++"], label="Sample Solver", value="unipc") | |
+ wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") | |
wanx_t2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
wanx_t2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, | |
info="Max 39 for 14B model, 29 for 1.3B model") | |
# Add exclude_single_blocks checkbox for WanX-t2v | |
@@ -1099,6 +1099,7 @@ with gr.Blocks( | |
with gr.Row(): | |
wanx_t2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") | |
+ wanx_t2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
wanx_t2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
wanx_t2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, | |
info="Max 39 for 14B model, 29 for 1.3B model") | |
# Add LoRA support to WanX-t2v tab | |
@@ -1063,7 +1064,27 @@ with gr.Blocks( | |
) | |
wanx_t2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") | |
+ # Add LoRA section for WanX-t2v | |
+ wanx_t2v_refresh_btn = gr.Button("π", elem_classes="refresh-btn") | |
+ wanx_t2v_lora_weights = [] | |
+ wanx_t2v_lora_multipliers = [] | |
+ for i in range(4): | |
+ with gr.Column(): | |
+ wanx_t2v_lora_weights.append(gr.Dropdown( | |
+ label=f"LoRA {i+1}", | |
+ choices=get_lora_options(), | |
+ value="None", | |
+ allow_custom_value=True, | |
+ interactive=True | |
+ )) | |
+ wanx_t2v_lora_multipliers.append(gr.Slider( | |
+ label=f"Multiplier", | |
+ minimum=0.0, | |
+ maximum=2.0, | |
+ step=0.05, | |
+ value=1.0 | |
+ )) | |
+ | |
with gr.Row(): | |
wanx_t2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
wanx_t2v_task = gr.Dropdown( | |
@@ -1077,6 +1098,7 @@ with gr.Blocks( | |
wanx_t2v_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") | |
wanx_t2v_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") | |
wanx_t2v_clip_path = gr.Textbox(label="CLIP Path", visible=False, value="") | |
+ wanx_t2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
wanx_t2v_save_path = gr.Textbox(label="Save Path", value="outputs") | |
# Update wanx_generate_video function to include LoRA and exclude_single_blocks | |
@@ -2051,6 +2073,15 @@ def wanx_generate_video( | |
save_path, | |
output_type, | |
sample_solver, | |
+ exclude_single_blocks, | |
attn_mode, | |
block_swap, | |
fp8, | |
- fp8_t5 | |
+ fp8_t5, | |
+ lora_folder, | |
+ lora1="None", | |
+ lora2="None", | |
+ lora3="None", | |
+ lora4="None", | |
+ lora1_multiplier=1.0, | |
+ lora2_multiplier=1.0, | |
+ lora3_multiplier=1.0, | |
+ lora4_multiplier=1.0 | |
) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
"""Generate video with WanX model (supports both i2v and t2v)""" | |
global stop_event | |
@@ -2107,6 +2138,20 @@ def wanx_generate_video( | |
if fp8_t5: | |
command.append("--fp8_t5") | |
+ | |
+ if exclude_single_blocks: | |
+ command.append("--exclude_single_blocks") | |
+ | |
+ # Add LoRA weights and multipliers if provided | |
+ valid_loras = [] | |
+ for weight, mult in zip([lora1, lora2, lora3, lora4], | |
+ [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): | |
+ if weight and weight != "None": | |
+ valid_loras.append((os.path.join(lora_folder, weight), mult)) | |
+ if valid_loras: | |
+ weights = [weight for weight, _ in valid_loras] | |
+ multipliers = [str(mult) for _, mult in valid_loras] | |
+ command.extend(["--lora_weight"] + weights) | |
+ command.extend(["--lora_multiplier"] + multipliers) | |
print(f"Running: {' '.join(command)}") | |
# Update wanx_generate_video_batch function | |
@@ -2176,9 +2221,19 @@ def wanx_generate_video_batch( | |
save_path, | |
output_type, | |
sample_solver, | |
+ exclude_single_blocks, | |
attn_mode, | |
block_swap, | |
fp8, | |
- fp8_t5, | |
+ fp8_t5, | |
+ lora_folder, | |
+ lora1="None", | |
+ lora2="None", | |
+ lora3="None", | |
+ lora4="None", | |
+ lora1_multiplier=1.0, | |
+ lora2_multiplier=1.0, | |
+ lora3_multiplier=1.0, | |
+ lora4_multiplier=1.0, | |
batch_size=1, | |
input_image=None # Make input_image optional and place it at the end | |
) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
@@ -2201,9 +2256,19 @@ def wanx_generate_video_batch( | |
save_path, | |
output_type, | |
sample_solver, | |
+ exclude_single_blocks, | |
attn_mode, | |
block_swap, | |
fp8, | |
- fp8_t5 | |
+ fp8_t5, | |
+ lora_folder, | |
+ lora1, | |
+ lora2, | |
+ lora3, | |
+ lora4, | |
+ lora1_multiplier, | |
+ lora2_multiplier, | |
+ lora3_multiplier, | |
+ lora4_multiplier | |
), | |
outputs=[wanx_output, wanx_batch_progress, wanx_progress_text], | |
queue=True | |
# Update WanX-i2v generate button click handler | |
@@ -2423,6 +2488,15 @@ def wanx_generate_btn.click( | |
wanx_save_path, | |
wanx_output_type, | |
wanx_sample_solver, | |
+ wanx_exclude_single_blocks, | |
wanx_attn_mode, | |
wanx_block_swap, | |
wanx_fp8, | |
- wanx_fp8_t5, | |
+ wanx_fp8_t5, | |
+ wanx_lora_folder, | |
+ *wanx_lora_weights, | |
+ *wanx_lora_multipliers, | |
wanx_batch_size, | |
wanx_input # Include the image input for this tab | |
], | |
outputs=[wanx_output, wanx_batch_progress, wanx_progress_text], | |
queue=True | |
) | |
+ | |
+ # Add refresh button handler for WanX-i2v tab | |
+ wanx_refresh_outputs = [] | |
+ for i in range(4): | |
+ wanx_refresh_outputs.extend([wanx_lora_weights[i], wanx_lora_multipliers[i]]) | |
+ | |
+ wanx_refresh_btn.click( | |
+ fn=update_lora_dropdowns, | |
+ inputs=[wanx_lora_folder] + wanx_lora_weights + wanx_lora_multipliers, | |
+ outputs=wanx_refresh_outputs | |
+ ) | |
# Update WanX-t2v generate button click handler | |
@@ -2470,9 +2544,19 @@ def wanx_t2v_generate_btn.click( | |
wanx_t2v_save_path, | |
wanx_t2v_output_type, | |
wanx_t2v_sample_solver, | |
+ wanx_t2v_exclude_single_blocks, | |
wanx_t2v_attn_mode, | |
wanx_t2v_block_swap, | |
wanx_t2v_fp8, | |
- wanx_t2v_fp8_t5, | |
+ wanx_t2v_fp8_t5, | |
+ wanx_t2v_lora_folder, | |
+ *wanx_t2v_lora_weights, | |
+ *wanx_t2v_lora_multipliers, | |
wanx_t2v_batch_size, | |
# input_image is now optional and not included here | |
], | |
outputs=[wanx_t2v_output, wanx_t2v_batch_progress, wanx_t2v_progress_text], | |
queue=True | |
) | |
+ | |
+ # Add refresh button handler for WanX-t2v tab | |
+ wanx_t2v_refresh_outputs = [] | |
+ for i in range(4): | |
+ wanx_t2v_refresh_outputs.extend([wanx_t2v_lora_weights[i], wanx_t2v_lora_multipliers[i]]) | |
+ | |
+ wanx_t2v_refresh_btn.click( | |
+ fn=update_lora_dropdowns, | |
+ inputs=[wanx_t2v_lora_folder] + wanx_t2v_lora_weights + wanx_t2v_lora_multipliers, | |
+ outputs=wanx_t2v_refresh_outputs | |
+ ) |