Spaces:
Running
on
L40S
Running
on
L40S
File size: 14,083 Bytes
099dc67 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
import os
import torch
import torch.nn as nn
import logging
from pathlib import Path
from huggingface_hub import hf_hub_download
from diffsynth import ModelManager, WanVideoReCamMasterPipeline
logger = logging.getLogger(__name__)
# Get model storage path from environment variable or use default
MODELS_ROOT_DIR = os.environ.get("RECAMMASTER_MODELS_DIR", "/data/models")
logger.info(f"Using models root directory: {MODELS_ROOT_DIR}")
# Define model repositories and files
WAN21_REPO_ID = "Wan-AI/Wan2.1-T2V-1.3B"
WAN21_LOCAL_DIR = f"{MODELS_ROOT_DIR}/Wan-AI/Wan2.1-T2V-1.3B"
WAN21_FILES = [
"diffusion_pytorch_model.safetensors",
"models_t5_umt5-xxl-enc-bf16.pth",
"Wan2.1_VAE.pth"
]
# Define tokenizer files to download
UMT5_XXL_TOKENIZER_FILES = [
"google/umt5-xxl/special_tokens_map.json",
"google/umt5-xxl/spiece.model",
"google/umt5-xxl/tokenizer.json",
"google/umt5-xxl/tokenizer_config.json"
]
RECAMMASTER_REPO_ID = "KwaiVGI/ReCamMaster-Wan2.1"
RECAMMASTER_CHECKPOINT_FILE = "step20000.ckpt"
RECAMMASTER_LOCAL_DIR = f"{MODELS_ROOT_DIR}/ReCamMaster/checkpoints"
class ModelLoader:
def __init__(self):
self.model_manager = None
self.pipe = None
self.is_loaded = False
def download_umt5_xxl_tokenizer(self, progress_callback=None):
"""Download UMT5-XXL tokenizer files from HuggingFace"""
total_files = len(UMT5_XXL_TOKENIZER_FILES)
downloaded_paths = []
for i, file_path in enumerate(UMT5_XXL_TOKENIZER_FILES):
local_dir = f"{WAN21_LOCAL_DIR}/{os.path.dirname(file_path)}"
filename = os.path.basename(file_path)
full_local_path = f"{WAN21_LOCAL_DIR}/{file_path}"
# Update progress
if progress_callback:
progress_callback(i/total_files, desc=f"Checking tokenizer file {i+1}/{total_files}: {filename}")
# Check if already exists
if os.path.exists(full_local_path):
logger.info(f"β Tokenizer file {filename} already exists at {full_local_path}")
downloaded_paths.append(full_local_path)
continue
# Create directory if it doesn't exist
os.makedirs(local_dir, exist_ok=True)
# Download the file
logger.info(f"Downloading tokenizer file {filename} from {WAN21_REPO_ID}/{file_path}...")
if progress_callback:
progress_callback(i/total_files, desc=f"Downloading tokenizer file {i+1}/{total_files}: {filename}")
try:
# Download using huggingface_hub
downloaded_path = hf_hub_download(
repo_id=WAN21_REPO_ID,
filename=file_path,
local_dir=WAN21_LOCAL_DIR,
local_dir_use_symlinks=False
)
logger.info(f"β Successfully downloaded tokenizer file {filename} to {downloaded_path}!")
downloaded_paths.append(downloaded_path)
except Exception as e:
logger.error(f"β Error downloading tokenizer file {filename}: {e}")
raise
if progress_callback:
progress_callback(1.0, desc=f"All tokenizer files downloaded successfully!")
return downloaded_paths
def download_wan21_models(self, progress_callback=None):
"""Download Wan2.1 model files from HuggingFace"""
total_files = len(WAN21_FILES)
downloaded_paths = []
# Create directory if it doesn't exist
Path(WAN21_LOCAL_DIR).mkdir(parents=True, exist_ok=True)
for i, filename in enumerate(WAN21_FILES):
local_path = Path(WAN21_LOCAL_DIR) / filename
# Update progress
if progress_callback:
progress_callback(i/total_files, desc=f"Checking Wan2.1 file {i+1}/{total_files}: {filename}")
# Check if already exists
if local_path.exists():
logger.info(f"β {filename} already exists at {local_path}")
downloaded_paths.append(str(local_path))
continue
# Download the file
logger.info(f"Downloading {filename} from {WAN21_REPO_ID}...")
if progress_callback:
progress_callback(i/total_files, desc=f"Downloading Wan2.1 file {i+1}/{total_files}: {filename}")
try:
# Download using huggingface_hub
downloaded_path = hf_hub_download(
repo_id=WAN21_REPO_ID,
filename=filename,
local_dir=WAN21_LOCAL_DIR,
local_dir_use_symlinks=False
)
logger.info(f"β Successfully downloaded {filename} to {downloaded_path}!")
downloaded_paths.append(downloaded_path)
except Exception as e:
logger.error(f"β Error downloading {filename}: {e}")
raise
if progress_callback:
progress_callback(1.0, desc=f"All Wan2.1 models downloaded successfully!")
return downloaded_paths
def download_recammaster_checkpoint(self, progress_callback=None):
"""Download ReCamMaster checkpoint from HuggingFace using huggingface_hub"""
checkpoint_path = Path(RECAMMASTER_LOCAL_DIR) / RECAMMASTER_CHECKPOINT_FILE
# Check if already exists
if checkpoint_path.exists():
logger.info(f"β ReCamMaster checkpoint already exists at {checkpoint_path}")
return checkpoint_path
# Create directory if it doesn't exist
Path(RECAMMASTER_LOCAL_DIR).mkdir(parents=True, exist_ok=True)
# Download the checkpoint
logger.info("Downloading ReCamMaster checkpoint from HuggingFace...")
logger.info(f"Repository: {RECAMMASTER_REPO_ID}")
logger.info(f"File: {RECAMMASTER_CHECKPOINT_FILE}")
logger.info(f"Destination: {checkpoint_path}")
if progress_callback:
progress_callback(0.0, desc=f"Downloading ReCamMaster checkpoint...")
try:
# Download using huggingface_hub
downloaded_path = hf_hub_download(
repo_id=RECAMMASTER_REPO_ID,
filename=RECAMMASTER_CHECKPOINT_FILE,
local_dir=RECAMMASTER_LOCAL_DIR,
local_dir_use_symlinks=False
)
logger.info(f"β Successfully downloaded ReCamMaster checkpoint to {downloaded_path}!")
if progress_callback:
progress_callback(1.0, desc=f"ReCamMaster checkpoint downloaded successfully!")
return downloaded_path
except Exception as e:
logger.error(f"β Error downloading checkpoint: {e}")
raise
def create_symlink_for_tokenizer(self):
"""Create symlink for google/umt5-xxl to handle potential path issues"""
try:
google_dir = f"{MODELS_ROOT_DIR}/google"
if not os.path.exists(google_dir):
os.makedirs(google_dir, exist_ok=True)
umt5_xxl_symlink = f"{google_dir}/umt5-xxl"
umt5_xxl_source = f"{WAN21_LOCAL_DIR}/google/umt5-xxl"
# Create a symlink if it doesn't exist
if not os.path.exists(umt5_xxl_symlink) and os.path.exists(umt5_xxl_source):
if os.name == 'nt': # Windows
import ctypes
kdll = ctypes.windll.LoadLibrary("kernel32.dll")
kdll.CreateSymbolicLinkA(umt5_xxl_symlink.encode(), umt5_xxl_source.encode(), 1)
else: # Unix/Linux
os.symlink(umt5_xxl_source, umt5_xxl_symlink)
logger.info(f"Created symlink from {umt5_xxl_source} to {umt5_xxl_symlink}")
except Exception as e:
logger.warning(f"Could not create symlink for google/umt5-xxl: {str(e)}")
# This is a warning, not an error, as we'll try to proceed anyway
def load_models(self, progress_callback=None):
"""Load the ReCamMaster models"""
if self.is_loaded:
return "Models already loaded!"
try:
logger.info("Starting model loading...")
# Import test data creator
from test_data import create_test_data_structure
# First create the test data structure
if progress_callback:
progress_callback(0.05, desc="Setting up test data structure...")
try:
create_test_data_structure(progress_callback)
except Exception as e:
error_msg = f"Error creating test data structure: {str(e)}"
logger.error(error_msg)
return error_msg
# Second, ensure the checkpoint is downloaded
if progress_callback:
progress_callback(0.1, desc="Checking for ReCamMaster checkpoint...")
try:
ckpt_path = self.download_recammaster_checkpoint(progress_callback)
logger.info(f"Using checkpoint at {ckpt_path}")
except Exception as e:
error_msg = f"Error downloading ReCamMaster checkpoint: {str(e)}"
logger.error(error_msg)
return error_msg
# Third, download Wan2.1 models if needed
if progress_callback:
progress_callback(0.2, desc="Checking for Wan2.1 models...")
try:
wan21_paths = self.download_wan21_models(progress_callback)
logger.info(f"Using Wan2.1 models: {wan21_paths}")
except Exception as e:
error_msg = f"Error downloading Wan2.1 models: {str(e)}"
logger.error(error_msg)
return error_msg
# Fourth, download UMT5-XXL tokenizer files
if progress_callback:
progress_callback(0.3, desc="Checking for UMT5-XXL tokenizer files...")
try:
tokenizer_paths = self.download_umt5_xxl_tokenizer(progress_callback)
logger.info(f"Using UMT5-XXL tokenizer files: {tokenizer_paths}")
except Exception as e:
error_msg = f"Error downloading UMT5-XXL tokenizer files: {str(e)}"
logger.error(error_msg)
return error_msg
# Now, load the models
if progress_callback:
progress_callback(0.4, desc="Loading model manager...")
# Create symlink for tokenizer
self.create_symlink_for_tokenizer()
# Load Wan2.1 pre-trained models
self.model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
if progress_callback:
progress_callback(0.5, desc="Loading Wan2.1 models...")
# Build full paths for the model files
model_files = [f"{WAN21_LOCAL_DIR}/{filename}" for filename in WAN21_FILES]
for model_file in model_files:
logger.info(f"Loading model from: {model_file}")
if not os.path.exists(model_file):
error_msg = f"Error: Model file not found: {model_file}"
logger.error(error_msg)
return error_msg
# Set environment variable for transformers to find the tokenizer
os.environ["TRANSFORMERS_CACHE"] = MODELS_ROOT_DIR
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism warning
self.model_manager.load_models(model_files)
if progress_callback:
progress_callback(0.7, desc="Creating pipeline...")
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(self.model_manager, device="cuda")
if progress_callback:
progress_callback(0.8, desc="Initializing ReCamMaster modules...")
# Initialize additional modules introduced in ReCamMaster
dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0]
for block in self.pipe.dit.blocks:
block.cam_encoder = nn.Linear(12, dim)
block.projector = nn.Linear(dim, dim)
block.cam_encoder.weight.data.zero_()
block.cam_encoder.bias.data.zero_()
block.projector.weight = nn.Parameter(torch.eye(dim))
block.projector.bias = nn.Parameter(torch.zeros(dim))
if progress_callback:
progress_callback(0.9, desc="Loading ReCamMaster checkpoint...")
# Load ReCamMaster checkpoint
if not os.path.exists(ckpt_path):
error_msg = f"Error: ReCamMaster checkpoint not found at {ckpt_path} even after download attempt."
logger.error(error_msg)
return error_msg
state_dict = torch.load(ckpt_path, map_location="cpu")
self.pipe.dit.load_state_dict(state_dict, strict=True)
self.pipe.to("cuda")
self.pipe.to(dtype=torch.bfloat16)
self.is_loaded = True
if progress_callback:
progress_callback(1.0, desc="Models loaded successfully!")
logger.info("Models loaded successfully!")
return "Models loaded successfully!"
except Exception as e:
logger.error(f"Error loading models: {str(e)}")
return f"Error loading models: {str(e)}" |