talkingAvater_bgk / test_api_client.py
oKen38461's picture
Phase2クリア_README.mdにTalking Head生成の使い方と技術仕様を追加し、requirements.txtにgradioとgradio_clientのバージョンを指定しました。
910f2aa
raw
history blame
8.12 kB
from gradio_client import Client, handle_file
from datetime import datetime
import os
import shutil
import logging
import time
from typing import Tuple, Optional
class TalkingHeadAPIClient:
"""DittoTalkingHead API クライアント"""
def __init__(self, space_name: str = "O-ken5481/talkingAvater_bgk", max_retries: int = 3, retry_delay: int = 5):
"""
Args:
space_name: Hugging Face SpaceのID(デフォルト: O-ken5481/talkingAvater_bgk)
max_retries: 最大リトライ回数
retry_delay: リトライ間隔(秒)
"""
self.space_name = space_name
self.max_retries = max_retries
self.retry_delay = retry_delay
self.logger = self._setup_logger()
self.client = None
self._connect()
def _setup_logger(self) -> logging.Logger:
"""ロガーの設定"""
logger = logging.getLogger('TalkingHeadAPIClient')
logger.setLevel(logging.INFO)
if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def _connect(self) -> None:
"""APIへの接続"""
for attempt in range(self.max_retries):
try:
self.logger.info(f"接続開始: {self.space_name} (試行 {attempt + 1}/{self.max_retries})")
self.client = Client(self.space_name)
self.logger.info("接続成功")
return
except Exception as e:
self.logger.error(f"接続失敗: {e}")
if attempt < self.max_retries - 1:
self.logger.info(f"{self.retry_delay}秒後にリトライします...")
time.sleep(self.retry_delay)
else:
raise ConnectionError(f"APIへの接続に失敗しました: {e}")
def generate_video(self, audio_path: str, image_path: str) -> Tuple[Optional[dict], str]:
"""
API経由で動画生成
Args:
audio_path: 音声ファイルのパス
image_path: 画像ファイルのパス
Returns:
tuple: (video_data, status_message)
"""
# ファイルの存在確認
if not os.path.exists(audio_path):
error_msg = f"音声ファイルが見つかりません: {audio_path}"
self.logger.error(error_msg)
return None, error_msg
if not os.path.exists(image_path):
error_msg = f"画像ファイルが見つかりません: {image_path}"
self.logger.error(error_msg)
return None, error_msg
# API呼び出し
for attempt in range(self.max_retries):
try:
self.logger.info(f"ファイルアップロード: {audio_path}, {image_path}")
self.logger.info("処理開始...")
result = self.client.predict(
audio_file=handle_file(audio_path),
source_image=handle_file(image_path),
api_name="/process_talking_head"
)
self.logger.info("動画生成完了")
return result
except Exception as e:
self.logger.error(f"処理エラー (試行 {attempt + 1}/{self.max_retries}): {e}")
if attempt < self.max_retries - 1:
self.logger.info(f"{self.retry_delay}秒後にリトライします...")
time.sleep(self.retry_delay)
else:
error_msg = f"動画生成に失敗しました: {e}"
return None, error_msg
def save_with_timestamp(self, video_path: str, output_dir: str = "example") -> Optional[str]:
"""
動画をタイムスタンプ付きで保存
Args:
video_path: 生成された動画のパス
output_dir: 保存先ディレクトリ
Returns:
str: 保存されたファイルパス(エラー時はNone)
"""
try:
# 動画パスの確認
if not video_path or not os.path.exists(video_path):
self.logger.error(f"動画ファイルが見つかりません: {video_path}")
return None
# 出力ディレクトリの作成
os.makedirs(output_dir, exist_ok=True)
# YYYY-MM-DD_HH-MM-SS.mp4 形式で保存
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_path = os.path.join(output_dir, f"{timestamp}.mp4")
# ファイルをコピー
shutil.copy2(video_path, output_path)
# ファイルサイズの確認
file_size = os.path.getsize(output_path)
self.logger.info(f"保存完了: {output_path} (サイズ: {file_size:,} bytes)")
return output_path
except Exception as e:
self.logger.error(f"保存エラー: {e}")
return None
def process_with_save(self, audio_path: str, image_path: str, output_dir: str = "example") -> Tuple[Optional[str], str]:
"""
動画生成と保存を一括実行
Args:
audio_path: 音声ファイルのパス
image_path: 画像ファイルのパス
output_dir: 保存先ディレクトリ
Returns:
tuple: (saved_path, status_message)
"""
# 動画生成
result = self.generate_video(audio_path, image_path)
if result[0] is None:
return None, result[1]
video_data, status = result
# 動画の保存
if isinstance(video_data, dict) and 'video' in video_data:
saved_path = self.save_with_timestamp(video_data['video'], output_dir)
if saved_path:
return saved_path, f"{status}\n保存先: {saved_path}"
else:
return None, f"{status}\n保存に失敗しました"
else:
return None, f"予期しないレスポンス形式: {video_data}"
def main():
"""テストスクリプトのメイン関数"""
# ロギング設定
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# クライアント初期化
try:
client = TalkingHeadAPIClient()
except Exception as e:
logging.error(f"クライアント初期化失敗: {e}")
return
# サンプルファイルを使用
audio_path = "example/audio.wav"
image_path = "example/image.png"
# ファイルの存在確認
if not os.path.exists(audio_path):
logging.error(f"音声ファイルが見つかりません: {audio_path}")
return
if not os.path.exists(image_path):
logging.error(f"画像ファイルが見つかりません: {image_path}")
return
try:
# 動画生成と保存
saved_path, status = client.process_with_save(audio_path, image_path)
if saved_path:
print(f"\n✅ 成功!")
print(f"ステータス: {status}")
print(f"動画を確認してください: {saved_path}")
else:
print(f"\n❌ 失敗")
print(f"ステータス: {status}")
except KeyboardInterrupt:
logging.info("処理を中断しました")
except Exception as e:
logging.error(f"予期しないエラー: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()