oKen38461's picture
`.gitignore`に`docs/`フォルダを追加して、無視するファイルを更新
43f5a2b
raw
history blame
4.28 kB
import gradio as gr
import os
import tempfile
import shutil
from pathlib import Path
from model_manager import ModelManager
from stream_pipeline_offline import StreamSDK
from inference import run, seed_everything
# モデルの初期化
print("=== モデルの初期化開始 ===")
# PyTorchモデルを使用(TensorRTモデルは非常に大きいため)
USE_PYTORCH = True
model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH)
if not model_manager.setup_models():
raise RuntimeError("モデルのセットアップに失敗しました。")
# SDKの初期化
if USE_PYTORCH:
data_root = "./checkpoints/ditto_pytorch"
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl"
else:
data_root = "./checkpoints/ditto_trt_Ampere_Plus"
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl"
try:
SDK = StreamSDK(cfg_pkl, data_root)
print("✅ SDK初期化成功")
except Exception as e:
print(f"❌ SDK初期化エラー: {e}")
raise
def process_talking_head(audio_file, source_image):
"""音声とソース画像からTalking Headビデオを生成"""
if audio_file is None:
return None, "音声ファイルをアップロードしてください。"
if source_image is None:
return None, "ソース画像をアップロードしてください。"
try:
# 一時ファイルの作成
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_output:
output_path = tmp_output.name
# 処理実行
print(f"処理開始: audio={audio_file}, image={source_image}")
seed_everything(1024)
run(SDK, audio_file, source_image, output_path)
# 結果の確認
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
return output_path, "✅ 処理が完了しました!"
else:
return None, "❌ 処理に失敗しました。出力ファイルが生成されませんでした。"
except Exception as e:
import traceback
error_msg = f"❌ エラーが発生しました: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return None, error_msg
# Gradio UI
with gr.Blocks(title="DittoTalkingHead") as demo:
gr.Markdown("""
# DittoTalkingHead - Talking Head Generation
音声とソース画像から、リアルなTalking Headビデオを生成します。
## 使い方
1. **音声ファイル**(WAV形式)をアップロード
2. **ソース画像**(PNG/JPG形式)をアップロード
3. **生成**ボタンをクリック
⚠️ 初回実行時は、モデルのダウンロードのため時間がかかります(約2.5GB)。
### 技術仕様
- **モデル**: DittoTalkingHead (PyTorch版)
- **GPU**: NVIDIA A100推奨
- **モデル提供**: [digital-avatar/ditto-talkinghead](https://huggingface.co/digital-avatar/ditto-talkinghead)
""")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="音声ファイル (WAV)",
type="filepath"
)
image_input = gr.Image(
label="ソース画像",
type="filepath"
)
generate_btn = gr.Button("生成", variant="primary")
with gr.Column():
video_output = gr.Video(
label="生成されたビデオ"
)
status_output = gr.Textbox(
label="ステータス",
lines=3
)
# サンプル
gr.Examples(
examples=[
["example/audio.wav", "example/image.png"]
],
inputs=[audio_input, image_input],
outputs=[video_output, status_output],
fn=process_talking_head,
cache_examples=True
)
# イベントハンドラ
generate_btn.click(
fn=process_talking_head,
inputs=[audio_input, image_input],
outputs=[video_output, status_output]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)