Spaces:
Runtime error
Runtime error
""" | |
Optimized DittoTalkingHead App with Phase 3 Performance Improvements | |
""" | |
import gradio as gr | |
import os | |
import tempfile | |
import shutil | |
from pathlib import Path | |
import torch | |
import time | |
from typing import Optional, Dict, Any | |
import io | |
from model_manager import ModelManager | |
from core.optimization import ( | |
FixedResolutionProcessor, | |
GPUOptimizer, | |
AvatarCache, | |
AvatarTokenManager, | |
ColdStartOptimizer, | |
InferenceCache, | |
CachedInference, | |
ParallelProcessor, | |
ParallelInference, | |
OptimizedInferenceWrapper | |
) | |
from cleanup_old_files import initialize_cleanup, get_cleanup_status | |
# サンプルファイルのディレクトリを定義 | |
EXAMPLES_DIR = (Path(__file__).parent / "example").resolve() | |
OUTPUT_DIR = (Path(__file__).parent / "output").resolve() | |
# 出力ディレクトリの作成 | |
OUTPUT_DIR.mkdir(exist_ok=True) | |
# ファイルクリーンアップの初期化(24時間後に自動削除) | |
initialize_cleanup(OUTPUT_DIR, max_age_hours=24) | |
# 初期化フラグ | |
print("=== Phase 3 最適化版 - 初期化開始 ===") | |
# 1. 解像度最適化の初期化 | |
resolution_optimizer = FixedResolutionProcessor() | |
FIXED_RESOLUTION = resolution_optimizer.get_max_dim() # 320 | |
print(f"✅ 解像度固定: {FIXED_RESOLUTION}×{FIXED_RESOLUTION}") | |
# 2. GPU最適化の初期化 | |
gpu_optimizer = GPUOptimizer() | |
print(gpu_optimizer.get_optimization_summary()) | |
# 3. Cold Start最適化の初期化 | |
cold_start_optimizer = ColdStartOptimizer() | |
# 4. アバターキャッシュの初期化 | |
avatar_cache = AvatarCache(cache_dir="/tmp/avatar_cache", ttl_days=14) | |
token_manager = AvatarTokenManager(avatar_cache) | |
print(f"✅ アバターキャッシュ初期化: {avatar_cache.get_cache_info()}") | |
# 5. 推論キャッシュの初期化 | |
inference_cache = InferenceCache( | |
cache_dir="/tmp/inference_cache", | |
memory_cache_size=50, | |
file_cache_size_gb=5.0, | |
ttl_hours=24 | |
) | |
cached_inference = CachedInference(inference_cache) | |
print(f"✅ 推論キャッシュ初期化: {inference_cache.get_cache_stats()}") | |
# 6. 並列処理の初期化(SDK初期化後に移動) | |
# モデルの初期化(最適化版) | |
USE_PYTORCH = True | |
model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH) | |
# Cold start最適化: 永続ストレージのセットアップ | |
if not cold_start_optimizer.setup_persistent_model_cache("./checkpoints"): | |
print("⚠️ 永続ストレージのセットアップに失敗") | |
if not model_manager.setup_models(): | |
raise RuntimeError("モデルのセットアップに失敗しました。") | |
# SDKの初期化 | |
if USE_PYTORCH: | |
data_root = "./checkpoints/ditto_pytorch" | |
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" | |
else: | |
data_root = "./checkpoints/ditto_trt_Ampere_Plus" | |
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl" | |
# SDK初期化 | |
SDK = None | |
try: | |
from stream_pipeline_offline import StreamSDK | |
from inference import run, seed_everything | |
# SDKを最適化設定で初期化 | |
SDK = StreamSDK(cfg_pkl, data_root) | |
print("✅ SDK初期化成功(最適化版)") | |
# GPU最適化を適用(torch.nn.Moduleの場合のみ) | |
if hasattr(SDK, 'decode_f3d') and hasattr(SDK.decode_f3d, 'decoder'): | |
try: | |
import torch.nn as nn | |
if isinstance(SDK.decode_f3d.decoder, nn.Module): | |
SDK.decode_f3d.decoder = gpu_optimizer.optimize_model(SDK.decode_f3d.decoder) | |
print("✅ デコーダーモデルに最適化を適用") | |
else: | |
print("ℹ️ デコーダーはnn.Moduleではないため、最適化をスキップ") | |
except Exception as e: | |
print(f"⚠️ GPU最適化の適用をスキップ: {e}") | |
except Exception as e: | |
print(f"❌ SDK初期化エラー: {e}") | |
import traceback | |
traceback.print_exc() | |
raise | |
# 並列処理の初期化(SDK初期化成功後) | |
parallel_processor = ParallelProcessor(num_threads=4, num_processes=2) | |
parallel_inference = ParallelInference(SDK, parallel_processor) | |
optimized_wrapper = OptimizedInferenceWrapper( | |
SDK, | |
use_parallel=True, | |
use_cache=True, | |
use_gpu_opt=True | |
) | |
print(f"✅ 並列処理初期化: {parallel_inference.get_performance_stats()}") | |
def prepare_avatar(image_file) -> Dict[str, Any]: | |
""" | |
画像を事前処理してアバタートークンを生成 | |
Args: | |
image_file: アップロードされた画像ファイル | |
Returns: | |
アバタートークン情報 | |
""" | |
if image_file is None: | |
return {"error": "画像ファイルをアップロードしてください。"} | |
try: | |
# 画像データを読み込む | |
with open(image_file, 'rb') as f: | |
image_data = f.read() | |
# 外観エンコーダーで埋め込みを生成 | |
def encode_appearance(img_data): | |
# ここでは簡略化のため、SDKの外観抽出を使用 | |
# 実際の実装では appearance_extractor を直接呼び出す | |
import numpy as np | |
from PIL import Image | |
# 画像を読み込んで処理 | |
img = Image.open(io.BytesIO(img_data)) | |
img = img.convert('RGB') | |
img = img.resize((FIXED_RESOLUTION, FIXED_RESOLUTION)) | |
# 仮の埋め込みベクトル(実際はモデルで生成) | |
# TODO: 実際の appearance_extractor を使用 | |
embedding = np.random.randn(512).astype(np.float32) | |
return embedding | |
# トークンを生成 | |
result = token_manager.prepare_avatar( | |
image_data, | |
encode_appearance | |
) | |
return { | |
"status": "✅ アバター準備完了", | |
"avatar_token": result['avatar_token'], | |
"expires": result['expires'], | |
"cached": "キャッシュ済み" if result['cached'] else "新規生成" | |
} | |
except Exception as e: | |
import traceback | |
return { | |
"error": f"❌ エラー: {str(e)}\n{traceback.format_exc()}" | |
} | |
def process_talking_head_optimized( | |
audio_file, | |
source_image, | |
avatar_token: Optional[str] = None, | |
use_resolution_optimization: bool = True, | |
use_inference_cache: bool = True, | |
use_parallel_processing: bool = True | |
): | |
""" | |
最適化されたTalking Head生成処理(キャッシュ対応) | |
Args: | |
audio_file: 音声ファイル | |
source_image: ソース画像(avatar_tokenがない場合に使用) | |
avatar_token: 事前生成されたアバタートークン | |
use_resolution_optimization: 解像度最適化を使用するか | |
use_inference_cache: 推論キャッシュを使用するか | |
""" | |
if audio_file is None: | |
return None, "音声ファイルをアップロードしてください。" | |
if avatar_token is None and source_image is None: | |
return None, "ソース画像またはアバタートークンが必要です。" | |
try: | |
start_time = time.time() | |
# 出力ファイルの作成(出力ディレクトリ内) | |
import uuid | |
output_filename = f"{uuid.uuid4()}.mp4" | |
output_path = str(OUTPUT_DIR / output_filename) | |
# アバタートークンから埋め込みを取得 | |
if avatar_token: | |
embedding = avatar_cache.load_embedding(avatar_token) | |
if embedding is None: | |
return None, "❌ 無効または期限切れのアバタートークンです。" | |
print(f"✅ キャッシュから埋め込みを取得: {avatar_token[:8]}...") | |
# 解像度最適化設定を適用 | |
if use_resolution_optimization: | |
setup_kwargs = { | |
"max_size": FIXED_RESOLUTION, # 320固定 | |
"sampling_timesteps": resolution_optimizer.get_diffusion_steps() # 25 | |
} | |
print(f"✅ 解像度最適化適用: {FIXED_RESOLUTION}×{FIXED_RESOLUTION}, ステップ数: {setup_kwargs['sampling_timesteps']}") | |
else: | |
setup_kwargs = {} | |
# 処理方法の選択 | |
if use_parallel_processing and source_image: | |
# 並列処理を使用 | |
print("🔄 並列処理モードで実行...") | |
if use_inference_cache: | |
# キャッシュ + 並列処理 | |
def inference_func(audio_path, image_path, out_path, **kwargs): | |
# 並列処理ラッパーを使用 | |
optimized_wrapper.process( | |
audio_path, image_path, out_path, | |
seed=1024, | |
more_kwargs={"setup_kwargs": kwargs.get('setup_kwargs', {})} | |
) | |
# キャッシュシステムを通じて処理 | |
result_path, cache_hit, process_time = cached_inference.process_with_cache( | |
inference_func, | |
audio_file, | |
source_image, | |
output_path, | |
resolution=f"{FIXED_RESOLUTION}x{FIXED_RESOLUTION}" if use_resolution_optimization else "default", | |
steps=setup_kwargs.get('sampling_timesteps', 50), | |
setup_kwargs=setup_kwargs | |
) | |
cache_status = "キャッシュヒット(並列)" if cache_hit else "新規生成(並列)" | |
else: | |
# 並列処理のみ | |
_, process_time, stats = optimized_wrapper.process( | |
audio_file, source_image, output_path, | |
seed=1024, | |
more_kwargs={"setup_kwargs": setup_kwargs} | |
) | |
cache_hit = False | |
cache_status = "並列処理(キャッシュ未使用)" | |
elif use_inference_cache and source_image: | |
# キャッシュのみ(並列処理なし) | |
def inference_func(audio_path, image_path, out_path, **kwargs): | |
seed_everything(1024) | |
run(SDK, audio_path, image_path, out_path, | |
more_kwargs={"setup_kwargs": kwargs.get('setup_kwargs', {})}) | |
# キャッシュシステムを通じて処理 | |
result_path, cache_hit, process_time = cached_inference.process_with_cache( | |
inference_func, | |
audio_file, | |
source_image, | |
output_path, | |
resolution=f"{FIXED_RESOLUTION}x{FIXED_RESOLUTION}" if use_resolution_optimization else "default", | |
steps=setup_kwargs.get('sampling_timesteps', 50), | |
setup_kwargs=setup_kwargs | |
) | |
cache_status = "キャッシュヒット" if cache_hit else "新規生成" | |
else: | |
# 通常処理(並列処理もキャッシュもなし) | |
print(f"処理開始: audio={audio_file}, image={source_image}, token={avatar_token is not None}") | |
seed_everything(1024) | |
run(SDK, audio_file, source_image, output_path, more_kwargs={"setup_kwargs": setup_kwargs}) | |
process_time = time.time() - start_time | |
cache_hit = False | |
cache_status = "通常処理" | |
# 結果の確認 | |
if os.path.exists(output_path) and os.path.getsize(output_path) > 0: | |
# パフォーマンス統計 | |
perf_info = f""" | |
✅ 処理完了! | |
処理時間: {process_time:.2f}秒 | |
解像度: {FIXED_RESOLUTION}×{FIXED_RESOLUTION} | |
最適化設定: | |
- 解像度最適化: {'有効' if use_resolution_optimization else '無効'} | |
- 並列処理: {'有効' if use_parallel_processing else '無効'} | |
- アバターキャッシュ: {'使用' if avatar_token else '未使用'} | |
- 推論キャッシュ: {cache_status} | |
キャッシュ統計: {inference_cache.get_cache_stats()['memory_cache_entries']}件(メモリ), {inference_cache.get_cache_stats()['file_cache_entries']}件(ファイル) | |
""" | |
return output_path, perf_info | |
else: | |
return None, "❌ 処理に失敗しました。出力ファイルが生成されませんでした。" | |
except Exception as e: | |
import traceback | |
error_msg = f"❌ エラーが発生しました: {str(e)}\n{traceback.format_exc()}" | |
print(error_msg) | |
return None, error_msg | |
# Gradio UI(最適化版) | |
with gr.Blocks(title="DittoTalkingHead - Phase 3 最適化版") as demo: | |
gr.Markdown(""" | |
# DittoTalkingHead - Phase 3 高速化実装 | |
**🚀 最適化機能:** | |
- 📐 解像度320×320固定による高速化 | |
- 🎯 画像事前アップロード&キャッシュ機能 | |
- ⚡ GPU最適化(Mixed Precision, torch.compile) | |
- 💾 Cold Start最適化 | |
- 🔄 推論キャッシュ(同じ入力で即座に結果を返す) | |
- 🚀 並列処理(音声・画像の前処理を並列化) | |
## 使い方 | |
### 方法1: 通常の使用 | |
1. 音声ファイル(WAV)と画像をアップロード | |
2. 「生成」ボタンをクリック | |
### 方法2: 高速化(推奨) | |
1. 「アバター準備」タブで画像を事前アップロード | |
2. 生成されたトークンをコピー | |
3. 「動画生成」タブで音声とトークンを使用 | |
""") | |
with gr.Tabs(): | |
# タブ1: 通常の動画生成 | |
with gr.TabItem("🎬 動画生成"): | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio( | |
label="音声ファイル (WAV)", | |
type="filepath" | |
) | |
with gr.Row(): | |
image_input = gr.Image( | |
label="ソース画像(オプション)", | |
type="filepath" | |
) | |
token_input = gr.Textbox( | |
label="アバタートークン(オプション)", | |
placeholder="事前準備したトークンを入力", | |
lines=1 | |
) | |
use_optimization = gr.Checkbox( | |
label="解像度最適化を使用(320×320)", | |
value=True | |
) | |
use_cache = gr.Checkbox( | |
label="推論キャッシュを使用(同じ入力で高速化)", | |
value=True | |
) | |
use_parallel = gr.Checkbox( | |
label="並列処理を使用(前処理を高速化)", | |
value=True | |
) | |
generate_btn = gr.Button("🎬 生成", variant="primary") | |
with gr.Column(): | |
video_output = gr.Video( | |
label="生成されたビデオ" | |
) | |
status_output = gr.Textbox( | |
label="ステータス", | |
lines=6 | |
) | |
# タブ2: アバター準備 | |
with gr.TabItem("👤 アバター準備"): | |
gr.Markdown(""" | |
### 画像を事前にアップロードして高速化 | |
画像の埋め込みベクトルを事前計算し、トークンとして保存します。 | |
このトークンを使用することで、動画生成時の処理時間を短縮できます。 | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
avatar_image_input = gr.Image( | |
label="アバター画像", | |
type="filepath" | |
) | |
prepare_btn = gr.Button("📤 アバター準備", variant="primary") | |
with gr.Column(): | |
prepare_output = gr.JSON( | |
label="準備結果" | |
) | |
# タブ3: 最適化情報 | |
with gr.TabItem("📊 最適化情報"): | |
with gr.Row(): | |
refresh_btn = gr.Button("🔄 情報を更新", scale=1) | |
info_display = gr.Markdown(f""" | |
### 現在の最適化設定 | |
{resolution_optimizer.get_optimization_summary()} | |
{gpu_optimizer.get_optimization_summary()} | |
### アバターキャッシュ情報 | |
{avatar_cache.get_cache_info()} | |
### 推論キャッシュ情報 | |
{inference_cache.get_cache_stats()} | |
""") | |
# キャッシュ管理ボタン | |
with gr.Row(): | |
clear_inference_cache_btn = gr.Button("🗑️ 推論キャッシュをクリア", variant="secondary") | |
clear_avatar_cache_btn = gr.Button("🗑️ アバターキャッシュをクリア", variant="secondary") | |
cleanup_status_btn = gr.Button("📊 クリーンアップ状態", variant="secondary") | |
cache_status = gr.Textbox(label="キャッシュ操作ステータス", lines=2) | |
# サンプル | |
example_audio = EXAMPLES_DIR / "audio.wav" | |
example_image = EXAMPLES_DIR / "image.png" | |
if example_audio.exists() and example_image.exists(): | |
gr.Examples( | |
examples=[ | |
[str(example_audio), str(example_image), None, True, True, True] | |
], | |
inputs=[audio_input, image_input, token_input, use_optimization, use_cache, use_parallel], | |
outputs=[video_output, status_output], | |
fn=process_talking_head_optimized | |
) | |
# イベントハンドラ | |
generate_btn.click( | |
fn=process_talking_head_optimized, | |
inputs=[audio_input, image_input, token_input, use_optimization, use_cache, use_parallel], | |
outputs=[video_output, status_output] | |
) | |
prepare_btn.click( | |
fn=prepare_avatar, | |
inputs=[avatar_image_input], | |
outputs=[prepare_output] | |
) | |
# キャッシュ管理関数 | |
def refresh_info(): | |
return f""" | |
### 現在の最適化設定 | |
{resolution_optimizer.get_optimization_summary()} | |
{gpu_optimizer.get_optimization_summary()} | |
### アバターキャッシュ情報 | |
{avatar_cache.get_cache_info()} | |
### 推論キャッシュ情報 | |
{inference_cache.get_cache_stats()} | |
### 並列処理情報 | |
{parallel_inference.get_performance_stats()} | |
""" | |
def clear_inference_cache(): | |
inference_cache.clear_cache() | |
return "✅ 推論キャッシュをクリアしました" | |
def clear_avatar_cache(): | |
avatar_cache.clear_cache() | |
return "✅ アバターキャッシュをクリアしました" | |
# キャッシュ管理イベント | |
refresh_btn.click( | |
fn=refresh_info, | |
outputs=[info_display] | |
) | |
clear_inference_cache_btn.click( | |
fn=clear_inference_cache, | |
outputs=[cache_status] | |
) | |
clear_avatar_cache_btn.click( | |
fn=clear_avatar_cache, | |
outputs=[cache_status] | |
) | |
cleanup_status_btn.click( | |
fn=lambda: get_cleanup_status(), | |
outputs=[cache_status] | |
) | |
if __name__ == "__main__": | |
# Cold Start最適化設定でGradioを起動 | |
launch_settings = cold_start_optimizer.optimize_gradio_settings() | |
# allowed_pathsを追加 | |
launch_settings['allowed_paths'] = [str(EXAMPLES_DIR), str(OUTPUT_DIR)] | |
demo.launch(**launch_settings) |