""" 自动下载MagicArticulate和Michelangelo所需的模型文件 在HF Space启动时调用 """ import os import logging from pathlib import Path logger = logging.getLogger(__name__) def download_models(): """下载所有必需的模型文件""" try: from huggingface_hub import hf_hub_download logger.info("🔄 开始下载模型文件...") # 1. 下载Michelangelo模型 michelangelo_path = "third_party/Michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt" if not os.path.exists(michelangelo_path): logger.info("📥 下载Michelangelo模型...") try: file_path = hf_hub_download( repo_id="Maikou/Michelangelo", filename="checkpoints/aligned_shape_latents/shapevae-256.ckpt", local_dir="third_party/Michelangelo" ) logger.info(f"✅ Michelangelo模型下载完成: {file_path}") except Exception as e: logger.error(f"❌ Michelangelo模型下载失败: {e}") else: logger.info("✅ Michelangelo模型已存在") # 2. 下载MagicArticulate层次模型 hier_path = "skeleton_ckpt/checkpoint_trainonv2_hier.pth" if not os.path.exists(hier_path): logger.info("📥 下载MagicArticulate层次模型...") try: os.makedirs("skeleton_ckpt", exist_ok=True) file_path = hf_hub_download( repo_id="Seed3D/MagicArticulate", filename="skeleton_ckpt/checkpoint_trainonv2_hier.pth", local_dir="" ) logger.info(f"✅ MagicArticulate层次模型下载完成: {file_path}") except Exception as e: logger.error(f"❌ MagicArticulate层次模型下载失败: {e}") else: logger.info("✅ MagicArticulate层次模型已存在") # 3. 下载MagicArticulate空间模型 spatial_path = "skeleton_ckpt/checkpoint_trainonv2_spatial.pth" if not os.path.exists(spatial_path): logger.info("📥 下载MagicArticulate空间模型...") try: os.makedirs("skeleton_ckpt", exist_ok=True) file_path = hf_hub_download( repo_id="Seed3D/MagicArticulate", filename="skeleton_ckpt/checkpoint_trainonv2_spatial.pth", local_dir="" ) logger.info(f"✅ MagicArticulate空间模型下载完成: {file_path}") except Exception as e: logger.error(f"❌ MagicArticulate空间模型下载失败: {e}") else: logger.info("✅ MagicArticulate空间模型已存在") logger.info("🎯 模型下载过程完成") return True except ImportError: logger.error("❌ huggingface_hub未安装,无法下载模型") return False except Exception as e: logger.error(f"💥 模型下载过程出错: {e}") return False if __name__ == "__main__": download_models()