File size: 5,252 Bytes
43f5a2b
 
 
 
 
 
f4998a2
146df9e
 
55535c7
 
 
 
0a00365
f4998a2
 
 
 
 
 
 
 
43f5a2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b17b0c
 
 
43f5a2b
f4998a2
 
 
 
43f5a2b
 
 
 
f4998a2
 
43f5a2b
 
 
 
 
 
 
 
 
 
 
 
55535c7
 
 
 
43f5a2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146df9e
 
 
 
 
 
 
 
 
 
 
 
 
43f5a2b
 
 
 
 
 
 
 
 
 
 
 
0a00365
55535c7
43f5a2b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import gradio as gr
import os
import tempfile
import shutil
from pathlib import Path
from model_manager import ModelManager

# サンプルファイルのディレクトリを定義(絶対パスに解決)
EXAMPLES_DIR = (Path(__file__).parent / "example").resolve()
OUTPUT_DIR = (Path(__file__).parent / "output").resolve()

# 出力ディレクトリの作成
OUTPUT_DIR.mkdir(exist_ok=True)

# インポートエラーのデバッグ情報を表示
try:
    import filetype
    print("✅ filetype module imported successfully")
except ImportError as e:
    print(f"⚠️ filetype import failed: {e}")
    print("Using fallback file type detection")

# モデルの初期化
print("=== モデルの初期化開始 ===")

# PyTorchモデルを使用(TensorRTモデルは非常に大きいため)
USE_PYTORCH = True

model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH)
if not model_manager.setup_models():
    raise RuntimeError("モデルのセットアップに失敗しました。")

# SDKの初期化
if USE_PYTORCH:
    data_root = "./checkpoints/ditto_pytorch"
    cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl"
else:
    data_root = "./checkpoints/ditto_trt_Ampere_Plus"
    cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl"

# SDK初期化のためのグローバル変数
SDK = None

try:
    # モジュールをインポート
    from stream_pipeline_offline import StreamSDK
    from inference import run, seed_everything
    
    SDK = StreamSDK(cfg_pkl, data_root)
    print("✅ SDK初期化成功")
except Exception as e:
    print(f"❌ SDK初期化エラー: {e}")
    import traceback
    traceback.print_exc()
    raise

def process_talking_head(audio_file, source_image):
    """音声とソース画像からTalking Headビデオを生成"""
    
    if audio_file is None:
        return None, "音声ファイルをアップロードしてください。"
    
    if source_image is None:
        return None, "ソース画像をアップロードしてください。"
    
    try:
        # 出力ファイルの作成(出力ディレクトリ内)
        import uuid
        output_filename = f"{uuid.uuid4()}.mp4"
        output_path = str(OUTPUT_DIR / output_filename)
        
        # 処理実行
        print(f"処理開始: audio={audio_file}, image={source_image}")
        seed_everything(1024)
        run(SDK, audio_file, source_image, output_path)
        
        # 結果の確認
        if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
            return output_path, "✅ 処理が完了しました!"
        else:
            return None, "❌ 処理に失敗しました。出力ファイルが生成されませんでした。"
            
    except Exception as e:
        import traceback
        error_msg = f"❌ エラーが発生しました: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        return None, error_msg

# Gradio UI
with gr.Blocks(title="DittoTalkingHead") as demo:
    gr.Markdown("""
    # DittoTalkingHead - Talking Head Generation
    
    音声とソース画像から、リアルなTalking Headビデオを生成します。
    
    ## 使い方
    1. **音声ファイル**(WAV形式)をアップロード
    2. **ソース画像**(PNG/JPG形式)をアップロード
    3. **生成**ボタンをクリック
    
    ⚠️ 初回実行時は、モデルのダウンロードのため時間がかかります(約2.5GB)。
    
    ### 技術仕様
    - **モデル**: DittoTalkingHead (PyTorch版)
    - **GPU**: NVIDIA A100推奨
    - **モデル提供**: [digital-avatar/ditto-talkinghead](https://huggingface.co/digital-avatar/ditto-talkinghead)
    """)
    
    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(
                label="音声ファイル (WAV)",
                type="filepath"
            )
            image_input = gr.Image(
                label="ソース画像",
                type="filepath"
            )
            generate_btn = gr.Button("生成", variant="primary")
            
        with gr.Column():
            video_output = gr.Video(
                label="生成されたビデオ"
            )
            status_output = gr.Textbox(
                label="ステータス",
                lines=3
            )
    
    # サンプル
    example_audio = EXAMPLES_DIR / "audio.wav"
    example_image = EXAMPLES_DIR / "image.png"
    
    if example_audio.exists() and example_image.exists():
        gr.Examples(
            examples=[
                [str(example_audio), str(example_image)]
            ],
            inputs=[audio_input, image_input],
            outputs=[video_output, status_output],
            fn=process_talking_head,
            cache_examples=True
        )
    
    # イベントハンドラ
    generate_btn.click(
        fn=process_talking_head,
        inputs=[audio_input, image_input],
        outputs=[video_output, status_output]
    )

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        allowed_paths=[str(EXAMPLES_DIR), str(OUTPUT_DIR)]
    )