Spaces:
Runtime error
Runtime error
File size: 4,278 Bytes
43f5a2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import gradio as gr
import os
import tempfile
import shutil
from pathlib import Path
from model_manager import ModelManager
from stream_pipeline_offline import StreamSDK
from inference import run, seed_everything
# モデルの初期化
print("=== モデルの初期化開始 ===")
# PyTorchモデルを使用(TensorRTモデルは非常に大きいため)
USE_PYTORCH = True
model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH)
if not model_manager.setup_models():
raise RuntimeError("モデルのセットアップに失敗しました。")
# SDKの初期化
if USE_PYTORCH:
data_root = "./checkpoints/ditto_pytorch"
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl"
else:
data_root = "./checkpoints/ditto_trt_Ampere_Plus"
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl"
try:
SDK = StreamSDK(cfg_pkl, data_root)
print("✅ SDK初期化成功")
except Exception as e:
print(f"❌ SDK初期化エラー: {e}")
raise
def process_talking_head(audio_file, source_image):
"""音声とソース画像からTalking Headビデオを生成"""
if audio_file is None:
return None, "音声ファイルをアップロードしてください。"
if source_image is None:
return None, "ソース画像をアップロードしてください。"
try:
# 一時ファイルの作成
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_output:
output_path = tmp_output.name
# 処理実行
print(f"処理開始: audio={audio_file}, image={source_image}")
seed_everything(1024)
run(SDK, audio_file, source_image, output_path)
# 結果の確認
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
return output_path, "✅ 処理が完了しました!"
else:
return None, "❌ 処理に失敗しました。出力ファイルが生成されませんでした。"
except Exception as e:
import traceback
error_msg = f"❌ エラーが発生しました: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return None, error_msg
# Gradio UI
with gr.Blocks(title="DittoTalkingHead") as demo:
gr.Markdown("""
# DittoTalkingHead - Talking Head Generation
音声とソース画像から、リアルなTalking Headビデオを生成します。
## 使い方
1. **音声ファイル**(WAV形式)をアップロード
2. **ソース画像**(PNG/JPG形式)をアップロード
3. **生成**ボタンをクリック
⚠️ 初回実行時は、モデルのダウンロードのため時間がかかります(約2.5GB)。
### 技術仕様
- **モデル**: DittoTalkingHead (PyTorch版)
- **GPU**: NVIDIA A100推奨
- **モデル提供**: [digital-avatar/ditto-talkinghead](https://huggingface.co/digital-avatar/ditto-talkinghead)
""")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="音声ファイル (WAV)",
type="filepath"
)
image_input = gr.Image(
label="ソース画像",
type="filepath"
)
generate_btn = gr.Button("生成", variant="primary")
with gr.Column():
video_output = gr.Video(
label="生成されたビデオ"
)
status_output = gr.Textbox(
label="ステータス",
lines=3
)
# サンプル
gr.Examples(
examples=[
["example/audio.wav", "example/image.png"]
],
inputs=[audio_input, image_input],
outputs=[video_output, status_output],
fn=process_talking_head,
cache_examples=True
)
# イベントハンドラ
generate_btn.click(
fn=process_talking_head,
inputs=[audio_input, image_input],
outputs=[video_output, status_output]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
) |