Spaces:
Runtime error
Runtime error
File size: 7,395 Bytes
d9a2a3d 55535c7 d9a2a3d 55535c7 d9a2a3d 2089ecf d9a2a3d 55535c7 d9a2a3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import os, tempfile, queue, threading, time, numpy as np, soundfile as sf
import gradio as gr
from stream_pipeline_offline import StreamSDK
import torch
from PIL import Image
from pathlib import Path
import cv2
# モデル設定
CFG_PKL = "checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl"
DATA_ROOT = "checkpoints/ditto_pytorch"
# サンプルファイルのディレクトリ
EXAMPLES_DIR = (Path(__file__).parent / "example").resolve()
OUTPUT_DIR = (Path(__file__).parent / "output").resolve()
# 出力ディレクトリの作成
OUTPUT_DIR.mkdir(exist_ok=True)
# グローバルで一度だけロード(concurrency_count=1 前提)
sdk: StreamSDK | None = None
def init_sdk():
global sdk
if sdk is None:
sdk = StreamSDK(CFG_PKL, DATA_ROOT)
return sdk
# 音声チャンクサイズ(秒)
CHUNK_SEC = 0.20 # 16000*0.20 = 3200 sample ≒ 5 フレーム
def generator(mic, src_img):
"""
Gradio 生成関数
mic : (sr, np.ndarray) 形式 (Gradio Audio streaming=True)
src_img : 画像ファイルパス
Yields : PIL.Image (現在フレーム) または (最後に mp4)
"""
if mic is None:
yield None, None, "マイク入力を開始してください"
return
if src_img is None:
yield None, None, "ソース画像をアップロードしてください"
return
try:
sr, wav_full = mic
sdk = init_sdk()
# setup: online_mode=True でストリーミング
import uuid
output_filename = f"{uuid.uuid4()}.mp4"
tmp_out = str(OUTPUT_DIR / output_filename)
sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024)
N_total = int(np.ceil(len(wav_full) / sr * 20)) # 概算フレーム数
sdk.setup_Nd(N_total)
# 処理開始時刻
start_time = time.time()
frame_count = 0
# 音声を CHUNK_SEC ごとに送り込む
hop = int(sr * CHUNK_SEC)
for start_idx in range(0, len(wav_full), hop):
chunk = wav_full[start_idx : start_idx + hop]
if len(chunk) < hop:
chunk = np.pad(chunk, (0, hop - len(chunk)))
sdk.run_chunk(chunk)
# 直近で書き込まれたフレームをキューから取得
frames_processed = 0
while sdk.writer_queue.qsize() > 0 and frames_processed < 5:
try:
frame = sdk.writer_queue.get_nowait()
if frame is not None:
# numpy array (H, W, 3) を PIL Image に変換
pil_frame = Image.fromarray(frame)
frame_count += 1
elapsed = time.time() - start_time
fps = frame_count / elapsed if elapsed > 0 else 0
yield pil_frame, None, f"処理中... フレーム: {frame_count}, FPS: {fps:.1f}"
frames_processed += 1
except queue.Empty:
break
# 少し待機(CPU負荷調整)
time.sleep(0.01)
# 残りのフレームを処理
print("音声チャンクの送信完了、残りフレームを処理中...")
timeout_count = 0
while timeout_count < 50: # 最大5秒待機
if sdk.writer_queue.qsize() > 0:
try:
frame = sdk.writer_queue.get_nowait()
if frame is not None:
pil_frame = Image.fromarray(frame)
frame_count += 1
elapsed = time.time() - start_time
fps = frame_count / elapsed if elapsed > 0 else 0
yield pil_frame, None, f"処理中... フレーム: {frame_count}, FPS: {fps:.1f}"
timeout_count = 0
except queue.Empty:
time.sleep(0.1)
timeout_count += 1
else:
time.sleep(0.1)
timeout_count += 1
# SDKを閉じて最終的なMP4を生成
print("SDKを閉じて最終的なMP4を生成中...")
sdk.close() # ワーカー join & mp4 結合
# 処理完了
elapsed_total = time.time() - start_time
yield None, gr.Video(tmp_out), f"✅ 完了! 総フレーム数: {frame_count}, 処理時間: {elapsed_total:.1f}秒"
except Exception as e:
import traceback
error_msg = f"❌ エラー: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
yield None, None, error_msg
# Gradio UI
with gr.Blocks(title="DittoTalkingHead Streaming") as demo:
gr.Markdown("""
# DittoTalkingHead - ストリーミング版
音声をリアルタイムで処理し、生成されたフレームを逐次表示します。
## 使い方
1. **ソース画像**(PNG/JPG形式)をアップロード
2. **Start**ボタンをクリックしてマイク録音開始
3. 録音中、ライブフレームが更新されます
4. 録音停止後、最終的なMP4が生成されます
""")
with gr.Row():
with gr.Column():
img_in = gr.Image(
type="filepath",
label="ソース画像 / Source Image",
value=str(EXAMPLES_DIR / "reference.png") if (EXAMPLES_DIR / "reference.png").exists() else None
)
mic_in = gr.Audio(
sources=["microphone"],
streaming=True,
label="マイク入力 (16 kHz)",
format="wav"
)
with gr.Column():
live_img = gr.Image(label="ライブフレーム", type="pil")
final_mp4 = gr.Video(label="最終結果 (MP4)")
status_text = gr.Textbox(label="ステータス", value="待機中...")
btn = gr.Button("Start Streaming", variant="primary")
# ストリーミング処理を開始
btn.click(
fn=generator,
inputs=[mic_in, img_in],
outputs=[live_img, final_mp4, status_text],
stream_every=0.1 # 100msごとに更新
)
# サンプル
if EXAMPLES_DIR.exists():
gr.Examples(
examples=[
[str(EXAMPLES_DIR / "reference.png")]
],
inputs=[img_in],
label="サンプル画像"
)
# 起動設定
if __name__ == "__main__":
# GPU最適化設定
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
# 環境変数設定
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
print("=== DittoTalkingHead ストリーミング版 起動 ===")
print(f"- チャンクサイズ: {CHUNK_SEC}秒")
print(f"- 最大解像度: 1024px")
print(f"- GPU: {'利用可能' if torch.cuda.is_available() else '利用不可'}")
# モデルの事前ロード
print("モデルを事前ロード中...")
init_sdk()
print("✅ モデルロード完了")
demo.queue(concurrency_count=1, max_size=8).launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
allowed_paths=[str(EXAMPLES_DIR), str(OUTPUT_DIR)]
) |