|
import torch |
|
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler |
|
from diffusers.utils import export_to_video |
|
from transformers import CLIPVisionModel |
|
import gradio as gr |
|
import tempfile |
|
import spaces |
|
from huggingface_hub import hf_hub_download |
|
import numpy as np |
|
from PIL import Image |
|
import random |
|
import logging |
|
import gc |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" |
|
LORA_REPO_ID = "Kijai/WanVideo_comfy" |
|
LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors" |
|
|
|
|
|
MOD_VALUE = 32 |
|
DEFAULT_H_SLIDER_VALUE = 512 |
|
DEFAULT_W_SLIDER_VALUE = 512 |
|
NEW_FORMULA_MAX_AREA = 480.0 * 832.0 |
|
|
|
SLIDER_MIN_H, SLIDER_MAX_H = 128, 896 |
|
SLIDER_MIN_W, SLIDER_MAX_W = 128, 896 |
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
FIXED_FPS = 24 |
|
MIN_FRAMES_MODEL = 8 |
|
MAX_FRAMES_MODEL = 81 |
|
|
|
default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation" |
|
default_negative_prompt = "static, blurred, low quality, watermark, text" |
|
|
|
|
|
logger.info("Loading model components...") |
|
image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32) |
|
vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32) |
|
pipe = WanImageToVideoPipeline.from_pretrained( |
|
MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 |
|
) |
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0) |
|
pipe.to("cuda") |
|
|
|
|
|
try: |
|
causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME) |
|
pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora") |
|
pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95]) |
|
pipe.fuse_lora() |
|
logger.info("LoRA loaded successfully") |
|
except Exception as e: |
|
logger.warning(f"LoRA loading failed: {e}") |
|
|
|
|
|
pipe.enable_vae_slicing() |
|
pipe.enable_vae_tiling() |
|
pipe.enable_model_cpu_offload() |
|
|
|
logger.info("Model loaded and ready") |
|
|
|
def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area, |
|
min_slider_h, max_slider_h, |
|
min_slider_w, max_slider_w, |
|
default_h, default_w): |
|
orig_w, orig_h = pil_image.size |
|
if orig_w <= 0 or orig_h <= 0: |
|
return default_h, default_w |
|
|
|
aspect_ratio = orig_h / orig_w |
|
|
|
|
|
if hasattr(spaces, 'GPU'): |
|
|
|
calculation_max_area = min(calculation_max_area, 320.0 * 320.0) |
|
|
|
calc_h = round(np.sqrt(calculation_max_area * aspect_ratio)) |
|
calc_w = round(np.sqrt(calculation_max_area / aspect_ratio)) |
|
|
|
calc_h = max(mod_val, (calc_h // mod_val) * mod_val) |
|
calc_w = max(mod_val, (calc_w // mod_val) * mod_val) |
|
|
|
|
|
if hasattr(spaces, 'GPU'): |
|
max_slider_h = min(max_slider_h, 640) |
|
max_slider_w = min(max_slider_w, 640) |
|
|
|
new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val)) |
|
new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val)) |
|
|
|
return new_h, new_w |
|
|
|
def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val): |
|
if uploaded_pil_image is None: |
|
return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE) |
|
try: |
|
new_h, new_w = _calculate_new_dimensions_wan( |
|
uploaded_pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA, |
|
SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W, |
|
DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE |
|
) |
|
return gr.update(value=new_h), gr.update(value=new_w) |
|
except Exception as e: |
|
gr.Warning("Error attempting to calculate new dimensions") |
|
return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE) |
|
|
|
def get_duration(input_image, prompt, height, width, |
|
negative_prompt, duration_seconds, |
|
guidance_scale, steps, |
|
seed, randomize_seed, |
|
progress): |
|
|
|
base_time = 60 |
|
|
|
if hasattr(spaces, 'GPU'): |
|
|
|
if steps > 4 and duration_seconds > 2: |
|
return 90 |
|
elif steps > 4 or duration_seconds > 2: |
|
return 80 |
|
else: |
|
return 70 |
|
else: |
|
|
|
if steps > 4 and duration_seconds > 2: |
|
return 90 |
|
elif steps > 4 or duration_seconds > 2: |
|
return 75 |
|
else: |
|
return 60 |
|
|
|
@spaces.GPU(duration=get_duration) |
|
def generate_video(input_image, prompt, height, width, |
|
negative_prompt=default_negative_prompt, duration_seconds = 2, |
|
guidance_scale = 1, steps = 4, |
|
seed = 42, randomize_seed = False, |
|
progress=gr.Progress(track_tqdm=True)): |
|
|
|
if input_image is None: |
|
raise gr.Error("Please upload an input image.") |
|
|
|
|
|
if hasattr(spaces, 'GPU'): |
|
|
|
max_pixels = 409600 |
|
if height * width > max_pixels: |
|
raise gr.Error(f"Resolution too high for Zero GPU. Maximum {max_pixels:,} pixels (e.g., 640ร640)") |
|
|
|
|
|
if duration_seconds > 2.5: |
|
duration_seconds = 2.5 |
|
gr.Warning("Duration limited to 2.5s in Zero GPU environment") |
|
|
|
|
|
if steps > 8: |
|
steps = 8 |
|
gr.Warning("Steps limited to 8 in Zero GPU environment") |
|
|
|
target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE) |
|
target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE) |
|
|
|
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) |
|
|
|
|
|
if hasattr(spaces, 'GPU'): |
|
max_frames_zerogpu = int(2.5 * FIXED_FPS) |
|
num_frames = min(num_frames, max_frames_zerogpu) |
|
|
|
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) |
|
|
|
logger.info(f"Generating video: {target_h}x{target_w}, {num_frames} frames, seed={current_seed}") |
|
|
|
|
|
resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS) |
|
|
|
try: |
|
with torch.inference_mode(): |
|
output_frames_list = pipe( |
|
image=resized_image, |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
height=target_h, |
|
width=target_w, |
|
num_frames=num_frames, |
|
guidance_scale=float(guidance_scale), |
|
num_inference_steps=int(steps), |
|
generator=torch.Generator(device="cuda").manual_seed(current_seed) |
|
).frames[0] |
|
except torch.cuda.OutOfMemoryError: |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
raise gr.Error("GPU out of memory. Try smaller resolution or shorter duration.") |
|
except Exception as e: |
|
logger.error(f"Generation failed: {e}") |
|
raise gr.Error(f"Video generation failed: {str(e)[:100]}") |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: |
|
video_path = tmpfile.name |
|
export_to_video(output_frames_list, video_path, fps=FIXED_FPS) |
|
|
|
|
|
del output_frames_list |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
return video_path, current_seed |
|
|
|
|
|
css = """ |
|
.container { |
|
max-width: 1200px; |
|
margin: auto; |
|
padding: 20px; |
|
} |
|
|
|
.header { |
|
text-align: center; |
|
margin-bottom: 30px; |
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
padding: 40px; |
|
border-radius: 20px; |
|
color: white; |
|
box-shadow: 0 10px 30px rgba(0,0,0,0.2); |
|
position: relative; |
|
overflow: hidden; |
|
} |
|
|
|
.header::before { |
|
content: ''; |
|
position: absolute; |
|
top: -50%; |
|
left: -50%; |
|
width: 200%; |
|
height: 200%; |
|
background: radial-gradient(circle, rgba(255,255,255,0.1) 0%, transparent 70%); |
|
animation: pulse 4s ease-in-out infinite; |
|
} |
|
|
|
@keyframes pulse { |
|
0%, 100% { transform: scale(1); opacity: 0.5; } |
|
50% { transform: scale(1.1); opacity: 0.8; } |
|
} |
|
|
|
.header h1 { |
|
font-size: 3em; |
|
margin-bottom: 10px; |
|
text-shadow: 2px 2px 4px rgba(0,0,0,0.3); |
|
position: relative; |
|
z-index: 1; |
|
} |
|
|
|
.header p { |
|
font-size: 1.2em; |
|
opacity: 0.95; |
|
position: relative; |
|
z-index: 1; |
|
} |
|
|
|
.gpu-status { |
|
position: absolute; |
|
top: 10px; |
|
right: 10px; |
|
background: rgba(0,0,0,0.3); |
|
padding: 5px 15px; |
|
border-radius: 20px; |
|
font-size: 0.8em; |
|
} |
|
|
|
.main-content { |
|
background: rgba(255, 255, 255, 0.95); |
|
border-radius: 20px; |
|
padding: 30px; |
|
box-shadow: 0 5px 20px rgba(0,0,0,0.1); |
|
backdrop-filter: blur(10px); |
|
} |
|
|
|
.input-section { |
|
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); |
|
padding: 25px; |
|
border-radius: 15px; |
|
margin-bottom: 20px; |
|
} |
|
|
|
.generate-btn { |
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
color: white; |
|
font-size: 1.3em; |
|
padding: 15px 40px; |
|
border-radius: 30px; |
|
border: none; |
|
cursor: pointer; |
|
transition: all 0.3s ease; |
|
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4); |
|
width: 100%; |
|
margin-top: 20px; |
|
} |
|
|
|
.generate-btn:hover { |
|
transform: translateY(-2px); |
|
box-shadow: 0 7px 20px rgba(102, 126, 234, 0.6); |
|
} |
|
|
|
.generate-btn:active { |
|
transform: translateY(0); |
|
} |
|
|
|
.video-output { |
|
background: #f8f9fa; |
|
padding: 20px; |
|
border-radius: 15px; |
|
text-align: center; |
|
min-height: 400px; |
|
display: flex; |
|
align-items: center; |
|
justify-content: center; |
|
} |
|
|
|
.accordion { |
|
background: rgba(255, 255, 255, 0.7); |
|
border-radius: 10px; |
|
margin-top: 15px; |
|
padding: 15px; |
|
} |
|
|
|
.slider-container { |
|
background: rgba(255, 255, 255, 0.5); |
|
padding: 15px; |
|
border-radius: 10px; |
|
margin: 10px 0; |
|
} |
|
|
|
body { |
|
background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab); |
|
background-size: 400% 400%; |
|
animation: gradient 15s ease infinite; |
|
} |
|
|
|
@keyframes gradient { |
|
0% { background-position: 0% 50%; } |
|
50% { background-position: 100% 50%; } |
|
100% { background-position: 0% 50%; } |
|
} |
|
|
|
.warning-box { |
|
background: rgba(255, 193, 7, 0.1); |
|
border: 1px solid rgba(255, 193, 7, 0.3); |
|
border-radius: 10px; |
|
padding: 15px; |
|
margin: 10px 0; |
|
color: #856404; |
|
font-size: 0.9em; |
|
} |
|
|
|
.info-box { |
|
background: rgba(52, 152, 219, 0.1); |
|
border: 1px solid rgba(52, 152, 219, 0.3); |
|
border-radius: 10px; |
|
padding: 15px; |
|
margin: 10px 0; |
|
color: #2c5282; |
|
font-size: 0.9em; |
|
} |
|
|
|
.footer { |
|
text-align: center; |
|
margin-top: 30px; |
|
color: #666; |
|
font-size: 0.9em; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: |
|
with gr.Column(elem_classes="container"): |
|
|
|
gr.HTML(""" |
|
<div class="header"> |
|
<h1>๐ฌ AI Video Magic Studio</h1> |
|
<p>Transform your images into captivating videos with Wan 2.1 + CausVid LoRA</p> |
|
<div class="gpu-status">๐ฅ๏ธ Zero GPU Optimized</div> |
|
</div> |
|
""") |
|
|
|
|
|
if hasattr(spaces, 'GPU'): |
|
gr.HTML(""" |
|
<div class="warning-box"> |
|
<strong>๐ก Zero GPU Performance Tips:</strong> |
|
<ul style="margin: 5px 0; padding-left: 20px;"> |
|
<li>Maximum duration: 2.5 seconds</li> |
|
<li>Maximum resolution: 640ร640 pixels</li> |
|
<li>Recommended: 512ร512 at 2 seconds</li> |
|
<li>Use 4-6 steps for optimal speed/quality balance</li> |
|
<li>Processing time: ~60-90 seconds</li> |
|
</ul> |
|
</div> |
|
""") |
|
|
|
|
|
gr.HTML(""" |
|
<div class="info-box"> |
|
<strong>๐ฏ Quick Start Guide:</strong> |
|
<ol style="margin: 5px 0; padding-left: 20px;"> |
|
<li>Upload your image - AI will calculate optimal dimensions</li> |
|
<li>Enter a creative prompt or use the default</li> |
|
<li>Adjust duration (2s recommended for best results)</li> |
|
<li>Click Generate and wait for completion</li> |
|
</ol> |
|
</div> |
|
""") |
|
|
|
with gr.Row(elem_classes="main-content"): |
|
with gr.Column(scale=1): |
|
gr.Markdown("### ๐ธ Input Settings") |
|
|
|
with gr.Column(elem_classes="input-section"): |
|
input_image = gr.Image( |
|
type="pil", |
|
label="๐ผ๏ธ Upload Your Image", |
|
elem_classes="image-upload" |
|
) |
|
|
|
prompt_input = gr.Textbox( |
|
label="โจ Animation Prompt", |
|
value=default_prompt_i2v, |
|
placeholder="Describe how you want your image to move...", |
|
lines=2 |
|
) |
|
|
|
duration_input = gr.Slider( |
|
minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1), |
|
maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1) if not hasattr(spaces, 'GPU') else 2.5, |
|
step=0.1, |
|
value=2, |
|
label=f"โฑ๏ธ Video Duration (seconds) - Clamped to {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps", |
|
elem_classes="slider-container" |
|
) |
|
|
|
with gr.Accordion("๐๏ธ Advanced Settings", open=False, elem_classes="accordion"): |
|
negative_prompt = gr.Textbox( |
|
label="๐ซ Negative Prompt", |
|
value=default_negative_prompt, |
|
lines=3 |
|
) |
|
|
|
with gr.Row(): |
|
seed = gr.Slider( |
|
minimum=0, |
|
maximum=MAX_SEED, |
|
step=1, |
|
value=42, |
|
label="๐ฒ Seed" |
|
) |
|
randomize_seed = gr.Checkbox( |
|
label="๐ Randomize", |
|
value=True |
|
) |
|
|
|
with gr.Row(): |
|
height_slider = gr.Slider( |
|
minimum=SLIDER_MIN_H, |
|
maximum=SLIDER_MAX_H if not hasattr(spaces, 'GPU') else 640, |
|
step=MOD_VALUE, |
|
value=DEFAULT_H_SLIDER_VALUE, |
|
label=f"๐ Height (multiple of {MOD_VALUE})" |
|
) |
|
width_slider = gr.Slider( |
|
minimum=SLIDER_MIN_W, |
|
maximum=SLIDER_MAX_W if not hasattr(spaces, 'GPU') else 640, |
|
step=MOD_VALUE, |
|
value=DEFAULT_W_SLIDER_VALUE, |
|
label=f"๐ Width (multiple of {MOD_VALUE})" |
|
) |
|
|
|
steps_slider = gr.Slider( |
|
minimum=1, |
|
maximum=30 if not hasattr(spaces, 'GPU') else 8, |
|
step=1, |
|
value=4, |
|
label="๐ง Quality Steps (4-6 recommended)" |
|
) |
|
|
|
guidance_scale = gr.Slider( |
|
minimum=0.0, |
|
maximum=20.0, |
|
step=0.5, |
|
value=1.0, |
|
label="๐ฏ Guidance Scale", |
|
visible=False |
|
) |
|
|
|
generate_btn = gr.Button( |
|
"๐ฌ Generate Video", |
|
variant="primary", |
|
elem_classes="generate-btn" |
|
) |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("### ๐ฅ Generated Video") |
|
video_output = gr.Video( |
|
label="", |
|
autoplay=True, |
|
elem_classes="video-output" |
|
) |
|
|
|
gr.HTML(""" |
|
<div class="footer"> |
|
<p>๐ก Tip: For best results, use clear images with good lighting and distinct subjects</p> |
|
</div> |
|
""") |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["peng.png", "a penguin playfully dancing in the snow, Antarctica", 512, 512], |
|
["forg.jpg", "the frog jumps around", 448, 576], |
|
], |
|
inputs=[input_image, prompt_input, height_slider, width_slider], |
|
outputs=[video_output, seed], |
|
fn=generate_video, |
|
cache_examples=False |
|
) |
|
|
|
|
|
gr.HTML(""" |
|
<div style="background: rgba(255,255,255,0.9); border-radius: 10px; padding: 15px; margin-top: 20px; font-size: 0.8em; text-align: center;"> |
|
<p style="margin: 0; color: #666;"> |
|
<strong style="color: #667eea;">Powered by:</strong> |
|
Wan 2.1 I2V (14B) + CausVid LoRA โข ๐ 4-8 steps fast inference โข ๐ฌ Up to 81 frames |
|
</p> |
|
</div> |
|
""") |
|
|
|
|
|
input_image.upload( |
|
fn=handle_image_upload_for_dims_wan, |
|
inputs=[input_image, height_slider, width_slider], |
|
outputs=[height_slider, width_slider] |
|
) |
|
|
|
input_image.clear( |
|
fn=handle_image_upload_for_dims_wan, |
|
inputs=[input_image, height_slider, width_slider], |
|
outputs=[height_slider, width_slider] |
|
) |
|
|
|
generate_btn.click( |
|
fn=generate_video, |
|
inputs=[ |
|
input_image, prompt_input, height_slider, width_slider, |
|
negative_prompt, duration_input, guidance_scale, |
|
steps_slider, seed, randomize_seed |
|
], |
|
outputs=[video_output, seed] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch() |