Spaces:
Runtime error
Runtime error
import os, tempfile, queue, threading, time, numpy as np, soundfile as sf | |
import gradio as gr | |
from stream_pipeline_offline import StreamSDK | |
import torch | |
from PIL import Image | |
from pathlib import Path | |
import cv2 | |
# モデル設定 | |
CFG_PKL = "checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" | |
DATA_ROOT = "checkpoints/ditto_pytorch" | |
# サンプルファイルのディレクトリ | |
EXAMPLES_DIR = (Path(__file__).parent / "example").resolve() | |
OUTPUT_DIR = (Path(__file__).parent / "output").resolve() | |
# 出力ディレクトリの作成 | |
OUTPUT_DIR.mkdir(exist_ok=True) | |
# グローバルで一度だけロード(concurrency_count=1 前提) | |
sdk: StreamSDK | None = None | |
def init_sdk(): | |
global sdk | |
if sdk is None: | |
sdk = StreamSDK(CFG_PKL, DATA_ROOT) | |
return sdk | |
# 音声チャンクサイズ(秒) | |
CHUNK_SEC = 0.20 # 16000*0.20 = 3200 sample ≒ 5 フレーム | |
def generator(mic, src_img): | |
""" | |
Gradio 生成関数 | |
mic : (sr, np.ndarray) 形式 (Gradio Audio streaming=True) | |
src_img : 画像ファイルパス | |
Yields : PIL.Image (現在フレーム) または (最後に mp4) | |
""" | |
if mic is None: | |
yield None, None, "マイク入力を開始してください" | |
return | |
if src_img is None: | |
yield None, None, "ソース画像をアップロードしてください" | |
return | |
try: | |
sr, wav_full = mic | |
sdk = init_sdk() | |
# setup: online_mode=True でストリーミング | |
import uuid | |
output_filename = f"{uuid.uuid4()}.mp4" | |
tmp_out = str(OUTPUT_DIR / output_filename) | |
sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024) | |
N_total = int(np.ceil(len(wav_full) / sr * 20)) # 概算フレーム数 | |
sdk.setup_Nd(N_total) | |
# 処理開始時刻 | |
start_time = time.time() | |
frame_count = 0 | |
# 音声を CHUNK_SEC ごとに送り込む | |
hop = int(sr * CHUNK_SEC) | |
for start_idx in range(0, len(wav_full), hop): | |
chunk = wav_full[start_idx : start_idx + hop] | |
if len(chunk) < hop: | |
chunk = np.pad(chunk, (0, hop - len(chunk))) | |
sdk.run_chunk(chunk) | |
# 直近で書き込まれたフレームをキューから取得 | |
frames_processed = 0 | |
while sdk.writer_queue.qsize() > 0 and frames_processed < 5: | |
try: | |
frame = sdk.writer_queue.get_nowait() | |
if frame is not None: | |
# numpy array (H, W, 3) を PIL Image に変換 | |
pil_frame = Image.fromarray(frame) | |
frame_count += 1 | |
elapsed = time.time() - start_time | |
fps = frame_count / elapsed if elapsed > 0 else 0 | |
yield pil_frame, None, f"処理中... フレーム: {frame_count}, FPS: {fps:.1f}" | |
frames_processed += 1 | |
except queue.Empty: | |
break | |
# 少し待機(CPU負荷調整) | |
time.sleep(0.01) | |
# 残りのフレームを処理 | |
print("音声チャンクの送信完了、残りフレームを処理中...") | |
timeout_count = 0 | |
while timeout_count < 50: # 最大5秒待機 | |
if sdk.writer_queue.qsize() > 0: | |
try: | |
frame = sdk.writer_queue.get_nowait() | |
if frame is not None: | |
pil_frame = Image.fromarray(frame) | |
frame_count += 1 | |
elapsed = time.time() - start_time | |
fps = frame_count / elapsed if elapsed > 0 else 0 | |
yield pil_frame, None, f"処理中... フレーム: {frame_count}, FPS: {fps:.1f}" | |
timeout_count = 0 | |
except queue.Empty: | |
time.sleep(0.1) | |
timeout_count += 1 | |
else: | |
time.sleep(0.1) | |
timeout_count += 1 | |
# SDKを閉じて最終的なMP4を生成 | |
print("SDKを閉じて最終的なMP4を生成中...") | |
sdk.close() # ワーカー join & mp4 結合 | |
# 処理完了 | |
elapsed_total = time.time() - start_time | |
yield None, gr.Video(tmp_out), f"✅ 完了! 総フレーム数: {frame_count}, 処理時間: {elapsed_total:.1f}秒" | |
except Exception as e: | |
import traceback | |
error_msg = f"❌ エラー: {str(e)}\n{traceback.format_exc()}" | |
print(error_msg) | |
yield None, None, error_msg | |
# Gradio UI | |
with gr.Blocks(title="DittoTalkingHead Streaming") as demo: | |
gr.Markdown(""" | |
# DittoTalkingHead - ストリーミング版 | |
音声をリアルタイムで処理し、生成されたフレームを逐次表示します。 | |
## 使い方 | |
1. **ソース画像**(PNG/JPG形式)をアップロード | |
2. **Start**ボタンをクリックしてマイク録音開始 | |
3. 録音中、ライブフレームが更新されます | |
4. 録音停止後、最終的なMP4が生成されます | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
img_in = gr.Image( | |
type="filepath", | |
label="ソース画像 / Source Image", | |
value=str(EXAMPLES_DIR / "reference.png") if (EXAMPLES_DIR / "reference.png").exists() else None | |
) | |
mic_in = gr.Audio( | |
sources=["microphone"], | |
streaming=True, | |
label="マイク入力 (16 kHz)", | |
format="wav" | |
) | |
with gr.Column(): | |
live_img = gr.Image(label="ライブフレーム", type="pil") | |
final_mp4 = gr.Video(label="最終結果 (MP4)") | |
status_text = gr.Textbox(label="ステータス", value="待機中...") | |
btn = gr.Button("Start Streaming", variant="primary") | |
# ストリーミング処理を開始 | |
btn.click( | |
fn=generator, | |
inputs=[mic_in, img_in], | |
outputs=[live_img, final_mp4, status_text], | |
stream_every=0.1 # 100msごとに更新 | |
) | |
# サンプル | |
if EXAMPLES_DIR.exists(): | |
gr.Examples( | |
examples=[ | |
[str(EXAMPLES_DIR / "reference.png")] | |
], | |
inputs=[img_in], | |
label="サンプル画像" | |
) | |
# 起動設定 | |
if __name__ == "__main__": | |
# GPU最適化設定 | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.backends.cudnn.benchmark = True | |
# 環境変数設定 | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" | |
print("=== DittoTalkingHead ストリーミング版 起動 ===") | |
print(f"- チャンクサイズ: {CHUNK_SEC}秒") | |
print(f"- 最大解像度: 1024px") | |
print(f"- GPU: {'利用可能' if torch.cuda.is_available() else '利用不可'}") | |
# モデルの事前ロード | |
print("モデルを事前ロード中...") | |
init_sdk() | |
print("✅ モデルロード完了") | |
demo.queue(concurrency_count=1, max_size=8).launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
allowed_paths=[str(EXAMPLES_DIR), str(OUTPUT_DIR)] | |
) |