oKen38461's picture
出力ディレクトリの作成とUUIDを用いたファイル名生成を追加しました。また、クリーンアップ機能の初期化と状態確認ボタンをGradioインターフェースに追加しました。これにより、ファイル管理が改善され、ユーザーがクリーンアップ状況を確認できるようになりました。
55535c7
import gradio as gr
import os
import tempfile
import shutil
from pathlib import Path
from model_manager import ModelManager
# サンプルファイルのディレクトリを定義(絶対パスに解決)
EXAMPLES_DIR = (Path(__file__).parent / "example").resolve()
OUTPUT_DIR = (Path(__file__).parent / "output").resolve()
# 出力ディレクトリの作成
OUTPUT_DIR.mkdir(exist_ok=True)
# インポートエラーのデバッグ情報を表示
try:
import filetype
print("✅ filetype module imported successfully")
except ImportError as e:
print(f"⚠️ filetype import failed: {e}")
print("Using fallback file type detection")
# モデルの初期化
print("=== モデルの初期化開始 ===")
# PyTorchモデルを使用(TensorRTモデルは非常に大きいため)
USE_PYTORCH = True
model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH)
if not model_manager.setup_models():
raise RuntimeError("モデルのセットアップに失敗しました。")
# SDKの初期化
if USE_PYTORCH:
data_root = "./checkpoints/ditto_pytorch"
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl"
else:
data_root = "./checkpoints/ditto_trt_Ampere_Plus"
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl"
# SDK初期化のためのグローバル変数
SDK = None
try:
# モジュールをインポート
from stream_pipeline_offline import StreamSDK
from inference import run, seed_everything
SDK = StreamSDK(cfg_pkl, data_root)
print("✅ SDK初期化成功")
except Exception as e:
print(f"❌ SDK初期化エラー: {e}")
import traceback
traceback.print_exc()
raise
def process_talking_head(audio_file, source_image):
"""音声とソース画像からTalking Headビデオを生成"""
if audio_file is None:
return None, "音声ファイルをアップロードしてください。"
if source_image is None:
return None, "ソース画像をアップロードしてください。"
try:
# 出力ファイルの作成(出力ディレクトリ内)
import uuid
output_filename = f"{uuid.uuid4()}.mp4"
output_path = str(OUTPUT_DIR / output_filename)
# 処理実行
print(f"処理開始: audio={audio_file}, image={source_image}")
seed_everything(1024)
run(SDK, audio_file, source_image, output_path)
# 結果の確認
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
return output_path, "✅ 処理が完了しました!"
else:
return None, "❌ 処理に失敗しました。出力ファイルが生成されませんでした。"
except Exception as e:
import traceback
error_msg = f"❌ エラーが発生しました: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return None, error_msg
# Gradio UI
with gr.Blocks(title="DittoTalkingHead") as demo:
gr.Markdown("""
# DittoTalkingHead - Talking Head Generation
音声とソース画像から、リアルなTalking Headビデオを生成します。
## 使い方
1. **音声ファイル**(WAV形式)をアップロード
2. **ソース画像**(PNG/JPG形式)をアップロード
3. **生成**ボタンをクリック
⚠️ 初回実行時は、モデルのダウンロードのため時間がかかります(約2.5GB)。
### 技術仕様
- **モデル**: DittoTalkingHead (PyTorch版)
- **GPU**: NVIDIA A100推奨
- **モデル提供**: [digital-avatar/ditto-talkinghead](https://huggingface.co/digital-avatar/ditto-talkinghead)
""")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="音声ファイル (WAV)",
type="filepath"
)
image_input = gr.Image(
label="ソース画像",
type="filepath"
)
generate_btn = gr.Button("生成", variant="primary")
with gr.Column():
video_output = gr.Video(
label="生成されたビデオ"
)
status_output = gr.Textbox(
label="ステータス",
lines=3
)
# サンプル
example_audio = EXAMPLES_DIR / "audio.wav"
example_image = EXAMPLES_DIR / "image.png"
if example_audio.exists() and example_image.exists():
gr.Examples(
examples=[
[str(example_audio), str(example_image)]
],
inputs=[audio_input, image_input],
outputs=[video_output, status_output],
fn=process_talking_head,
cache_examples=True
)
# イベントハンドラ
generate_btn.click(
fn=process_talking_head,
inputs=[audio_input, image_input],
outputs=[video_output, status_output]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
allowed_paths=[str(EXAMPLES_DIR), str(OUTPUT_DIR)]
)