ginipick's picture
Update app.py
061dfbf verified
raw
history blame
26.3 kB
import torch
import torch.nn.functional as F
from diffusers import AutoencoderKLWan, WanVideoTextToVideoPipeline, UniPCMultistepScheduler
from diffusers.utils import export_to_video
from diffusers.models import Transformer2DModel
import gradio as gr
import tempfile
import spaces
from huggingface_hub import hf_hub_download
import numpy as np
import random
import logging
import os
import gc
from typing import List, Optional, Union
# MMAudio imports
try:
import mmaudio
except ImportError:
os.system("pip install -e .")
import mmaudio
# Set environment variables
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
os.environ['HF_HUB_CACHE'] = '/tmp/hub'
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
setup_eval_logging)
from mmaudio.model.flow_matching import FlowMatching
from mmaudio.model.networks import MMAudio, get_my_mmaudio
from mmaudio.model.sequence_config import SequenceConfig
from mmaudio.model.utils.features_utils import FeaturesUtils
# NAG-enhanced Pipeline
class NAGWanPipeline(WanVideoTextToVideoPipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.nag_scale = 0.0
self.nag_tau = 3.5
self.nag_alpha = 0.5
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
nag_negative_prompt: Optional[Union[str, List[str]]] = None,
nag_scale: float = 0.0,
nag_tau: float = 3.5,
nag_alpha: float = 0.5,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: int = 16,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[dict] = None,
clip_skip: Optional[int] = None,
):
# Use NAG negative prompt if provided
if nag_negative_prompt is not None:
negative_prompt = nag_negative_prompt
# Store NAG parameters
self.nag_scale = nag_scale
self.nag_tau = nag_tau
self.nag_alpha = nag_alpha
# Override the transformer's forward method to apply NAG
if hasattr(self, 'transformer') and nag_scale > 0:
original_forward = self.transformer.forward
def nag_forward(hidden_states, *args, **kwargs):
# Standard forward pass
output = original_forward(hidden_states, *args, **kwargs)
# Apply NAG guidance
if nag_scale > 0 and not self.transformer.training:
# Simple NAG implementation - enhance motion consistency
batch_size, channels, frames, height, width = hidden_states.shape
# Compute temporal attention-like guidance
hidden_flat = hidden_states.view(batch_size, channels, -1)
attention = F.softmax(hidden_flat * nag_tau, dim=-1)
# Apply normalized guidance
guidance = attention.mean(dim=2, keepdim=True) * nag_alpha
guidance = guidance.unsqueeze(-1).unsqueeze(-1)
# Scale and add guidance
if hasattr(output, 'sample'):
output.sample = output.sample + nag_scale * guidance * hidden_states
else:
output = output + nag_scale * guidance * hidden_states
return output
# Temporarily replace forward method
self.transformer.forward = nag_forward
# Call parent pipeline
result = super().__call__(
prompt=prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
eta=eta,
generator=generator,
latents=latents,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
output_type=output_type,
return_dict=return_dict,
callback=callback,
callback_steps=callback_steps,
cross_attention_kwargs=cross_attention_kwargs,
clip_skip=clip_skip,
)
# Restore original forward method
if hasattr(self, 'transformer') and hasattr(self.transformer, 'forward'):
self.transformer.forward = original_forward
return result
# Clean up temp files
def cleanup_temp_files():
temp_dir = tempfile.gettempdir()
for filename in os.listdir(temp_dir):
filepath = os.path.join(temp_dir, filename)
try:
if filename.endswith(('.mp4', '.flac', '.wav')):
os.remove(filepath)
except:
pass
# Video generation model setup
MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
LORA_REPO_ID = "Kijai/WanVideo_comfy"
LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
# Load the model components
vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
pipe = NAGWanPipeline.from_pretrained(
MODEL_ID, vae=vae, torch_dtype=torch.bfloat16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
pipe.to("cuda")
# Load LoRA weights for faster generation
causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
pipe.fuse_lora()
# Audio generation model setup
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
log = logging.getLogger()
device = 'cuda'
dtype = torch.bfloat16
# Global variables for audio model
audio_model = None
audio_net = None
audio_feature_utils = None
audio_seq_cfg = None
def load_audio_model():
global audio_model, audio_net, audio_feature_utils, audio_seq_cfg
if audio_net is None:
audio_model = all_model_cfg['small_16k']
audio_model.download_if_needed()
setup_eval_logging()
seq_cfg = audio_model.seq_cfg
net = get_my_mmaudio(audio_model.model_name).to(device, dtype).eval()
net.load_weights(torch.load(audio_model.model_path, map_location=device, weights_only=True))
log.info(f'Loaded weights from {audio_model.model_path}')
feature_utils = FeaturesUtils(tod_vae_ckpt=audio_model.vae_path,
synchformer_ckpt=audio_model.synchformer_ckpt,
enable_conditions=True,
mode=audio_model.mode,
bigvgan_vocoder_ckpt=audio_model.bigvgan_16k_path,
need_vae_encoder=False)
feature_utils = feature_utils.to(device, dtype).eval()
audio_net = net
audio_feature_utils = feature_utils
audio_seq_cfg = seq_cfg
return audio_net, audio_feature_utils, audio_seq_cfg
# Constants
MOD_VALUE = 32
DEFAULT_DURATION_SECONDS = 4
DEFAULT_STEPS = 4
DEFAULT_SEED = 2025
DEFAULT_H_SLIDER_VALUE = 480
DEFAULT_W_SLIDER_VALUE = 832
SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 129
DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
default_prompt = "A ginger cat passionately plays electric guitar with intensity and emotion on a stage"
default_audio_prompt = ""
default_audio_negative_prompt = "music"
# CSS
custom_css = """
/* 전체 λ°°κ²½ κ·ΈλΌλ””μ–ΈνŠΈ */
.gradio-container {
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif !important;
background: linear-gradient(135deg, #667eea 0%, #764ba2 25%, #f093fb 50%, #f5576c 75%, #fa709a 100%) !important;
background-size: 400% 400% !important;
animation: gradientShift 15s ease infinite !important;
}
@keyframes gradientShift {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
/* 메인 μ»¨ν…Œμ΄λ„ˆ μŠ€νƒ€μΌ */
.main-container {
backdrop-filter: blur(10px);
background: rgba(255, 255, 255, 0.1) !important;
border-radius: 20px !important;
padding: 30px !important;
box-shadow: 0 8px 32px 0 rgba(31, 38, 135, 0.37) !important;
border: 1px solid rgba(255, 255, 255, 0.18) !important;
}
/* 헀더 μŠ€νƒ€μΌ */
h1 {
background: linear-gradient(45deg, #ffffff, #f0f0f0) !important;
-webkit-background-clip: text !important;
-webkit-text-fill-color: transparent !important;
background-clip: text !important;
font-weight: 800 !important;
font-size: 2.5rem !important;
text-align: center !important;
margin-bottom: 2rem !important;
text-shadow: 2px 2px 4px rgba(0,0,0,0.1) !important;
}
/* μ»΄ν¬λ„ŒνŠΈ μ»¨ν…Œμ΄λ„ˆ μŠ€νƒ€μΌ */
.input-container, .output-container {
background: rgba(255, 255, 255, 0.08) !important;
border-radius: 15px !important;
padding: 20px !important;
margin: 10px 0 !important;
backdrop-filter: blur(5px) !important;
border: 1px solid rgba(255, 255, 255, 0.1) !important;
}
/* μž…λ ₯ ν•„λ“œ μŠ€νƒ€μΌ */
input, textarea, .gr-box {
background: rgba(255, 255, 255, 0.9) !important;
border: 1px solid rgba(255, 255, 255, 0.3) !important;
border-radius: 10px !important;
color: #333 !important;
transition: all 0.3s ease !important;
}
input:focus, textarea:focus {
background: rgba(255, 255, 255, 1) !important;
border-color: #667eea !important;
box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important;
}
/* λ²„νŠΌ μŠ€νƒ€μΌ */
.generate-btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
color: white !important;
font-weight: 600 !important;
font-size: 1.1rem !important;
padding: 12px 30px !important;
border-radius: 50px !important;
border: none !important;
cursor: pointer !important;
transition: all 0.3s ease !important;
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
}
.generate-btn:hover {
transform: translateY(-2px) !important;
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
}
/* μŠ¬λΌμ΄λ” μŠ€νƒ€μΌ */
input[type="range"] {
background: transparent !important;
}
input[type="range"]::-webkit-slider-track {
background: rgba(255, 255, 255, 0.3) !important;
border-radius: 5px !important;
height: 6px !important;
}
input[type="range"]::-webkit-slider-thumb {
background: linear-gradient(135deg, #667eea, #764ba2) !important;
border: 2px solid white !important;
border-radius: 50% !important;
cursor: pointer !important;
width: 18px !important;
height: 18px !important;
-webkit-appearance: none !important;
}
/* Accordion μŠ€νƒ€μΌ */
.gr-accordion {
background: rgba(255, 255, 255, 0.05) !important;
border-radius: 10px !important;
border: 1px solid rgba(255, 255, 255, 0.1) !important;
margin: 15px 0 !important;
}
/* 라벨 μŠ€νƒ€μΌ */
label {
color: #ffffff !important;
font-weight: 500 !important;
font-size: 0.95rem !important;
margin-bottom: 5px !important;
}
/* λΉ„λ””μ˜€ 좜λ ₯ μ˜μ—­ */
video {
border-radius: 15px !important;
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.3) !important;
}
/* Examples μ„Ήμ…˜ μŠ€νƒ€μΌ */
.gr-examples {
background: rgba(255, 255, 255, 0.05) !important;
border-radius: 15px !important;
padding: 20px !important;
margin-top: 20px !important;
}
/* Checkbox μŠ€νƒ€μΌ */
input[type="checkbox"] {
accent-color: #667eea !important;
}
/* Radio λ²„νŠΌ μŠ€νƒ€μΌ */
input[type="radio"] {
accent-color: #667eea !important;
}
/* Info box */
.info-box {
background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%);
border-radius: 10px;
padding: 15px;
margin: 10px 0;
border-left: 4px solid #667eea;
}
/* λ°˜μ‘ν˜• μ• λ‹ˆλ©”μ΄μ…˜ */
@media (max-width: 768px) {
h1 { font-size: 2rem !important; }
.main-container { padding: 20px !important; }
}
"""
def clear_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
def get_duration(prompt, nag_negative_prompt, nag_scale,
height, width, duration_seconds,
steps, seed, randomize_seed,
audio_mode, audio_prompt, audio_negative_prompt,
audio_seed, audio_steps, audio_cfg_strength,
progress):
duration = int(duration_seconds) * int(steps) * 2.25 + 5
if audio_mode == "Enable Audio":
duration += 60
return duration
@torch.inference_mode()
def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_prompt,
audio_seed, audio_steps, audio_cfg_strength):
net, feature_utils, seq_cfg = load_audio_model()
rng = torch.Generator(device=device)
if audio_seed >= 0:
rng.manual_seed(audio_seed)
else:
rng.seed()
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=audio_steps)
video_info = load_video(video_path, duration_sec)
clip_frames = video_info.clip_frames.unsqueeze(0)
sync_frames = video_info.sync_frames.unsqueeze(0)
duration = video_info.duration_sec
seq_cfg.duration = duration
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
audios = generate(clip_frames,
sync_frames, [audio_prompt],
negative_text=[audio_negative_prompt],
feature_utils=feature_utils,
net=net,
fm=fm,
rng=rng,
cfg_strength=audio_cfg_strength)
audio = audios.float().cpu()[0]
video_with_audio_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
make_video(video_info, video_with_audio_path, audio, sampling_rate=seq_cfg.sampling_rate)
return video_with_audio_path
@spaces.GPU(duration=get_duration)
def generate_video(prompt, nag_negative_prompt, nag_scale,
height, width, duration_seconds,
steps, seed, randomize_seed,
audio_mode, audio_prompt, audio_negative_prompt,
audio_seed, audio_steps, audio_cfg_strength,
progress=gr.Progress(track_tqdm=True)):
if not prompt.strip():
raise gr.Error("Please enter a text prompt to generate video.")
target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
num_frames = np.clip(int(round(int(duration_seconds) * FIXED_FPS) + 1), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
# Generate video using NAG
with torch.inference_mode():
output_frames_list = pipe(
prompt=prompt,
nag_negative_prompt=nag_negative_prompt,
nag_scale=nag_scale,
nag_tau=3.5,
nag_alpha=0.5,
height=target_h,
width=target_w,
num_frames=num_frames,
guidance_scale=0., # NAG replaces traditional guidance
num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed)
).frames[0]
# Save video without audio
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
video_path = tmpfile.name
export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
# Generate audio if enabled
video_with_audio_path = None
if audio_mode == "Enable Audio":
progress(0.5, desc="Generating audio...")
video_with_audio_path = add_audio_to_video(
video_path, duration_seconds,
audio_prompt, audio_negative_prompt,
audio_seed, audio_steps, audio_cfg_strength
)
clear_cache()
cleanup_temp_files()
return video_path, video_with_audio_path, current_seed
def update_audio_visibility(audio_mode):
return gr.update(visible=(audio_mode == "Enable Audio"))
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_classes=["main-container"]):
gr.Markdown("# ✨ Fast NAG T2V (14B) with Audio Generation")
gr.Markdown("### πŸš€ Normalized Attention Guidance + CausVid LoRA + MMAudio")
gr.HTML("""
<div class="info-box">
<p>🎯 <strong>NAG (Normalized Attention Guidance)</strong>: Enhanced motion consistency and quality</p>
<p>⚑ <strong>Speed</strong>: Generate videos in just 4-8 steps with CausVid LoRA</p>
<p>🎡 <strong>Audio</strong>: Optional synchronized audio generation with MMAudio</p>
</div>
""")
with gr.Row():
with gr.Column(elem_classes=["input-container"]):
prompt_input = gr.Textbox(
label="✨ Video Prompt",
value=default_prompt,
placeholder="Describe your video scene in detail...",
lines=3
)
with gr.Accordion("🎨 NAG Settings", open=True):
nag_negative_prompt = gr.Textbox(
label="❌ NAG Negative Prompt",
value=DEFAULT_NAG_NEGATIVE_PROMPT,
lines=2
)
nag_scale = gr.Slider(
label="🎯 NAG Scale",
minimum=0.0,
maximum=20.0,
step=0.25,
value=11.0,
info="0 = No NAG, 11 = Recommended, 20 = Maximum guidance"
)
duration_seconds_input = gr.Slider(
minimum=1,
maximum=8,
step=1,
value=DEFAULT_DURATION_SECONDS,
label="⏱️ Duration (seconds)",
info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps."
)
audio_mode = gr.Radio(
choices=["Video Only", "Enable Audio"],
value="Video Only",
label="🎡 Audio Mode",
info="Enable to add audio to your generated video"
)
with gr.Column(visible=False) as audio_settings:
audio_prompt = gr.Textbox(
label="🎡 Audio Prompt",
value=default_audio_prompt,
placeholder="Describe the audio you want (e.g., 'waves, seagulls', 'footsteps on gravel')",
lines=2
)
audio_negative_prompt = gr.Textbox(
label="❌ Audio Negative Prompt",
value=default_audio_negative_prompt,
lines=2
)
with gr.Row():
audio_seed = gr.Number(
label="🎲 Audio Seed",
value=-1,
precision=0,
minimum=-1
)
audio_steps = gr.Slider(
minimum=1,
maximum=50,
step=1,
value=25,
label="πŸš€ Audio Steps"
)
audio_cfg_strength = gr.Slider(
minimum=1.0,
maximum=10.0,
step=0.5,
value=4.5,
label="🎯 Audio Guidance"
)
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
with gr.Row():
height_input = gr.Slider(
minimum=SLIDER_MIN_H,
maximum=SLIDER_MAX_H,
step=MOD_VALUE,
value=DEFAULT_H_SLIDER_VALUE,
label=f"πŸ“ Output Height (Γ—{MOD_VALUE})"
)
width_input = gr.Slider(
minimum=SLIDER_MIN_W,
maximum=SLIDER_MAX_W,
step=MOD_VALUE,
value=DEFAULT_W_SLIDER_VALUE,
label=f"πŸ“ Output Width (Γ—{MOD_VALUE})"
)
with gr.Row():
steps_slider = gr.Slider(
minimum=1,
maximum=8,
step=1,
value=DEFAULT_STEPS,
label="πŸš€ Inference Steps"
)
seed_input = gr.Slider(
label="🎲 Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=DEFAULT_SEED,
interactive=True
)
randomize_seed_checkbox = gr.Checkbox(
label="πŸ”€ Randomize seed",
value=True,
interactive=True
)
generate_button = gr.Button(
"🎬 Generate Video",
variant="primary",
elem_classes=["generate-btn"]
)
with gr.Column(elem_classes=["output-container"]):
video_output = gr.Video(
label="πŸŽ₯ Generated Video",
autoplay=True,
interactive=False
)
video_with_audio_output = gr.Video(
label="πŸŽ₯ Generated Video with Audio",
autoplay=True,
interactive=False,
visible=False
)
gr.HTML("""
<div style="text-align: center; margin-top: 20px; color: #ffffff;">
<p>πŸ’‘ Tip: Try different NAG scales for varied artistic effects!</p>
</div>
""")
# Event handlers
audio_mode.change(
fn=update_audio_visibility,
inputs=[audio_mode],
outputs=[audio_settings, video_with_audio_output]
)
ui_inputs = [
prompt_input, nag_negative_prompt, nag_scale,
height_input, width_input, duration_seconds_input,
steps_slider, seed_input, randomize_seed_checkbox,
audio_mode, audio_prompt, audio_negative_prompt,
audio_seed, audio_steps, audio_cfg_strength
]
generate_button.click(
fn=generate_video,
inputs=ui_inputs,
outputs=[video_output, video_with_audio_output, seed_input]
)
with gr.Column():
gr.Examples(
examples=[
["A ginger cat passionately plays electric guitar with intensity and emotion on a stage. The background is shrouded in deep darkness. Spotlights cast dramatic shadows.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE, DEFAULT_DURATION_SECONDS,
DEFAULT_STEPS, DEFAULT_SEED, False,
"Enable Audio", "electric guitar riffs, cat meowing", default_audio_negative_prompt, -1, 25, 4.5],
["A red vintage Porsche convertible flying over a rugged coastal cliff. Monstrous waves violently crashing against the rocks below. A lighthouse stands tall atop the cliff.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE, DEFAULT_DURATION_SECONDS,
DEFAULT_STEPS, DEFAULT_SEED, False,
"Enable Audio", "car engine roaring, ocean waves crashing, wind", default_audio_negative_prompt, -1, 25, 4.5],
["Enormous glowing jellyfish float slowly across a sky filled with soft clouds. Their tentacles shimmer with iridescent light as they drift above a peaceful mountain landscape. Magical and dreamlike, captured in a wide shot. Surreal realism style with detailed textures.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE, DEFAULT_DURATION_SECONDS,
DEFAULT_STEPS, DEFAULT_SEED, False,
"Video Only", "", default_audio_negative_prompt, -1, 25, 4.5],
],
inputs=ui_inputs,
outputs=[video_output, video_with_audio_output, seed_input],
fn=generate_video,
cache_examples="lazy",
label="🌟 Example Gallery"
)
if __name__ == "__main__":
demo.queue().launch()