import spaces #* import gradio as gr import tempfile import os os.environ['TOKENIZERS_PARALLELISM'] = 'false' from transformers import AutoModelForImageSegmentation import torch from torchvision import transforms import decord from PIL import Image import numpy as np from diffsynth import ModelManager, WanVideoPipeline, save_video num_frames, width, height = 49, 832, 480 # gpu_id = 3 # device = f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu' device = f'cuda' if torch.cuda.is_available() else 'cpu' #* # pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124 from modelscope import snapshot_download model_dir = snapshot_download( # https://www.modelscope.cn/models/AI-ModelScope/RMBG-2.0 model_id = 'AI-ModelScope/RMBG-2.0', local_dir = 'ckpt/RMBG-2.0', ignore_file_pattern = ['onnx*'], ) from huggingface_hub import snapshot_download, hf_hub_download snapshot_download( # 下载整个仓库; 下briaai/RMBG-2.0需要token repo_id="alibaba-pai/Wan2.1-Fun-1.3B-Control", local_dir="ckpt/Wan2.1-Fun-1.3B-Control", local_dir_use_symlinks=False, resume_download=True, repo_type="model" ) hf_hub_download( repo_id="Kunbyte/Lumen", filename="Lumen-T2V-1.3B-V1.0.ckpt", local_dir="ckpt/Lumen", local_dir_use_symlinks=False, resume_download=True, ) rmbg_model = AutoModelForImageSegmentation.from_pretrained('ckpt/RMBG-2.0', trust_remote_code=True) # ckpt/RMBG-2.0 torch.set_float32_matmul_precision(['high', 'highest'][0]) rmbg_model.to(device) rmbg_model.eval() model_manager = ModelManager(device="cpu") # 1.3b: device=cpu: uses 6G VRAM, device=device: uses 16G VRAM; about 1-2 min per video wan_dit_path = 'ckpt/Lumen/Lumen-T2V-1.3B-V1.0.ckpt' if 'wan14b' in wan_dit_path.lower(): # 14B: uses about 36G, about 10 min per video model_manager.load_models( [ wan_dit_path if wan_dit_path else 'ckpt/Wan2.1-Fun-14B-Control/diffusion_pytorch_model.safetensors', 'ckpt/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth', 'ckpt/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth', 'ckpt/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth', ], torch_dtype=torch.bfloat16, # float8_e4m3fn fp8量化; bfloat16 ) else: model_manager.load_models( [ wan_dit_path if wan_dit_path else 'ckpt/Wan2.1-Fun-1.3B-Control/diffusion_pytorch_model.safetensors', 'ckpt/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth', 'ckpt/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth', 'ckpt/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth', ], torch_dtype=torch.bfloat16, ) wan_pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device=device) wan_pipe.enable_vram_management(num_persistent_param_in_dit=None) gr_info_duration = 2 # gradio popup information duration @spaces.GPU def rmbg_mask(video_path, mask_path=None, progress=gr.Progress()): """Extract foreground from video, return foreground video path""" if not video_path: gr.Warning("Please upload a video first!", duration=gr_info_duration) return None try: progress(0, desc="Preparing foreground extraction...") if mask_path and os.path.exists(mask_path): gr.Info("Using uploaded mask video for foreground extraction.", duration=gr_info_duration) video_frames = decord.VideoReader(uri=video_path, width=width, height=height) video_frames = video_frames.get_batch(range(num_frames)).asnumpy().astype(np.uint8) mask_frames = decord.VideoReader(uri=mask_path, width=width, height=height) mask_frames = mask_frames.get_batch(range(num_frames)).asnumpy().astype(np.uint8) fg_frames = np.where( mask_frames >= 127, video_frames, 0) fg_frames = [Image.fromarray(frame) for frame in fg_frames] else: image_size = (width, height) transform_image = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) video_reader = decord.VideoReader(uri=video_path, width=width, height=height) video_frames = video_reader.get_batch(range(num_frames)).asnumpy() fg_frames = [] # Use progress bar in the loop for i in range(num_frames): # Update progress bar based on processed frames progress((i + 1) / num_frames, desc=f"Processing frame {i+1}/{num_frames}...") image = Image.fromarray(video_frames[i]) input_images = transform_image(image).unsqueeze(0).to(device) with torch.no_grad(): preds = rmbg_model(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image.size) # PIL.Image mode=L # Extract foreground from image based on mask fg_image = Image.composite(image, Image.new('RGB', image.size), mask) # white areas of mask take image1, black areas take image2 fg_frames.append(fg_image) progress(1.0, desc="Saving video...") with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file: fg_video_path = temp_file.name save_video(fg_frames, fg_video_path, fps=16, quality=7) progress(1.0, desc="Foreground extraction completed!") # gr.Info("Foreground extraction successful!") # gr.Video.update(value=fg_video_path, visible=True) return fg_video_path except Exception as e: error_msg = f"Foreground extraction error: {str(e)}" gr.Error(error_msg) return None @spaces.GPU def video_relighting(fg_video_path, prompt, seed=-1, num_inference_steps=50, video_quality=7, progress=gr.Progress()): """Relighting the foreground video base on the text """ if not fg_video_path or not os.path.exists(fg_video_path): gr.Warning("Please extract foreground first!", duration = gr_info_duration) return None if not prompt: gr.Warning("Please provide text prompt for relighting!", duration = gr_info_duration) return None try: fg_video = decord.VideoReader(uri=fg_video_path, width=width, height=height) fg_video = fg_video.get_batch(range(num_frames)).asnumpy().astype('uint8') progress(0.1, desc="relighting video...") relit_video = wan_pipe( prompt=prompt, # negative_prompt = 'Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards', negative_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走', num_inference_steps=num_inference_steps, control_video=fg_video, height=height, width=width, num_frames=num_frames, seed=seed, tiled=True, ) with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file: relit_video_path = temp_file.name save_video(relit_video, relit_video_path, fps=16, quality=video_quality) progress(1.0, desc="Relighting processing completed!") gr.Info(f"Relighting successful! Used seed={seed}, steps={num_inference_steps}", duration=gr_info_duration) return relit_video_path except Exception as e: error_msg = f"Relighting processing error: {str(e)}" gr.Error(error_msg) return None # gradio app_lumen.py python app_lumen.py # Examples bg_prompt_path = 'my_data/zh_en_short_prompts.txt' with open(bg_prompt_path, 'r') as f: bg_prompts = f.readlines() bg_prompts = [bg.strip() for bg in bg_prompts if bg.strip()] # 去除空行 bg_prompts_zh = bg_prompts[ : len(bg_prompts)//2] bg_prompts_en = bg_prompts[ len(bg_prompts)//2 :] video_names = [ 191947, 922930, 1217498, 1302135, 1371894, 1428515, 1628805, 1873403, 2259812, 2445920, 2639840, 2779867, 2974076 ] # 13 video_names = video_names * 2 video_dir = 'test/pachong_test/video/single_13' relight_dir = 'test/pachong_test/video/single_13_2-res-v1.0-gradio_demo' header = """ #