import gradio as gr import os import tempfile import shutil from pathlib import Path from model_manager import ModelManager # サンプルファイルのディレクトリを定義(絶対パスに解決) EXAMPLES_DIR = (Path(__file__).parent / "example").resolve() OUTPUT_DIR = (Path(__file__).parent / "output").resolve() # 出力ディレクトリの作成 OUTPUT_DIR.mkdir(exist_ok=True) # インポートエラーのデバッグ情報を表示 try: import filetype print("✅ filetype module imported successfully") except ImportError as e: print(f"⚠️ filetype import failed: {e}") print("Using fallback file type detection") # モデルの初期化 print("=== モデルの初期化開始 ===") # PyTorchモデルを使用(TensorRTモデルは非常に大きいため) USE_PYTORCH = True model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH) if not model_manager.setup_models(): raise RuntimeError("モデルのセットアップに失敗しました。") # SDKの初期化 if USE_PYTORCH: data_root = "./checkpoints/ditto_pytorch" cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" else: data_root = "./checkpoints/ditto_trt_Ampere_Plus" cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl" # SDK初期化のためのグローバル変数 SDK = None try: # モジュールをインポート from stream_pipeline_offline import StreamSDK from inference import run, seed_everything SDK = StreamSDK(cfg_pkl, data_root) print("✅ SDK初期化成功") except Exception as e: print(f"❌ SDK初期化エラー: {e}") import traceback traceback.print_exc() raise def process_talking_head(audio_file, source_image): """音声とソース画像からTalking Headビデオを生成""" if audio_file is None: return None, "音声ファイルをアップロードしてください。" if source_image is None: return None, "ソース画像をアップロードしてください。" try: # 出力ファイルの作成(出力ディレクトリ内) import uuid output_filename = f"{uuid.uuid4()}.mp4" output_path = str(OUTPUT_DIR / output_filename) # 処理実行 print(f"処理開始: audio={audio_file}, image={source_image}") seed_everything(1024) run(SDK, audio_file, source_image, output_path) # 結果の確認 if os.path.exists(output_path) and os.path.getsize(output_path) > 0: return output_path, "✅ 処理が完了しました!" else: return None, "❌ 処理に失敗しました。出力ファイルが生成されませんでした。" except Exception as e: import traceback error_msg = f"❌ エラーが発生しました: {str(e)}\n{traceback.format_exc()}" print(error_msg) return None, error_msg # Gradio UI with gr.Blocks(title="DittoTalkingHead") as demo: gr.Markdown(""" # DittoTalkingHead - Talking Head Generation 音声とソース画像から、リアルなTalking Headビデオを生成します。 ## 使い方 1. **音声ファイル**(WAV形式)をアップロード 2. **ソース画像**(PNG/JPG形式)をアップロード 3. **生成**ボタンをクリック ⚠️ 初回実行時は、モデルのダウンロードのため時間がかかります(約2.5GB)。 ### 技術仕様 - **モデル**: DittoTalkingHead (PyTorch版) - **GPU**: NVIDIA A100推奨 - **モデル提供**: [digital-avatar/ditto-talkinghead](https://huggingface.co/digital-avatar/ditto-talkinghead) """) with gr.Row(): with gr.Column(): audio_input = gr.Audio( label="音声ファイル (WAV)", type="filepath" ) image_input = gr.Image( label="ソース画像", type="filepath" ) generate_btn = gr.Button("生成", variant="primary") with gr.Column(): video_output = gr.Video( label="生成されたビデオ" ) status_output = gr.Textbox( label="ステータス", lines=3 ) # サンプル example_audio = EXAMPLES_DIR / "audio.wav" example_image = EXAMPLES_DIR / "image.png" if example_audio.exists() and example_image.exists(): gr.Examples( examples=[ [str(example_audio), str(example_image)] ], inputs=[audio_input, image_input], outputs=[video_output, status_output], fn=process_talking_head, cache_examples=True ) # イベントハンドラ generate_btn.click( fn=process_talking_head, inputs=[audio_input, image_input], outputs=[video_output, status_output] ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False, allowed_paths=[str(EXAMPLES_DIR), str(OUTPUT_DIR)] )