import os, tempfile, queue, threading, time, numpy as np, soundfile as sf import gradio as gr from stream_pipeline_offline import StreamSDK import torch from PIL import Image from pathlib import Path import cv2 # モデル設定 CFG_PKL = "checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" DATA_ROOT = "checkpoints/ditto_pytorch" # サンプルファイルのディレクトリ EXAMPLES_DIR = (Path(__file__).parent / "example").resolve() OUTPUT_DIR = (Path(__file__).parent / "output").resolve() # 出力ディレクトリの作成 OUTPUT_DIR.mkdir(exist_ok=True) # グローバルで一度だけロード(concurrency_count=1 前提) sdk: StreamSDK | None = None def init_sdk(): global sdk if sdk is None: sdk = StreamSDK(CFG_PKL, DATA_ROOT) return sdk # 音声チャンクサイズ(秒) CHUNK_SEC = 0.20 # 16000*0.20 = 3200 sample ≒ 5 フレーム def generator(mic, src_img): """ Gradio 生成関数 mic : (sr, np.ndarray) 形式 (Gradio Audio streaming=True) src_img : 画像ファイルパス Yields : PIL.Image (現在フレーム) または (最後に mp4) """ if mic is None: yield None, None, "マイク入力を開始してください" return if src_img is None: yield None, None, "ソース画像をアップロードしてください" return try: sr, wav_full = mic sdk = init_sdk() # setup: online_mode=True でストリーミング import uuid output_filename = f"{uuid.uuid4()}.mp4" tmp_out = str(OUTPUT_DIR / output_filename) sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024) N_total = int(np.ceil(len(wav_full) / sr * 20)) # 概算フレーム数 sdk.setup_Nd(N_total) # 処理開始時刻 start_time = time.time() frame_count = 0 # 音声を CHUNK_SEC ごとに送り込む hop = int(sr * CHUNK_SEC) for start_idx in range(0, len(wav_full), hop): chunk = wav_full[start_idx : start_idx + hop] if len(chunk) < hop: chunk = np.pad(chunk, (0, hop - len(chunk))) sdk.run_chunk(chunk) # 直近で書き込まれたフレームをキューから取得 frames_processed = 0 while sdk.writer_queue.qsize() > 0 and frames_processed < 5: try: frame = sdk.writer_queue.get_nowait() if frame is not None: # numpy array (H, W, 3) を PIL Image に変換 pil_frame = Image.fromarray(frame) frame_count += 1 elapsed = time.time() - start_time fps = frame_count / elapsed if elapsed > 0 else 0 yield pil_frame, None, f"処理中... フレーム: {frame_count}, FPS: {fps:.1f}" frames_processed += 1 except queue.Empty: break # 少し待機(CPU負荷調整) time.sleep(0.01) # 残りのフレームを処理 print("音声チャンクの送信完了、残りフレームを処理中...") timeout_count = 0 while timeout_count < 50: # 最大5秒待機 if sdk.writer_queue.qsize() > 0: try: frame = sdk.writer_queue.get_nowait() if frame is not None: pil_frame = Image.fromarray(frame) frame_count += 1 elapsed = time.time() - start_time fps = frame_count / elapsed if elapsed > 0 else 0 yield pil_frame, None, f"処理中... フレーム: {frame_count}, FPS: {fps:.1f}" timeout_count = 0 except queue.Empty: time.sleep(0.1) timeout_count += 1 else: time.sleep(0.1) timeout_count += 1 # SDKを閉じて最終的なMP4を生成 print("SDKを閉じて最終的なMP4を生成中...") sdk.close() # ワーカー join & mp4 結合 # 処理完了 elapsed_total = time.time() - start_time yield None, gr.Video(tmp_out), f"✅ 完了! 総フレーム数: {frame_count}, 処理時間: {elapsed_total:.1f}秒" except Exception as e: import traceback error_msg = f"❌ エラー: {str(e)}\n{traceback.format_exc()}" print(error_msg) yield None, None, error_msg # Gradio UI with gr.Blocks(title="DittoTalkingHead Streaming") as demo: gr.Markdown(""" # DittoTalkingHead - ストリーミング版 音声をリアルタイムで処理し、生成されたフレームを逐次表示します。 ## 使い方 1. **ソース画像**(PNG/JPG形式)をアップロード 2. **Start**ボタンをクリックしてマイク録音開始 3. 録音中、ライブフレームが更新されます 4. 録音停止後、最終的なMP4が生成されます """) with gr.Row(): with gr.Column(): img_in = gr.Image( type="filepath", label="ソース画像 / Source Image", value=str(EXAMPLES_DIR / "reference.png") if (EXAMPLES_DIR / "reference.png").exists() else None ) mic_in = gr.Audio( sources=["microphone"], streaming=True, label="マイク入力 (16 kHz)", format="wav" ) with gr.Column(): live_img = gr.Image(label="ライブフレーム", type="pil") final_mp4 = gr.Video(label="最終結果 (MP4)") status_text = gr.Textbox(label="ステータス", value="待機中...") btn = gr.Button("Start Streaming", variant="primary") # ストリーミング処理を開始 btn.click( fn=generator, inputs=[mic_in, img_in], outputs=[live_img, final_mp4, status_text], stream_every=0.1 # 100msごとに更新 ) # サンプル if EXAMPLES_DIR.exists(): gr.Examples( examples=[ [str(EXAMPLES_DIR / "reference.png")] ], inputs=[img_in], label="サンプル画像" ) # 起動設定 if __name__ == "__main__": # GPU最適化設定 if torch.cuda.is_available(): torch.cuda.empty_cache() torch.backends.cudnn.benchmark = True # 環境変数設定 os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" print("=== DittoTalkingHead ストリーミング版 起動 ===") print(f"- チャンクサイズ: {CHUNK_SEC}秒") print(f"- 最大解像度: 1024px") print(f"- GPU: {'利用可能' if torch.cuda.is_available() else '利用不可'}") # モデルの事前ロード print("モデルを事前ロード中...") init_sdk() print("✅ モデルロード完了") demo.queue(concurrency_count=1, max_size=8).launch( server_name="0.0.0.0", server_port=7860, share=False, allowed_paths=[str(EXAMPLES_DIR), str(OUTPUT_DIR)] )