talkingAvater_bgk / model_manager.py
oKen38461's picture
モデルマネージャーに補助モデルの情報を追加し、`requirements.txt`にMediaPipeを新たに追加しました。また、NumPyのバージョン制約を設定しました。
09ceafd
import os
import shutil
import requests
from tqdm import tqdm
from pathlib import Path
import hashlib
import json
import time
class ModelManager:
def __init__(self, cache_dir="/tmp/models", use_pytorch=False):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.use_pytorch = use_pytorch
# Hugging Face公式リポジトリからモデルを取得
base_url = "https://huggingface.co/digital-avatar/ditto-talkinghead/resolve/main"
if use_pytorch:
# PyTorchモデルの設定
self.model_configs = [
{
"name": "appearance_extractor.pth",
"url": f"{base_url}/ditto_pytorch/models/appearance_extractor.pth",
"dest_dir": "checkpoints/ditto_pytorch/models",
"dest_file": "appearance_extractor.pth",
"type": "file"
},
{
"name": "decoder.pth",
"url": f"{base_url}/ditto_pytorch/models/decoder.pth",
"dest_dir": "checkpoints/ditto_pytorch/models",
"dest_file": "decoder.pth",
"type": "file"
},
{
"name": "lmdm_v0.4_hubert.pth",
"url": f"{base_url}/ditto_pytorch/models/lmdm_v0.4_hubert.pth",
"dest_dir": "checkpoints/ditto_pytorch/models",
"dest_file": "lmdm_v0.4_hubert.pth",
"type": "file"
},
{
"name": "motion_extractor.pth",
"url": f"{base_url}/ditto_pytorch/models/motion_extractor.pth",
"dest_dir": "checkpoints/ditto_pytorch/models",
"dest_file": "motion_extractor.pth",
"type": "file"
},
{
"name": "stitch_network.pth",
"url": f"{base_url}/ditto_pytorch/models/stitch_network.pth",
"dest_dir": "checkpoints/ditto_pytorch/models",
"dest_file": "stitch_network.pth",
"type": "file"
},
{
"name": "warp_network.pth",
"url": f"{base_url}/ditto_pytorch/models/warp_network.pth",
"dest_dir": "checkpoints/ditto_pytorch/models",
"dest_file": "warp_network.pth",
"type": "file"
},
{
"name": "v0.4_hubert_cfg_pytorch.pkl",
"url": f"{base_url}/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl",
"dest_dir": "checkpoints/ditto_cfg",
"dest_file": "v0.4_hubert_cfg_pytorch.pkl",
"type": "file",
"size": "31 kB"
},
# 補助モデル (aux_models)
{
"name": "2d106det.onnx",
"url": f"{base_url}/ditto_pytorch/aux_models/2d106det.onnx",
"dest_dir": "checkpoints/ditto_pytorch/aux_models",
"dest_file": "2d106det.onnx",
"type": "file",
"size": "5.03 MB"
},
{
"name": "det_10g.onnx",
"url": f"{base_url}/ditto_pytorch/aux_models/det_10g.onnx",
"dest_dir": "checkpoints/ditto_pytorch/aux_models",
"dest_file": "det_10g.onnx",
"type": "file",
"size": "16.9 MB"
},
{
"name": "face_landmarker.task",
"url": f"{base_url}/ditto_pytorch/aux_models/face_landmarker.task",
"dest_dir": "checkpoints/ditto_pytorch/aux_models",
"dest_file": "face_landmarker.task",
"type": "file",
"size": "3.76 MB"
},
{
"name": "hubert_streaming_fix_kv.onnx",
"url": f"{base_url}/ditto_pytorch/aux_models/hubert_streaming_fix_kv.onnx",
"dest_dir": "checkpoints/ditto_pytorch/aux_models",
"dest_file": "hubert_streaming_fix_kv.onnx",
"type": "file",
"size": "1.46 GB"
},
{
"name": "landmark203.onnx",
"url": f"{base_url}/ditto_pytorch/aux_models/landmark203.onnx",
"dest_dir": "checkpoints/ditto_pytorch/aux_models",
"dest_file": "landmark203.onnx",
"type": "file",
"size": "115 MB"
}
]
else:
# TensorRTモデルの設定
self.model_configs = [
{
"name": "ditto_trt_models",
"url": os.environ.get("DITTO_TRT_URL", f"{base_url}/checkpoints/ditto_trt_Ampere_Plus.tar.gz"),
"dest_dir": "checkpoints",
"type": "archive",
"extract_subdir": "ditto_trt_Ampere_Plus"
},
{
"name": "v0.4_hubert_cfg_trt.pkl",
"url": f"{base_url}/ditto_cfg/v0.4_hubert_cfg_trt.pkl",
"dest_dir": "checkpoints/ditto_cfg",
"dest_file": "v0.4_hubert_cfg_trt.pkl",
"type": "file"
}
]
self.progress_file = self.cache_dir / "download_progress.json"
self.download_progress = self.load_progress()
def load_progress(self):
"""ダウンロード進捗の読み込み"""
if self.progress_file.exists():
with open(self.progress_file, 'r') as f:
return json.load(f)
return {}
def save_progress(self):
"""ダウンロード進捗の保存"""
with open(self.progress_file, 'w') as f:
json.dump(self.download_progress, f)
def get_file_hash(self, filepath):
"""ファイルのハッシュ値を計算"""
sha256_hash = hashlib.sha256()
with open(filepath, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
def download_file(self, url, dest_path, retries=3):
"""ファイルのダウンロード(レジューム対応)"""
dest_path = Path(dest_path)
dest_path.parent.mkdir(parents=True, exist_ok=True)
headers = {}
mode = 'wb'
resume_pos = 0
# レジューム処理
if dest_path.exists():
resume_pos = dest_path.stat().st_size
headers['Range'] = f'bytes={resume_pos}-'
mode = 'ab'
for attempt in range(retries):
try:
response = requests.get(url, headers=headers, stream=True, timeout=30)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
if resume_pos > 0:
total_size += resume_pos
with open(dest_path, mode) as f:
with tqdm(total=total_size, initial=resume_pos, unit='B', unit_scale=True, desc=dest_path.name) as pbar:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
return True
except Exception as e:
print(f"ダウンロードエラー (試行 {attempt + 1}/{retries}): {e}")
if attempt < retries - 1:
time.sleep(5) # 再試行前に待機
else:
raise
return False
def extract_archive(self, archive_path, dest_dir, extract_subdir=None):
"""アーカイブの展開"""
import tarfile
import zipfile
archive_path = Path(archive_path)
dest_dir = Path(dest_dir)
temp_dir = dest_dir / "temp_extract"
try:
if archive_path.suffix == '.gz' or archive_path.suffix == '.tar' or str(archive_path).endswith('.tar.gz'):
with tarfile.open(archive_path, 'r:*') as tar:
if extract_subdir:
# 一時ディレクトリに展開してから移動
temp_dir.mkdir(exist_ok=True)
tar.extractall(temp_dir)
# 特定のサブディレクトリを移動
src_dir = temp_dir / extract_subdir
if src_dir.exists():
shutil.move(str(src_dir), str(dest_dir / extract_subdir))
shutil.rmtree(temp_dir)
else:
tar.extractall(dest_dir)
elif archive_path.suffix == '.zip':
with zipfile.ZipFile(archive_path, 'r') as zip_ref:
zip_ref.extractall(dest_dir)
else:
raise ValueError(f"Unsupported archive format: {archive_path.suffix}")
except Exception as e:
if temp_dir.exists():
shutil.rmtree(temp_dir)
raise e
def check_models_exist(self):
"""必要なモデルが存在するかチェック"""
missing_models = []
for config in self.model_configs:
if config['type'] == 'file':
dest_path = Path(config['dest_dir']) / config['dest_file']
if not dest_path.exists():
missing_models.append(config)
else: # archive
dest_dir = Path(config['dest_dir'])
if not dest_dir.exists() or not any(dest_dir.iterdir()):
missing_models.append(config)
return missing_models
def download_models(self):
"""必要なモデルをダウンロード"""
missing_models = self.check_models_exist()
if not missing_models:
print("すべてのモデルが既に存在します。")
return True
print(f"{len(missing_models)}個のモデルをダウンロードします...")
for config in missing_models:
size_info = config.get('size', '不明')
print(f"\n{config['name']} をダウンロード中... (サイズ: {size_info})")
# キャッシュパスの設定
cache_filename = f"{config['name']}.download"
cache_path = self.cache_dir / cache_filename
try:
# ダウンロード
if not cache_path.exists() or self.download_progress.get(config['name'], {}).get('status') != 'completed':
self.download_file(config['url'], cache_path)
self.download_progress[config['name']] = {'status': 'completed'}
self.save_progress()
# 展開またはコピー
if config['type'] == 'file':
dest_dir = Path(config['dest_dir'])
dest_dir.mkdir(parents=True, exist_ok=True)
dest_path = dest_dir / config['dest_file']
shutil.copy2(cache_path, dest_path)
else: # archive
dest_dir = Path(config['dest_dir'])
dest_dir.mkdir(parents=True, exist_ok=True)
print(f"{config['name']} を展開中...")
extract_subdir = config.get('extract_subdir')
self.extract_archive(cache_path, dest_dir, extract_subdir)
print(f"{config['name']} のセットアップ完了")
except Exception as e:
print(f"エラー: {config['name']} のダウンロード中にエラーが発生しました: {e}")
return False
return True
def setup_models(self):
"""モデルのセットアップ(メイン処理)"""
print("=== DittoTalkingHead モデルセットアップ ===")
print(f"キャッシュディレクトリ: {self.cache_dir}")
success = self.download_models()
if success:
print("\n✅ すべてのモデルのセットアップが完了しました!")
else:
print("\n❌ モデルのセットアップ中にエラーが発生しました。")
return success
if __name__ == "__main__":
# テスト実行
manager = ModelManager()
manager.setup_models()