Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import threading | |
from dataclasses import dataclass | |
from urllib.parse import urlparse | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
from diffusers.models import AutoencoderKLWan | |
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | |
from einops import rearrange | |
from jaxtyping import Float | |
from peft import LoraConfig | |
from PIL import Image | |
from torch import Tensor | |
from wan.pipeline_wan_t2tex_extra import WanT2TexPipeline | |
from wan.wan_t2tex_transformer_3d_extra import WanT2TexTransformer3DModel | |
TEX_PIPE = None | |
VAE = None | |
LATENTS_MEAN, LATENTS_STD = None, None | |
TEX_PIPE_LOCK = threading.Lock() | |
class Config: | |
video_base_name: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" | |
seqtex_path: str = "https://huggingface.co/VAST-AI/SeqTex/resolve/main/.gitattributes/edm2_ema_12176_clean.pth" | |
min_noise_level_index: int = 15 # which is same as paper [WorldMem](https://arxiv.org/pdf/2504.12369v1) | |
use_causal_mask: bool = False | |
addtional_qk_geometry: bool = False | |
use_normal: bool = True | |
use_position: bool = True | |
randomly_init: bool = True # we load the weights from a corresponding ckpt | |
num_views: int = 4 | |
uv_num_views: int = 1 | |
mv_height: int = 512 | |
mv_width: int = 512 | |
uv_height: int = 1024 | |
uv_width: int = 1024 | |
flow_shift: float = 5.0 | |
eval_guidance_scale: float = 1.0 | |
eval_num_inference_steps: int = 30 | |
eval_seed: int = 42 | |
lora_rank: int = 128 | |
lora_alpha: int = 64 | |
cfg = Config() | |
def load_model_weights(model_path: str, map_location="cpu"): | |
""" | |
Load model weights from either a URL or local file path. | |
Args: | |
model_path (str): Path to model weights, can be URL or local file path | |
map_location (str): Device to map the model to | |
Returns: | |
Dict: Loaded state dictionary | |
""" | |
# Check if the path is a URL | |
parsed_url = urlparse(model_path) | |
if parsed_url.scheme in ('http', 'https'): | |
# Load from URL using torch.hub | |
try: | |
state_dict = torch.hub.load_state_dict_from_url( | |
model_path, | |
map_location=map_location, | |
progress=True | |
) | |
return state_dict | |
except Exception as e: | |
gr.Warning(f"Failed to load from URL: {e}") | |
raise e | |
else: | |
# Load from local file path | |
if not os.path.exists(model_path): | |
raise FileNotFoundError(f"Local model file not found: {model_path}") | |
return torch.load(model_path, map_location=map_location) | |
def lazy_get_seqtex_pipe(): | |
""" | |
Lazy load the SeqTex pipeline for texture generation. | |
""" | |
global TEX_PIPE, VAE, LATENTS_MEAN, LATENTS_STD | |
if TEX_PIPE is not None: | |
return TEX_PIPE | |
gr.Info("First called, loading SeqTex pipeline... It may take about 1 minute.") | |
with TEX_PIPE_LOCK: | |
if TEX_PIPE is not None: | |
return TEX_PIPE | |
# Pipeline | |
TEX_PIPE = WanT2TexPipeline.from_pretrained(cfg.video_base_name) | |
# Models | |
transformer = WanT2TexTransformer3DModel( | |
TEX_PIPE.transformer, | |
use_causal_mask=cfg.use_causal_mask, | |
addtional_qk_geo=cfg.addtional_qk_geometry, | |
use_normal=cfg.use_normal, | |
use_position=cfg.use_position, | |
randomly_init=cfg.randomly_init, | |
) | |
transformer.add_adapter( | |
LoraConfig( | |
r=cfg.lora_rank, | |
lora_alpha=cfg.lora_alpha, | |
init_lora_weights=True, | |
target_modules=["attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out.0", "attn1.to_out.2", | |
"ffn.net.0.proj", "ffn.net.2"], | |
) | |
) | |
# load transformer | |
state_dict = load_model_weights(cfg.seqtex_path, map_location="cpu") | |
transformer.load_state_dict(state_dict, strict=True) | |
TEX_PIPE.transformer = transformer | |
VAE = AutoencoderKLWan.from_pretrained(cfg.video_base_name, subfolder="vae", torch_dtype=torch.float32).to("cuda").requires_grad_(False) | |
TEX_PIPE.vae = VAE | |
# Some useful parameters | |
LATENTS_MEAN = torch.tensor(VAE.config.latents_mean).view( | |
1, VAE.config.z_dim, 1, 1, 1 | |
).to("cuda", dtype=torch.float32) | |
LATENTS_STD = 1.0 / torch.tensor(VAE.config.latents_std).view( | |
1, VAE.config.z_dim, 1, 1, 1 | |
).to("cuda", dtype=torch.float32) | |
scheduler: FlowMatchEulerDiscreteScheduler = ( | |
FlowMatchEulerDiscreteScheduler.from_config( | |
TEX_PIPE.scheduler.config, shift=cfg.flow_shift | |
) | |
) | |
min_noise_level_index = scheduler.config.num_train_timesteps - cfg.min_noise_level_index # in our scheduler, the first time is noise. set to 1000 - 15 typically | |
setattr(TEX_PIPE, "min_noise_level_index", min_noise_level_index) | |
min_noise_level_timestep = scheduler.timesteps[min_noise_level_index] | |
setattr(TEX_PIPE, "min_noise_level_timestep", min_noise_level_timestep) | |
setattr(TEX_PIPE, "min_noise_level_sigma", min_noise_level_timestep / 1000.) | |
TEX_PIPE = TEX_PIPE.to("cuda", dtype=torch.float32) # use float32 for inference | |
return TEX_PIPE | |
def encode_images( | |
images: Float[Tensor, "B F H W C"], encode_as_first: bool = False | |
) -> Float[Tensor, "B C' F H/8 W/8"]: | |
""" | |
Encode images to latent space using VAE. | |
Every frame is seen as a separate image, without any awareness of the temporal dimension. | |
:param images: Input images tensor with shape [B, F, H, W, C]. | |
:param encode_as_first: Whether to encode all frames as the first frame. | |
:return: Encoded latents with shape [B, C', F, H/8, W/8]. | |
""" | |
if images.min() < - 0.1: | |
# images are in [-1, 1] range | |
images = (images + 1.0) / 2.0 # Normalize to [0, 1] range | |
if encode_as_first: | |
# encode all the frame as the first one | |
B = images.shape[0] | |
images = rearrange(images, "B F H W C -> (B F) C 1 H W") | |
latents = (VAE.encode(images).latent_dist.sample() - LATENTS_MEAN) * LATENTS_STD | |
latents = rearrange(latents, "(B F) C 1 H W -> B C F H W", B=B) | |
else: | |
raise NotImplementedError("Currently only support encode as first frame.") | |
return latents | |
# @torch.no_grad() | |
# @torch.amp.autocast('cuda', dtype=torch.float32) | |
# def decode_images(self, latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False): | |
# if decode_as_first: | |
# F = latents.shape[2] | |
# latents = latents.to(self.vae.dtype) | |
# latents = latents / self.latents_std + self.latents_mean | |
# latents = rearrange(latents, "B C F H W -> (B F) C 1 H W") | |
# images = self.vae.decode(latents, return_dict=False)[0] | |
# images = rearrange(images, "(B F) C Nv H W -> B C (F Nv) H W", F=F, Nv=1) | |
# else: | |
# raise NotImplementedError("Currently only support decode as first frame.") | |
# return images | |
def decode_images(latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False): | |
""" | |
Decode latents back to images using VAE. | |
:param latents: Input latents with shape [B, C, F, H, W]. | |
:param decode_as_first: Whether to decode all frames as the first frame. | |
:return: Decoded images with shape [B, C, F*Nv, H*8, W*8]. | |
""" | |
if decode_as_first: | |
F = latents.shape[2] | |
latents = latents.to(VAE.dtype) | |
latents = latents / LATENTS_STD + LATENTS_MEAN | |
latents = rearrange(latents, "B C F H W -> (B F) C 1 H W") | |
images = VAE.decode(latents, return_dict=False)[0] | |
images = rearrange(images, "(B F) C Nv H W -> B C (F Nv) H W", F=F, Nv=1) | |
else: | |
raise NotImplementedError("Currently only support decode as first frame.") | |
return images | |
def convert_img_to_tensor(image: Image.Image, device="cuda") -> Float[Tensor, "H W C"]: | |
""" | |
Convert a PIL Image to a tensor. If Image is RGBA, mask it with black background using a-channel mask. | |
:param image: PIL Image to convert. [0, 255] | |
:return: Tensor representation of the image. [0.0, 1.0], still [H, W, C] | |
""" | |
# Convert to RGBA to ensure alpha channel exists | |
image = image.convert("RGBA") | |
np_img = np.array(image) | |
rgb = np_img[..., :3] | |
alpha = np_img[..., 3:4] / 255.0 # Normalize alpha to [0, 1] | |
# Blend with black background using alpha mask | |
rgb = rgb * alpha | |
rgb = rgb.astype(np.float32) / 255.0 # Normalize to [0, 1] | |
tensor = torch.from_numpy(rgb).to(device) | |
return tensor | |
def generate_texture(position_map, normal_map, position_images, normal_images, condition_image, text_prompt, selected_view, negative_prompt=None, device="cuda", progress=gr.Progress()): | |
""" | |
Use SeqTex to generate texture for the mesh based on the image condition. | |
:param position_images: List of position images from different views. | |
:param normal_images: List of normal images from different views. | |
:param condition_image: Image condition generated from the selected view. | |
:param text_prompt: Text prompt for texture generation. | |
:param selected_view: The view selected for generating the image condition. | |
:return: Generated texture map, and multi-view frames in tensor. | |
""" | |
progress(0, desc="Loading SeqTex pipeline...") | |
tex_pipe = lazy_get_seqtex_pipe() | |
progress(0.2, desc="SeqTex pipeline loaded successfully.") | |
view_id_map = { | |
"First View": 0, | |
"Second View": 1, | |
"Third View": 2, | |
"Fourth View": 3 | |
} | |
view_id = view_id_map[selected_view] | |
progress(0.3, desc="Encoding position and normal images...") | |
nat_seq = torch.cat([position_images.unsqueeze(0), normal_images.unsqueeze(0)], dim=0) # 1 F H W C | |
uv_seq = torch.cat([position_map.unsqueeze(0), normal_map.unsqueeze(0)], dim=0) | |
nat_latents = encode_images(nat_seq, encode_as_first=True) # B C F H W | |
uv_latents = encode_images(uv_seq, encode_as_first=True) # B C F' H' W' | |
nat_pos_latents, nat_norm_latents = torch.chunk(nat_latents, 2, dim=0) | |
uv_pos_latents, uv_norm_latents = torch.chunk(uv_latents, 2, dim=0) | |
nat_geo_latents = torch.cat([nat_pos_latents, nat_norm_latents], dim=1) | |
uv_geo_latents = torch.cat([uv_pos_latents, uv_norm_latents], dim=1) | |
cond_model_latents = (nat_geo_latents, uv_geo_latents) | |
num_frames = cfg.num_views * (2 ** sum(VAE.config.temperal_downsample)) | |
uv_num_frames = cfg.uv_num_views * (2 ** sum(VAE.config.temperal_downsample)) | |
progress(0.4, desc="Encoding condition image...") | |
if isinstance(condition_image, Image.Image): | |
condition_image = condition_image.resize((cfg.mv_width, cfg.mv_height), Image.LANCZOS) | |
# Convert PIL Image to tensor | |
condition_image = convert_img_to_tensor(condition_image, device=device) | |
condition_image = condition_image.unsqueeze(0).unsqueeze(0) | |
gt_latents = (encode_images(condition_image, encode_as_first=True), None) | |
progress(0.5, desc="Generating texture with SeqTex...") | |
latents = tex_pipe( | |
prompt=text_prompt, | |
negative_prompt=negative_prompt, | |
num_frames=num_frames, | |
generator=torch.Generator(device=device).manual_seed(cfg.eval_seed), | |
num_inference_steps=cfg.eval_num_inference_steps, | |
guidance_scale=cfg.eval_guidance_scale, | |
height=cfg.mv_height, | |
width=cfg.mv_width, | |
output_type="latent", | |
cond_model_latents=cond_model_latents, | |
# mask_indices=test_mask_indices, | |
uv_height=cfg.uv_height, | |
uv_width=cfg.uv_width, | |
uv_num_frames=uv_num_frames, | |
treat_as_first=True, | |
gt_condition=gt_latents, | |
inference_img_cond_frame=view_id, | |
use_qk_geometry=True, | |
task_type="img2tex", # img2tex | |
progress=progress, | |
).frames | |
mv_latents, uv_latents = latents | |
progress(0.9, desc="Decoding generated latents to images...") | |
mv_frames = decode_images(mv_latents, decode_as_first=True) # B C 4 H W | |
uv_frames = decode_images(uv_latents, decode_as_first=True) # B C 1 H W | |
uv_map_pred = uv_frames[:, :, -1, ...] | |
uv_map_pred.squeeze_(0) | |
mv_out = rearrange(mv_frames[:, :, :cfg.num_views, ...], "B C (F N) H W -> N C (B H) (F W)", N=1)[0] | |
mv_out = torch.clamp(mv_out, 0.0, 1.0) | |
uv_map_pred = torch.clamp(uv_map_pred, 0.0, 1.0) | |
progress(1, desc="Texture generated successfully.") | |
return uv_map_pred.float(), mv_out.float(), "Step 3: Texture generated successfully." | |