Spaces:
Runtime error
Runtime error
File size: 5,252 Bytes
43f5a2b f4998a2 146df9e 55535c7 0a00365 f4998a2 43f5a2b 4b17b0c 43f5a2b f4998a2 43f5a2b f4998a2 43f5a2b 55535c7 43f5a2b 146df9e 43f5a2b 0a00365 55535c7 43f5a2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import gradio as gr
import os
import tempfile
import shutil
from pathlib import Path
from model_manager import ModelManager
# サンプルファイルのディレクトリを定義(絶対パスに解決)
EXAMPLES_DIR = (Path(__file__).parent / "example").resolve()
OUTPUT_DIR = (Path(__file__).parent / "output").resolve()
# 出力ディレクトリの作成
OUTPUT_DIR.mkdir(exist_ok=True)
# インポートエラーのデバッグ情報を表示
try:
import filetype
print("✅ filetype module imported successfully")
except ImportError as e:
print(f"⚠️ filetype import failed: {e}")
print("Using fallback file type detection")
# モデルの初期化
print("=== モデルの初期化開始 ===")
# PyTorchモデルを使用(TensorRTモデルは非常に大きいため)
USE_PYTORCH = True
model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH)
if not model_manager.setup_models():
raise RuntimeError("モデルのセットアップに失敗しました。")
# SDKの初期化
if USE_PYTORCH:
data_root = "./checkpoints/ditto_pytorch"
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl"
else:
data_root = "./checkpoints/ditto_trt_Ampere_Plus"
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl"
# SDK初期化のためのグローバル変数
SDK = None
try:
# モジュールをインポート
from stream_pipeline_offline import StreamSDK
from inference import run, seed_everything
SDK = StreamSDK(cfg_pkl, data_root)
print("✅ SDK初期化成功")
except Exception as e:
print(f"❌ SDK初期化エラー: {e}")
import traceback
traceback.print_exc()
raise
def process_talking_head(audio_file, source_image):
"""音声とソース画像からTalking Headビデオを生成"""
if audio_file is None:
return None, "音声ファイルをアップロードしてください。"
if source_image is None:
return None, "ソース画像をアップロードしてください。"
try:
# 出力ファイルの作成(出力ディレクトリ内)
import uuid
output_filename = f"{uuid.uuid4()}.mp4"
output_path = str(OUTPUT_DIR / output_filename)
# 処理実行
print(f"処理開始: audio={audio_file}, image={source_image}")
seed_everything(1024)
run(SDK, audio_file, source_image, output_path)
# 結果の確認
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
return output_path, "✅ 処理が完了しました!"
else:
return None, "❌ 処理に失敗しました。出力ファイルが生成されませんでした。"
except Exception as e:
import traceback
error_msg = f"❌ エラーが発生しました: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return None, error_msg
# Gradio UI
with gr.Blocks(title="DittoTalkingHead") as demo:
gr.Markdown("""
# DittoTalkingHead - Talking Head Generation
音声とソース画像から、リアルなTalking Headビデオを生成します。
## 使い方
1. **音声ファイル**(WAV形式)をアップロード
2. **ソース画像**(PNG/JPG形式)をアップロード
3. **生成**ボタンをクリック
⚠️ 初回実行時は、モデルのダウンロードのため時間がかかります(約2.5GB)。
### 技術仕様
- **モデル**: DittoTalkingHead (PyTorch版)
- **GPU**: NVIDIA A100推奨
- **モデル提供**: [digital-avatar/ditto-talkinghead](https://huggingface.co/digital-avatar/ditto-talkinghead)
""")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="音声ファイル (WAV)",
type="filepath"
)
image_input = gr.Image(
label="ソース画像",
type="filepath"
)
generate_btn = gr.Button("生成", variant="primary")
with gr.Column():
video_output = gr.Video(
label="生成されたビデオ"
)
status_output = gr.Textbox(
label="ステータス",
lines=3
)
# サンプル
example_audio = EXAMPLES_DIR / "audio.wav"
example_image = EXAMPLES_DIR / "image.png"
if example_audio.exists() and example_image.exists():
gr.Examples(
examples=[
[str(example_audio), str(example_image)]
],
inputs=[audio_input, image_input],
outputs=[video_output, status_output],
fn=process_talking_head,
cache_examples=True
)
# イベントハンドラ
generate_btn.click(
fn=process_talking_head,
inputs=[audio_input, image_input],
outputs=[video_output, status_output]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
allowed_paths=[str(EXAMPLES_DIR), str(OUTPUT_DIR)]
) |