import os import shutil import requests from tqdm import tqdm from pathlib import Path import hashlib import json import time class ModelManager: def __init__(self, cache_dir="/tmp/models", use_pytorch=False): self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) self.use_pytorch = use_pytorch # Hugging Face公式リポジトリからモデルを取得 base_url = "https://huggingface.co/digital-avatar/ditto-talkinghead/resolve/main" if use_pytorch: # PyTorchモデルの設定 self.model_configs = [ { "name": "appearance_extractor.pth", "url": f"{base_url}/ditto_pytorch/models/appearance_extractor.pth", "dest_dir": "checkpoints/ditto_pytorch/models", "dest_file": "appearance_extractor.pth", "type": "file" }, { "name": "decoder.pth", "url": f"{base_url}/ditto_pytorch/models/decoder.pth", "dest_dir": "checkpoints/ditto_pytorch/models", "dest_file": "decoder.pth", "type": "file" }, { "name": "lmdm_v0.4_hubert.pth", "url": f"{base_url}/ditto_pytorch/models/lmdm_v0.4_hubert.pth", "dest_dir": "checkpoints/ditto_pytorch/models", "dest_file": "lmdm_v0.4_hubert.pth", "type": "file" }, { "name": "motion_extractor.pth", "url": f"{base_url}/ditto_pytorch/models/motion_extractor.pth", "dest_dir": "checkpoints/ditto_pytorch/models", "dest_file": "motion_extractor.pth", "type": "file" }, { "name": "stitch_network.pth", "url": f"{base_url}/ditto_pytorch/models/stitch_network.pth", "dest_dir": "checkpoints/ditto_pytorch/models", "dest_file": "stitch_network.pth", "type": "file" }, { "name": "warp_network.pth", "url": f"{base_url}/ditto_pytorch/models/warp_network.pth", "dest_dir": "checkpoints/ditto_pytorch/models", "dest_file": "warp_network.pth", "type": "file" }, { "name": "v0.4_hubert_cfg_pytorch.pkl", "url": f"{base_url}/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl", "dest_dir": "checkpoints/ditto_cfg", "dest_file": "v0.4_hubert_cfg_pytorch.pkl", "type": "file", "size": "31 kB" }, # 補助モデル (aux_models) { "name": "2d106det.onnx", "url": f"{base_url}/ditto_pytorch/aux_models/2d106det.onnx", "dest_dir": "checkpoints/ditto_pytorch/aux_models", "dest_file": "2d106det.onnx", "type": "file", "size": "5.03 MB" }, { "name": "det_10g.onnx", "url": f"{base_url}/ditto_pytorch/aux_models/det_10g.onnx", "dest_dir": "checkpoints/ditto_pytorch/aux_models", "dest_file": "det_10g.onnx", "type": "file", "size": "16.9 MB" }, { "name": "face_landmarker.task", "url": f"{base_url}/ditto_pytorch/aux_models/face_landmarker.task", "dest_dir": "checkpoints/ditto_pytorch/aux_models", "dest_file": "face_landmarker.task", "type": "file", "size": "3.76 MB" }, { "name": "hubert_streaming_fix_kv.onnx", "url": f"{base_url}/ditto_pytorch/aux_models/hubert_streaming_fix_kv.onnx", "dest_dir": "checkpoints/ditto_pytorch/aux_models", "dest_file": "hubert_streaming_fix_kv.onnx", "type": "file", "size": "1.46 GB" }, { "name": "landmark203.onnx", "url": f"{base_url}/ditto_pytorch/aux_models/landmark203.onnx", "dest_dir": "checkpoints/ditto_pytorch/aux_models", "dest_file": "landmark203.onnx", "type": "file", "size": "115 MB" } ] else: # TensorRTモデルの設定 self.model_configs = [ { "name": "ditto_trt_models", "url": os.environ.get("DITTO_TRT_URL", f"{base_url}/checkpoints/ditto_trt_Ampere_Plus.tar.gz"), "dest_dir": "checkpoints", "type": "archive", "extract_subdir": "ditto_trt_Ampere_Plus" }, { "name": "v0.4_hubert_cfg_trt.pkl", "url": f"{base_url}/ditto_cfg/v0.4_hubert_cfg_trt.pkl", "dest_dir": "checkpoints/ditto_cfg", "dest_file": "v0.4_hubert_cfg_trt.pkl", "type": "file" } ] self.progress_file = self.cache_dir / "download_progress.json" self.download_progress = self.load_progress() def load_progress(self): """ダウンロード進捗の読み込み""" if self.progress_file.exists(): with open(self.progress_file, 'r') as f: return json.load(f) return {} def save_progress(self): """ダウンロード進捗の保存""" with open(self.progress_file, 'w') as f: json.dump(self.download_progress, f) def get_file_hash(self, filepath): """ファイルのハッシュ値を計算""" sha256_hash = hashlib.sha256() with open(filepath, "rb") as f: for byte_block in iter(lambda: f.read(4096), b""): sha256_hash.update(byte_block) return sha256_hash.hexdigest() def download_file(self, url, dest_path, retries=3): """ファイルのダウンロード(レジューム対応)""" dest_path = Path(dest_path) dest_path.parent.mkdir(parents=True, exist_ok=True) headers = {} mode = 'wb' resume_pos = 0 # レジューム処理 if dest_path.exists(): resume_pos = dest_path.stat().st_size headers['Range'] = f'bytes={resume_pos}-' mode = 'ab' for attempt in range(retries): try: response = requests.get(url, headers=headers, stream=True, timeout=30) response.raise_for_status() total_size = int(response.headers.get('content-length', 0)) if resume_pos > 0: total_size += resume_pos with open(dest_path, mode) as f: with tqdm(total=total_size, initial=resume_pos, unit='B', unit_scale=True, desc=dest_path.name) as pbar: for chunk in response.iter_content(chunk_size=8192): if chunk: f.write(chunk) pbar.update(len(chunk)) return True except Exception as e: print(f"ダウンロードエラー (試行 {attempt + 1}/{retries}): {e}") if attempt < retries - 1: time.sleep(5) # 再試行前に待機 else: raise return False def extract_archive(self, archive_path, dest_dir, extract_subdir=None): """アーカイブの展開""" import tarfile import zipfile archive_path = Path(archive_path) dest_dir = Path(dest_dir) temp_dir = dest_dir / "temp_extract" try: if archive_path.suffix == '.gz' or archive_path.suffix == '.tar' or str(archive_path).endswith('.tar.gz'): with tarfile.open(archive_path, 'r:*') as tar: if extract_subdir: # 一時ディレクトリに展開してから移動 temp_dir.mkdir(exist_ok=True) tar.extractall(temp_dir) # 特定のサブディレクトリを移動 src_dir = temp_dir / extract_subdir if src_dir.exists(): shutil.move(str(src_dir), str(dest_dir / extract_subdir)) shutil.rmtree(temp_dir) else: tar.extractall(dest_dir) elif archive_path.suffix == '.zip': with zipfile.ZipFile(archive_path, 'r') as zip_ref: zip_ref.extractall(dest_dir) else: raise ValueError(f"Unsupported archive format: {archive_path.suffix}") except Exception as e: if temp_dir.exists(): shutil.rmtree(temp_dir) raise e def check_models_exist(self): """必要なモデルが存在するかチェック""" missing_models = [] for config in self.model_configs: if config['type'] == 'file': dest_path = Path(config['dest_dir']) / config['dest_file'] if not dest_path.exists(): missing_models.append(config) else: # archive dest_dir = Path(config['dest_dir']) if not dest_dir.exists() or not any(dest_dir.iterdir()): missing_models.append(config) return missing_models def download_models(self): """必要なモデルをダウンロード""" missing_models = self.check_models_exist() if not missing_models: print("すべてのモデルが既に存在します。") return True print(f"{len(missing_models)}個のモデルをダウンロードします...") for config in missing_models: size_info = config.get('size', '不明') print(f"\n{config['name']} をダウンロード中... (サイズ: {size_info})") # キャッシュパスの設定 cache_filename = f"{config['name']}.download" cache_path = self.cache_dir / cache_filename try: # ダウンロード if not cache_path.exists() or self.download_progress.get(config['name'], {}).get('status') != 'completed': self.download_file(config['url'], cache_path) self.download_progress[config['name']] = {'status': 'completed'} self.save_progress() # 展開またはコピー if config['type'] == 'file': dest_dir = Path(config['dest_dir']) dest_dir.mkdir(parents=True, exist_ok=True) dest_path = dest_dir / config['dest_file'] shutil.copy2(cache_path, dest_path) else: # archive dest_dir = Path(config['dest_dir']) dest_dir.mkdir(parents=True, exist_ok=True) print(f"{config['name']} を展開中...") extract_subdir = config.get('extract_subdir') self.extract_archive(cache_path, dest_dir, extract_subdir) print(f"{config['name']} のセットアップ完了") except Exception as e: print(f"エラー: {config['name']} のダウンロード中にエラーが発生しました: {e}") return False return True def setup_models(self): """モデルのセットアップ(メイン処理)""" print("=== DittoTalkingHead モデルセットアップ ===") print(f"キャッシュディレクトリ: {self.cache_dir}") success = self.download_models() if success: print("\n✅ すべてのモデルのセットアップが完了しました!") else: print("\n❌ モデルのセットアップ中にエラーが発生しました。") return success if __name__ == "__main__": # テスト実行 manager = ModelManager() manager.setup_models()