Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import tempfile | |
import shutil | |
from pathlib import Path | |
from model_manager import ModelManager | |
# サンプルファイルのディレクトリを定義(絶対パスに解決) | |
EXAMPLES_DIR = (Path(__file__).parent / "example").resolve() | |
OUTPUT_DIR = (Path(__file__).parent / "output").resolve() | |
# 出力ディレクトリの作成 | |
OUTPUT_DIR.mkdir(exist_ok=True) | |
# インポートエラーのデバッグ情報を表示 | |
try: | |
import filetype | |
print("✅ filetype module imported successfully") | |
except ImportError as e: | |
print(f"⚠️ filetype import failed: {e}") | |
print("Using fallback file type detection") | |
# モデルの初期化 | |
print("=== モデルの初期化開始 ===") | |
# PyTorchモデルを使用(TensorRTモデルは非常に大きいため) | |
USE_PYTORCH = True | |
model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH) | |
if not model_manager.setup_models(): | |
raise RuntimeError("モデルのセットアップに失敗しました。") | |
# SDKの初期化 | |
if USE_PYTORCH: | |
data_root = "./checkpoints/ditto_pytorch" | |
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" | |
else: | |
data_root = "./checkpoints/ditto_trt_Ampere_Plus" | |
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl" | |
# SDK初期化のためのグローバル変数 | |
SDK = None | |
try: | |
# モジュールをインポート | |
from stream_pipeline_offline import StreamSDK | |
from inference import run, seed_everything | |
SDK = StreamSDK(cfg_pkl, data_root) | |
print("✅ SDK初期化成功") | |
except Exception as e: | |
print(f"❌ SDK初期化エラー: {e}") | |
import traceback | |
traceback.print_exc() | |
raise | |
def process_talking_head(audio_file, source_image): | |
"""音声とソース画像からTalking Headビデオを生成""" | |
if audio_file is None: | |
return None, "音声ファイルをアップロードしてください。" | |
if source_image is None: | |
return None, "ソース画像をアップロードしてください。" | |
try: | |
# 出力ファイルの作成(出力ディレクトリ内) | |
import uuid | |
output_filename = f"{uuid.uuid4()}.mp4" | |
output_path = str(OUTPUT_DIR / output_filename) | |
# 処理実行 | |
print(f"処理開始: audio={audio_file}, image={source_image}") | |
seed_everything(1024) | |
run(SDK, audio_file, source_image, output_path) | |
# 結果の確認 | |
if os.path.exists(output_path) and os.path.getsize(output_path) > 0: | |
return output_path, "✅ 処理が完了しました!" | |
else: | |
return None, "❌ 処理に失敗しました。出力ファイルが生成されませんでした。" | |
except Exception as e: | |
import traceback | |
error_msg = f"❌ エラーが発生しました: {str(e)}\n{traceback.format_exc()}" | |
print(error_msg) | |
return None, error_msg | |
# Gradio UI | |
with gr.Blocks(title="DittoTalkingHead") as demo: | |
gr.Markdown(""" | |
# DittoTalkingHead - Talking Head Generation | |
音声とソース画像から、リアルなTalking Headビデオを生成します。 | |
## 使い方 | |
1. **音声ファイル**(WAV形式)をアップロード | |
2. **ソース画像**(PNG/JPG形式)をアップロード | |
3. **生成**ボタンをクリック | |
⚠️ 初回実行時は、モデルのダウンロードのため時間がかかります(約2.5GB)。 | |
### 技術仕様 | |
- **モデル**: DittoTalkingHead (PyTorch版) | |
- **GPU**: NVIDIA A100推奨 | |
- **モデル提供**: [digital-avatar/ditto-talkinghead](https://huggingface.co/digital-avatar/ditto-talkinghead) | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio( | |
label="音声ファイル (WAV)", | |
type="filepath" | |
) | |
image_input = gr.Image( | |
label="ソース画像", | |
type="filepath" | |
) | |
generate_btn = gr.Button("生成", variant="primary") | |
with gr.Column(): | |
video_output = gr.Video( | |
label="生成されたビデオ" | |
) | |
status_output = gr.Textbox( | |
label="ステータス", | |
lines=3 | |
) | |
# サンプル | |
example_audio = EXAMPLES_DIR / "audio.wav" | |
example_image = EXAMPLES_DIR / "image.png" | |
if example_audio.exists() and example_image.exists(): | |
gr.Examples( | |
examples=[ | |
[str(example_audio), str(example_image)] | |
], | |
inputs=[audio_input, image_input], | |
outputs=[video_output, status_output], | |
fn=process_talking_head, | |
cache_examples=True | |
) | |
# イベントハンドラ | |
generate_btn.click( | |
fn=process_talking_head, | |
inputs=[audio_input, image_input], | |
outputs=[video_output, status_output] | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
allowed_paths=[str(EXAMPLES_DIR), str(OUTPUT_DIR)] | |
) |