File size: 7,395 Bytes
d9a2a3d
 
 
 
 
 
 
 
 
 
 
 
 
 
55535c7
 
 
 
d9a2a3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55535c7
 
 
d9a2a3d
2089ecf
d9a2a3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55535c7
 
d9a2a3d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import os, tempfile, queue, threading, time, numpy as np, soundfile as sf
import gradio as gr
from stream_pipeline_offline import StreamSDK
import torch
from PIL import Image
from pathlib import Path
import cv2

# モデル設定
CFG_PKL = "checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl"
DATA_ROOT = "checkpoints/ditto_pytorch"

# サンプルファイルのディレクトリ
EXAMPLES_DIR = (Path(__file__).parent / "example").resolve()
OUTPUT_DIR = (Path(__file__).parent / "output").resolve()

# 出力ディレクトリの作成
OUTPUT_DIR.mkdir(exist_ok=True)

# グローバルで一度だけロード(concurrency_count=1 前提)
sdk: StreamSDK | None = None
def init_sdk():
    global sdk
    if sdk is None:
        sdk = StreamSDK(CFG_PKL, DATA_ROOT)
    return sdk

# 音声チャンクサイズ(秒)
CHUNK_SEC = 0.20  # 16000*0.20 = 3200 sample ≒ 5 フレーム

def generator(mic, src_img):
    """
    Gradio 生成関数
    mic     : (sr, np.ndarray) 形式 (Gradio Audio streaming=True)
    src_img : 画像ファイルパス
    Yields  : PIL.Image (現在フレーム) または (最後に mp4)
    """
    if mic is None:
        yield None, None, "マイク入力を開始してください"
        return
    
    if src_img is None:
        yield None, None, "ソース画像をアップロードしてください"
        return
    
    try:
        sr, wav_full = mic
        sdk = init_sdk()
        
        # setup: online_mode=True でストリーミング
        import uuid
        output_filename = f"{uuid.uuid4()}.mp4"
        tmp_out = str(OUTPUT_DIR / output_filename)
        sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024)
        N_total = int(np.ceil(len(wav_full) / sr * 20))  # 概算フレーム数
        sdk.setup_Nd(N_total)
        
        # 処理開始時刻
        start_time = time.time()
        frame_count = 0
        
        # 音声を CHUNK_SEC ごとに送り込む
        hop = int(sr * CHUNK_SEC)
        for start_idx in range(0, len(wav_full), hop):
            chunk = wav_full[start_idx : start_idx + hop]
            if len(chunk) < hop:
                chunk = np.pad(chunk, (0, hop - len(chunk)))
            sdk.run_chunk(chunk)
            
            # 直近で書き込まれたフレームをキューから取得
            frames_processed = 0
            while sdk.writer_queue.qsize() > 0 and frames_processed < 5:
                try:
                    frame = sdk.writer_queue.get_nowait()
                    if frame is not None:
                        # numpy array (H, W, 3) を PIL Image に変換
                        pil_frame = Image.fromarray(frame)
                        frame_count += 1
                        elapsed = time.time() - start_time
                        fps = frame_count / elapsed if elapsed > 0 else 0
                        yield pil_frame, None, f"処理中... フレーム: {frame_count}, FPS: {fps:.1f}"
                    frames_processed += 1
                except queue.Empty:
                    break
            
            # 少し待機(CPU負荷調整)
            time.sleep(0.01)
        
        # 残りのフレームを処理
        print("音声チャンクの送信完了、残りフレームを処理中...")
        timeout_count = 0
        while timeout_count < 50:  # 最大5秒待機
            if sdk.writer_queue.qsize() > 0:
                try:
                    frame = sdk.writer_queue.get_nowait()
                    if frame is not None:
                        pil_frame = Image.fromarray(frame)
                        frame_count += 1
                        elapsed = time.time() - start_time
                        fps = frame_count / elapsed if elapsed > 0 else 0
                        yield pil_frame, None, f"処理中... フレーム: {frame_count}, FPS: {fps:.1f}"
                    timeout_count = 0
                except queue.Empty:
                    time.sleep(0.1)
                    timeout_count += 1
            else:
                time.sleep(0.1)
                timeout_count += 1
        
        # SDKを閉じて最終的なMP4を生成
        print("SDKを閉じて最終的なMP4を生成中...")
        sdk.close()  # ワーカー join & mp4 結合
        
        # 処理完了
        elapsed_total = time.time() - start_time
        yield None, gr.Video(tmp_out), f"✅ 完了! 総フレーム数: {frame_count}, 処理時間: {elapsed_total:.1f}秒"
        
    except Exception as e:
        import traceback
        error_msg = f"❌ エラー: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        yield None, None, error_msg

# Gradio UI
with gr.Blocks(title="DittoTalkingHead Streaming") as demo:
    gr.Markdown("""
    # DittoTalkingHead - ストリーミング版
    
    音声をリアルタイムで処理し、生成されたフレームを逐次表示します。
    
    ## 使い方
    1. **ソース画像**(PNG/JPG形式)をアップロード
    2. **Start**ボタンをクリックしてマイク録音開始
    3. 録音中、ライブフレームが更新されます
    4. 録音停止後、最終的なMP4が生成されます
    """)
    
    with gr.Row():
        with gr.Column():
            img_in = gr.Image(
                type="filepath", 
                label="ソース画像 / Source Image",
                value=str(EXAMPLES_DIR / "reference.png") if (EXAMPLES_DIR / "reference.png").exists() else None
            )
            mic_in = gr.Audio(
                sources=["microphone"],
                streaming=True,
                label="マイク入力 (16 kHz)",
                format="wav"
            )
        
        with gr.Column():
            live_img = gr.Image(label="ライブフレーム", type="pil")
            final_mp4 = gr.Video(label="最終結果 (MP4)")
            status_text = gr.Textbox(label="ステータス", value="待機中...")
    
    btn = gr.Button("Start Streaming", variant="primary")
    
    # ストリーミング処理を開始
    btn.click(
        fn=generator,
        inputs=[mic_in, img_in],
        outputs=[live_img, final_mp4, status_text],
        stream_every=0.1  # 100msごとに更新
    )
    
    # サンプル
    if EXAMPLES_DIR.exists():
        gr.Examples(
            examples=[
                [str(EXAMPLES_DIR / "reference.png")]
            ],
            inputs=[img_in],
            label="サンプル画像"
        )

# 起動設定
if __name__ == "__main__":
    # GPU最適化設定
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.backends.cudnn.benchmark = True
    
    # 環境変数設定
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
    
    print("=== DittoTalkingHead ストリーミング版 起動 ===")
    print(f"- チャンクサイズ: {CHUNK_SEC}秒")
    print(f"- 最大解像度: 1024px")
    print(f"- GPU: {'利用可能' if torch.cuda.is_available() else '利用不可'}")
    
    # モデルの事前ロード
    print("モデルを事前ロード中...")
    init_sdk()
    print("✅ モデルロード完了")
    
    demo.queue(concurrency_count=1, max_size=8).launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        allowed_paths=[str(EXAMPLES_DIR), str(OUTPUT_DIR)]
    )