Spaces:
Runtime error
Runtime error
import os | |
import shutil | |
import requests | |
from tqdm import tqdm | |
from pathlib import Path | |
import hashlib | |
import json | |
import time | |
class ModelManager: | |
def __init__(self, cache_dir="/tmp/models", use_pytorch=False): | |
self.cache_dir = Path(cache_dir) | |
self.cache_dir.mkdir(parents=True, exist_ok=True) | |
self.use_pytorch = use_pytorch | |
# Hugging Face公式リポジトリからモデルを取得 | |
base_url = "https://huggingface.co/digital-avatar/ditto-talkinghead/resolve/main" | |
if use_pytorch: | |
# PyTorchモデルの設定 | |
self.model_configs = [ | |
{ | |
"name": "appearance_extractor.pth", | |
"url": f"{base_url}/ditto_pytorch/models/appearance_extractor.pth", | |
"dest_dir": "checkpoints/ditto_pytorch/models", | |
"dest_file": "appearance_extractor.pth", | |
"type": "file" | |
}, | |
{ | |
"name": "decoder.pth", | |
"url": f"{base_url}/ditto_pytorch/models/decoder.pth", | |
"dest_dir": "checkpoints/ditto_pytorch/models", | |
"dest_file": "decoder.pth", | |
"type": "file" | |
}, | |
{ | |
"name": "lmdm_v0.4_hubert.pth", | |
"url": f"{base_url}/ditto_pytorch/models/lmdm_v0.4_hubert.pth", | |
"dest_dir": "checkpoints/ditto_pytorch/models", | |
"dest_file": "lmdm_v0.4_hubert.pth", | |
"type": "file" | |
}, | |
{ | |
"name": "motion_extractor.pth", | |
"url": f"{base_url}/ditto_pytorch/models/motion_extractor.pth", | |
"dest_dir": "checkpoints/ditto_pytorch/models", | |
"dest_file": "motion_extractor.pth", | |
"type": "file" | |
}, | |
{ | |
"name": "stitch_network.pth", | |
"url": f"{base_url}/ditto_pytorch/models/stitch_network.pth", | |
"dest_dir": "checkpoints/ditto_pytorch/models", | |
"dest_file": "stitch_network.pth", | |
"type": "file" | |
}, | |
{ | |
"name": "warp_network.pth", | |
"url": f"{base_url}/ditto_pytorch/models/warp_network.pth", | |
"dest_dir": "checkpoints/ditto_pytorch/models", | |
"dest_file": "warp_network.pth", | |
"type": "file" | |
}, | |
{ | |
"name": "v0.4_hubert_cfg_pytorch.pkl", | |
"url": f"{base_url}/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl", | |
"dest_dir": "checkpoints/ditto_cfg", | |
"dest_file": "v0.4_hubert_cfg_pytorch.pkl", | |
"type": "file", | |
"size": "31 kB" | |
}, | |
# 補助モデル (aux_models) | |
{ | |
"name": "2d106det.onnx", | |
"url": f"{base_url}/ditto_pytorch/aux_models/2d106det.onnx", | |
"dest_dir": "checkpoints/ditto_pytorch/aux_models", | |
"dest_file": "2d106det.onnx", | |
"type": "file", | |
"size": "5.03 MB" | |
}, | |
{ | |
"name": "det_10g.onnx", | |
"url": f"{base_url}/ditto_pytorch/aux_models/det_10g.onnx", | |
"dest_dir": "checkpoints/ditto_pytorch/aux_models", | |
"dest_file": "det_10g.onnx", | |
"type": "file", | |
"size": "16.9 MB" | |
}, | |
{ | |
"name": "face_landmarker.task", | |
"url": f"{base_url}/ditto_pytorch/aux_models/face_landmarker.task", | |
"dest_dir": "checkpoints/ditto_pytorch/aux_models", | |
"dest_file": "face_landmarker.task", | |
"type": "file", | |
"size": "3.76 MB" | |
}, | |
{ | |
"name": "hubert_streaming_fix_kv.onnx", | |
"url": f"{base_url}/ditto_pytorch/aux_models/hubert_streaming_fix_kv.onnx", | |
"dest_dir": "checkpoints/ditto_pytorch/aux_models", | |
"dest_file": "hubert_streaming_fix_kv.onnx", | |
"type": "file", | |
"size": "1.46 GB" | |
}, | |
{ | |
"name": "landmark203.onnx", | |
"url": f"{base_url}/ditto_pytorch/aux_models/landmark203.onnx", | |
"dest_dir": "checkpoints/ditto_pytorch/aux_models", | |
"dest_file": "landmark203.onnx", | |
"type": "file", | |
"size": "115 MB" | |
} | |
] | |
else: | |
# TensorRTモデルの設定 | |
self.model_configs = [ | |
{ | |
"name": "ditto_trt_models", | |
"url": os.environ.get("DITTO_TRT_URL", f"{base_url}/checkpoints/ditto_trt_Ampere_Plus.tar.gz"), | |
"dest_dir": "checkpoints", | |
"type": "archive", | |
"extract_subdir": "ditto_trt_Ampere_Plus" | |
}, | |
{ | |
"name": "v0.4_hubert_cfg_trt.pkl", | |
"url": f"{base_url}/ditto_cfg/v0.4_hubert_cfg_trt.pkl", | |
"dest_dir": "checkpoints/ditto_cfg", | |
"dest_file": "v0.4_hubert_cfg_trt.pkl", | |
"type": "file" | |
} | |
] | |
self.progress_file = self.cache_dir / "download_progress.json" | |
self.download_progress = self.load_progress() | |
def load_progress(self): | |
"""ダウンロード進捗の読み込み""" | |
if self.progress_file.exists(): | |
with open(self.progress_file, 'r') as f: | |
return json.load(f) | |
return {} | |
def save_progress(self): | |
"""ダウンロード進捗の保存""" | |
with open(self.progress_file, 'w') as f: | |
json.dump(self.download_progress, f) | |
def get_file_hash(self, filepath): | |
"""ファイルのハッシュ値を計算""" | |
sha256_hash = hashlib.sha256() | |
with open(filepath, "rb") as f: | |
for byte_block in iter(lambda: f.read(4096), b""): | |
sha256_hash.update(byte_block) | |
return sha256_hash.hexdigest() | |
def download_file(self, url, dest_path, retries=3): | |
"""ファイルのダウンロード(レジューム対応)""" | |
dest_path = Path(dest_path) | |
dest_path.parent.mkdir(parents=True, exist_ok=True) | |
headers = {} | |
mode = 'wb' | |
resume_pos = 0 | |
# レジューム処理 | |
if dest_path.exists(): | |
resume_pos = dest_path.stat().st_size | |
headers['Range'] = f'bytes={resume_pos}-' | |
mode = 'ab' | |
for attempt in range(retries): | |
try: | |
response = requests.get(url, headers=headers, stream=True, timeout=30) | |
response.raise_for_status() | |
total_size = int(response.headers.get('content-length', 0)) | |
if resume_pos > 0: | |
total_size += resume_pos | |
with open(dest_path, mode) as f: | |
with tqdm(total=total_size, initial=resume_pos, unit='B', unit_scale=True, desc=dest_path.name) as pbar: | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: | |
f.write(chunk) | |
pbar.update(len(chunk)) | |
return True | |
except Exception as e: | |
print(f"ダウンロードエラー (試行 {attempt + 1}/{retries}): {e}") | |
if attempt < retries - 1: | |
time.sleep(5) # 再試行前に待機 | |
else: | |
raise | |
return False | |
def extract_archive(self, archive_path, dest_dir, extract_subdir=None): | |
"""アーカイブの展開""" | |
import tarfile | |
import zipfile | |
archive_path = Path(archive_path) | |
dest_dir = Path(dest_dir) | |
temp_dir = dest_dir / "temp_extract" | |
try: | |
if archive_path.suffix == '.gz' or archive_path.suffix == '.tar' or str(archive_path).endswith('.tar.gz'): | |
with tarfile.open(archive_path, 'r:*') as tar: | |
if extract_subdir: | |
# 一時ディレクトリに展開してから移動 | |
temp_dir.mkdir(exist_ok=True) | |
tar.extractall(temp_dir) | |
# 特定のサブディレクトリを移動 | |
src_dir = temp_dir / extract_subdir | |
if src_dir.exists(): | |
shutil.move(str(src_dir), str(dest_dir / extract_subdir)) | |
shutil.rmtree(temp_dir) | |
else: | |
tar.extractall(dest_dir) | |
elif archive_path.suffix == '.zip': | |
with zipfile.ZipFile(archive_path, 'r') as zip_ref: | |
zip_ref.extractall(dest_dir) | |
else: | |
raise ValueError(f"Unsupported archive format: {archive_path.suffix}") | |
except Exception as e: | |
if temp_dir.exists(): | |
shutil.rmtree(temp_dir) | |
raise e | |
def check_models_exist(self): | |
"""必要なモデルが存在するかチェック""" | |
missing_models = [] | |
for config in self.model_configs: | |
if config['type'] == 'file': | |
dest_path = Path(config['dest_dir']) / config['dest_file'] | |
if not dest_path.exists(): | |
missing_models.append(config) | |
else: # archive | |
dest_dir = Path(config['dest_dir']) | |
if not dest_dir.exists() or not any(dest_dir.iterdir()): | |
missing_models.append(config) | |
return missing_models | |
def download_models(self): | |
"""必要なモデルをダウンロード""" | |
missing_models = self.check_models_exist() | |
if not missing_models: | |
print("すべてのモデルが既に存在します。") | |
return True | |
print(f"{len(missing_models)}個のモデルをダウンロードします...") | |
for config in missing_models: | |
size_info = config.get('size', '不明') | |
print(f"\n{config['name']} をダウンロード中... (サイズ: {size_info})") | |
# キャッシュパスの設定 | |
cache_filename = f"{config['name']}.download" | |
cache_path = self.cache_dir / cache_filename | |
try: | |
# ダウンロード | |
if not cache_path.exists() or self.download_progress.get(config['name'], {}).get('status') != 'completed': | |
self.download_file(config['url'], cache_path) | |
self.download_progress[config['name']] = {'status': 'completed'} | |
self.save_progress() | |
# 展開またはコピー | |
if config['type'] == 'file': | |
dest_dir = Path(config['dest_dir']) | |
dest_dir.mkdir(parents=True, exist_ok=True) | |
dest_path = dest_dir / config['dest_file'] | |
shutil.copy2(cache_path, dest_path) | |
else: # archive | |
dest_dir = Path(config['dest_dir']) | |
dest_dir.mkdir(parents=True, exist_ok=True) | |
print(f"{config['name']} を展開中...") | |
extract_subdir = config.get('extract_subdir') | |
self.extract_archive(cache_path, dest_dir, extract_subdir) | |
print(f"{config['name']} のセットアップ完了") | |
except Exception as e: | |
print(f"エラー: {config['name']} のダウンロード中にエラーが発生しました: {e}") | |
return False | |
return True | |
def setup_models(self): | |
"""モデルのセットアップ(メイン処理)""" | |
print("=== DittoTalkingHead モデルセットアップ ===") | |
print(f"キャッシュディレクトリ: {self.cache_dir}") | |
success = self.download_models() | |
if success: | |
print("\n✅ すべてのモデルのセットアップが完了しました!") | |
else: | |
print("\n❌ モデルのセットアップ中にエラーが発生しました。") | |
return success | |
if __name__ == "__main__": | |
# テスト実行 | |
manager = ModelManager() | |
manager.setup_models() |