SeqTex / utils /texture_generation.py
yuanze1024's picture
init space
1d5bb62
raw
history blame
12.8 kB
import os
import threading
from dataclasses import dataclass
from urllib.parse import urlparse
import gradio as gr
import numpy as np
import spaces
import torch
from diffusers.models import AutoencoderKLWan
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from einops import rearrange
from jaxtyping import Float
from peft import LoraConfig
from PIL import Image
from torch import Tensor
from wan.pipeline_wan_t2tex_extra import WanT2TexPipeline
from wan.wan_t2tex_transformer_3d_extra import WanT2TexTransformer3DModel
TEX_PIPE = None
VAE = None
LATENTS_MEAN, LATENTS_STD = None, None
TEX_PIPE_LOCK = threading.Lock()
@dataclass
class Config:
video_base_name: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
seqtex_path: str = "https://huggingface.co/VAST-AI/SeqTex/resolve/main/.gitattributes/edm2_ema_12176_clean.pth"
min_noise_level_index: int = 15 # which is same as paper [WorldMem](https://arxiv.org/pdf/2504.12369v1)
use_causal_mask: bool = False
addtional_qk_geometry: bool = False
use_normal: bool = True
use_position: bool = True
randomly_init: bool = True # we load the weights from a corresponding ckpt
num_views: int = 4
uv_num_views: int = 1
mv_height: int = 512
mv_width: int = 512
uv_height: int = 1024
uv_width: int = 1024
flow_shift: float = 5.0
eval_guidance_scale: float = 1.0
eval_num_inference_steps: int = 30
eval_seed: int = 42
lora_rank: int = 128
lora_alpha: int = 64
cfg = Config()
def load_model_weights(model_path: str, map_location="cpu"):
"""
Load model weights from either a URL or local file path.
Args:
model_path (str): Path to model weights, can be URL or local file path
map_location (str): Device to map the model to
Returns:
Dict: Loaded state dictionary
"""
# Check if the path is a URL
parsed_url = urlparse(model_path)
if parsed_url.scheme in ('http', 'https'):
# Load from URL using torch.hub
try:
state_dict = torch.hub.load_state_dict_from_url(
model_path,
map_location=map_location,
progress=True
)
return state_dict
except Exception as e:
gr.Warning(f"Failed to load from URL: {e}")
raise e
else:
# Load from local file path
if not os.path.exists(model_path):
raise FileNotFoundError(f"Local model file not found: {model_path}")
return torch.load(model_path, map_location=map_location)
def lazy_get_seqtex_pipe():
"""
Lazy load the SeqTex pipeline for texture generation.
"""
global TEX_PIPE, VAE, LATENTS_MEAN, LATENTS_STD
if TEX_PIPE is not None:
return TEX_PIPE
gr.Info("First called, loading SeqTex pipeline... It may take about 1 minute.")
with TEX_PIPE_LOCK:
if TEX_PIPE is not None:
return TEX_PIPE
# Pipeline
TEX_PIPE = WanT2TexPipeline.from_pretrained(cfg.video_base_name)
# Models
transformer = WanT2TexTransformer3DModel(
TEX_PIPE.transformer,
use_causal_mask=cfg.use_causal_mask,
addtional_qk_geo=cfg.addtional_qk_geometry,
use_normal=cfg.use_normal,
use_position=cfg.use_position,
randomly_init=cfg.randomly_init,
)
transformer.add_adapter(
LoraConfig(
r=cfg.lora_rank,
lora_alpha=cfg.lora_alpha,
init_lora_weights=True,
target_modules=["attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out.0", "attn1.to_out.2",
"ffn.net.0.proj", "ffn.net.2"],
)
)
# load transformer
state_dict = load_model_weights(cfg.seqtex_path, map_location="cpu")
transformer.load_state_dict(state_dict, strict=True)
TEX_PIPE.transformer = transformer
VAE = AutoencoderKLWan.from_pretrained(cfg.video_base_name, subfolder="vae", torch_dtype=torch.float32).to("cuda").requires_grad_(False)
TEX_PIPE.vae = VAE
# Some useful parameters
LATENTS_MEAN = torch.tensor(VAE.config.latents_mean).view(
1, VAE.config.z_dim, 1, 1, 1
).to("cuda", dtype=torch.float32)
LATENTS_STD = 1.0 / torch.tensor(VAE.config.latents_std).view(
1, VAE.config.z_dim, 1, 1, 1
).to("cuda", dtype=torch.float32)
scheduler: FlowMatchEulerDiscreteScheduler = (
FlowMatchEulerDiscreteScheduler.from_config(
TEX_PIPE.scheduler.config, shift=cfg.flow_shift
)
)
min_noise_level_index = scheduler.config.num_train_timesteps - cfg.min_noise_level_index # in our scheduler, the first time is noise. set to 1000 - 15 typically
setattr(TEX_PIPE, "min_noise_level_index", min_noise_level_index)
min_noise_level_timestep = scheduler.timesteps[min_noise_level_index]
setattr(TEX_PIPE, "min_noise_level_timestep", min_noise_level_timestep)
setattr(TEX_PIPE, "min_noise_level_sigma", min_noise_level_timestep / 1000.)
TEX_PIPE = TEX_PIPE.to("cuda", dtype=torch.float32) # use float32 for inference
return TEX_PIPE
@torch.amp.autocast('cuda', dtype=torch.float32)
def encode_images(
images: Float[Tensor, "B F H W C"], encode_as_first: bool = False
) -> Float[Tensor, "B C' F H/8 W/8"]:
"""
Encode images to latent space using VAE.
Every frame is seen as a separate image, without any awareness of the temporal dimension.
:param images: Input images tensor with shape [B, F, H, W, C].
:param encode_as_first: Whether to encode all frames as the first frame.
:return: Encoded latents with shape [B, C', F, H/8, W/8].
"""
if images.min() < - 0.1:
# images are in [-1, 1] range
images = (images + 1.0) / 2.0 # Normalize to [0, 1] range
if encode_as_first:
# encode all the frame as the first one
B = images.shape[0]
images = rearrange(images, "B F H W C -> (B F) C 1 H W")
latents = (VAE.encode(images).latent_dist.sample() - LATENTS_MEAN) * LATENTS_STD
latents = rearrange(latents, "(B F) C 1 H W -> B C F H W", B=B)
else:
raise NotImplementedError("Currently only support encode as first frame.")
return latents
# @torch.no_grad()
# @torch.amp.autocast('cuda', dtype=torch.float32)
# def decode_images(self, latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False):
# if decode_as_first:
# F = latents.shape[2]
# latents = latents.to(self.vae.dtype)
# latents = latents / self.latents_std + self.latents_mean
# latents = rearrange(latents, "B C F H W -> (B F) C 1 H W")
# images = self.vae.decode(latents, return_dict=False)[0]
# images = rearrange(images, "(B F) C Nv H W -> B C (F Nv) H W", F=F, Nv=1)
# else:
# raise NotImplementedError("Currently only support decode as first frame.")
# return images
@torch.amp.autocast('cuda', dtype=torch.float32)
def decode_images(latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False):
"""
Decode latents back to images using VAE.
:param latents: Input latents with shape [B, C, F, H, W].
:param decode_as_first: Whether to decode all frames as the first frame.
:return: Decoded images with shape [B, C, F*Nv, H*8, W*8].
"""
if decode_as_first:
F = latents.shape[2]
latents = latents.to(VAE.dtype)
latents = latents / LATENTS_STD + LATENTS_MEAN
latents = rearrange(latents, "B C F H W -> (B F) C 1 H W")
images = VAE.decode(latents, return_dict=False)[0]
images = rearrange(images, "(B F) C Nv H W -> B C (F Nv) H W", F=F, Nv=1)
else:
raise NotImplementedError("Currently only support decode as first frame.")
return images
def convert_img_to_tensor(image: Image.Image, device="cuda") -> Float[Tensor, "H W C"]:
"""
Convert a PIL Image to a tensor. If Image is RGBA, mask it with black background using a-channel mask.
:param image: PIL Image to convert. [0, 255]
:return: Tensor representation of the image. [0.0, 1.0], still [H, W, C]
"""
# Convert to RGBA to ensure alpha channel exists
image = image.convert("RGBA")
np_img = np.array(image)
rgb = np_img[..., :3]
alpha = np_img[..., 3:4] / 255.0 # Normalize alpha to [0, 1]
# Blend with black background using alpha mask
rgb = rgb * alpha
rgb = rgb.astype(np.float32) / 255.0 # Normalize to [0, 1]
tensor = torch.from_numpy(rgb).to(device)
return tensor
@spaces.GPU(duration=120)
@torch.cuda.amp.autocast(dtype=torch.float32)
@torch.inference_mode
@torch.no_grad
def generate_texture(position_map, normal_map, position_images, normal_images, condition_image, text_prompt, selected_view, negative_prompt=None, device="cuda", progress=gr.Progress()):
"""
Use SeqTex to generate texture for the mesh based on the image condition.
:param position_images: List of position images from different views.
:param normal_images: List of normal images from different views.
:param condition_image: Image condition generated from the selected view.
:param text_prompt: Text prompt for texture generation.
:param selected_view: The view selected for generating the image condition.
:return: Generated texture map, and multi-view frames in tensor.
"""
progress(0, desc="Loading SeqTex pipeline...")
tex_pipe = lazy_get_seqtex_pipe()
progress(0.2, desc="SeqTex pipeline loaded successfully.")
view_id_map = {
"First View": 0,
"Second View": 1,
"Third View": 2,
"Fourth View": 3
}
view_id = view_id_map[selected_view]
progress(0.3, desc="Encoding position and normal images...")
nat_seq = torch.cat([position_images.unsqueeze(0), normal_images.unsqueeze(0)], dim=0) # 1 F H W C
uv_seq = torch.cat([position_map.unsqueeze(0), normal_map.unsqueeze(0)], dim=0)
nat_latents = encode_images(nat_seq, encode_as_first=True) # B C F H W
uv_latents = encode_images(uv_seq, encode_as_first=True) # B C F' H' W'
nat_pos_latents, nat_norm_latents = torch.chunk(nat_latents, 2, dim=0)
uv_pos_latents, uv_norm_latents = torch.chunk(uv_latents, 2, dim=0)
nat_geo_latents = torch.cat([nat_pos_latents, nat_norm_latents], dim=1)
uv_geo_latents = torch.cat([uv_pos_latents, uv_norm_latents], dim=1)
cond_model_latents = (nat_geo_latents, uv_geo_latents)
num_frames = cfg.num_views * (2 ** sum(VAE.config.temperal_downsample))
uv_num_frames = cfg.uv_num_views * (2 ** sum(VAE.config.temperal_downsample))
progress(0.4, desc="Encoding condition image...")
if isinstance(condition_image, Image.Image):
condition_image = condition_image.resize((cfg.mv_width, cfg.mv_height), Image.LANCZOS)
# Convert PIL Image to tensor
condition_image = convert_img_to_tensor(condition_image, device=device)
condition_image = condition_image.unsqueeze(0).unsqueeze(0)
gt_latents = (encode_images(condition_image, encode_as_first=True), None)
progress(0.5, desc="Generating texture with SeqTex...")
latents = tex_pipe(
prompt=text_prompt,
negative_prompt=negative_prompt,
num_frames=num_frames,
generator=torch.Generator(device=device).manual_seed(cfg.eval_seed),
num_inference_steps=cfg.eval_num_inference_steps,
guidance_scale=cfg.eval_guidance_scale,
height=cfg.mv_height,
width=cfg.mv_width,
output_type="latent",
cond_model_latents=cond_model_latents,
# mask_indices=test_mask_indices,
uv_height=cfg.uv_height,
uv_width=cfg.uv_width,
uv_num_frames=uv_num_frames,
treat_as_first=True,
gt_condition=gt_latents,
inference_img_cond_frame=view_id,
use_qk_geometry=True,
task_type="img2tex", # img2tex
progress=progress,
).frames
mv_latents, uv_latents = latents
progress(0.9, desc="Decoding generated latents to images...")
mv_frames = decode_images(mv_latents, decode_as_first=True) # B C 4 H W
uv_frames = decode_images(uv_latents, decode_as_first=True) # B C 1 H W
uv_map_pred = uv_frames[:, :, -1, ...]
uv_map_pred.squeeze_(0)
mv_out = rearrange(mv_frames[:, :, :cfg.num_views, ...], "B C (F N) H W -> N C (B H) (F W)", N=1)[0]
mv_out = torch.clamp(mv_out, 0.0, 1.0)
uv_map_pred = torch.clamp(uv_map_pred, 0.0, 1.0)
progress(1, desc="Texture generated successfully.")
return uv_map_pred.float(), mv_out.float(), "Step 3: Texture generated successfully."