Spaces:
Runtime error
Runtime error
File size: 8,124 Bytes
910f2aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
from gradio_client import Client, handle_file
from datetime import datetime
import os
import shutil
import logging
import time
from typing import Tuple, Optional
class TalkingHeadAPIClient:
"""DittoTalkingHead API クライアント"""
def __init__(self, space_name: str = "O-ken5481/talkingAvater_bgk", max_retries: int = 3, retry_delay: int = 5):
"""
Args:
space_name: Hugging Face SpaceのID(デフォルト: O-ken5481/talkingAvater_bgk)
max_retries: 最大リトライ回数
retry_delay: リトライ間隔(秒)
"""
self.space_name = space_name
self.max_retries = max_retries
self.retry_delay = retry_delay
self.logger = self._setup_logger()
self.client = None
self._connect()
def _setup_logger(self) -> logging.Logger:
"""ロガーの設定"""
logger = logging.getLogger('TalkingHeadAPIClient')
logger.setLevel(logging.INFO)
if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def _connect(self) -> None:
"""APIへの接続"""
for attempt in range(self.max_retries):
try:
self.logger.info(f"接続開始: {self.space_name} (試行 {attempt + 1}/{self.max_retries})")
self.client = Client(self.space_name)
self.logger.info("接続成功")
return
except Exception as e:
self.logger.error(f"接続失敗: {e}")
if attempt < self.max_retries - 1:
self.logger.info(f"{self.retry_delay}秒後にリトライします...")
time.sleep(self.retry_delay)
else:
raise ConnectionError(f"APIへの接続に失敗しました: {e}")
def generate_video(self, audio_path: str, image_path: str) -> Tuple[Optional[dict], str]:
"""
API経由で動画生成
Args:
audio_path: 音声ファイルのパス
image_path: 画像ファイルのパス
Returns:
tuple: (video_data, status_message)
"""
# ファイルの存在確認
if not os.path.exists(audio_path):
error_msg = f"音声ファイルが見つかりません: {audio_path}"
self.logger.error(error_msg)
return None, error_msg
if not os.path.exists(image_path):
error_msg = f"画像ファイルが見つかりません: {image_path}"
self.logger.error(error_msg)
return None, error_msg
# API呼び出し
for attempt in range(self.max_retries):
try:
self.logger.info(f"ファイルアップロード: {audio_path}, {image_path}")
self.logger.info("処理開始...")
result = self.client.predict(
audio_file=handle_file(audio_path),
source_image=handle_file(image_path),
api_name="/process_talking_head"
)
self.logger.info("動画生成完了")
return result
except Exception as e:
self.logger.error(f"処理エラー (試行 {attempt + 1}/{self.max_retries}): {e}")
if attempt < self.max_retries - 1:
self.logger.info(f"{self.retry_delay}秒後にリトライします...")
time.sleep(self.retry_delay)
else:
error_msg = f"動画生成に失敗しました: {e}"
return None, error_msg
def save_with_timestamp(self, video_path: str, output_dir: str = "example") -> Optional[str]:
"""
動画をタイムスタンプ付きで保存
Args:
video_path: 生成された動画のパス
output_dir: 保存先ディレクトリ
Returns:
str: 保存されたファイルパス(エラー時はNone)
"""
try:
# 動画パスの確認
if not video_path or not os.path.exists(video_path):
self.logger.error(f"動画ファイルが見つかりません: {video_path}")
return None
# 出力ディレクトリの作成
os.makedirs(output_dir, exist_ok=True)
# YYYY-MM-DD_HH-MM-SS.mp4 形式で保存
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_path = os.path.join(output_dir, f"{timestamp}.mp4")
# ファイルをコピー
shutil.copy2(video_path, output_path)
# ファイルサイズの確認
file_size = os.path.getsize(output_path)
self.logger.info(f"保存完了: {output_path} (サイズ: {file_size:,} bytes)")
return output_path
except Exception as e:
self.logger.error(f"保存エラー: {e}")
return None
def process_with_save(self, audio_path: str, image_path: str, output_dir: str = "example") -> Tuple[Optional[str], str]:
"""
動画生成と保存を一括実行
Args:
audio_path: 音声ファイルのパス
image_path: 画像ファイルのパス
output_dir: 保存先ディレクトリ
Returns:
tuple: (saved_path, status_message)
"""
# 動画生成
result = self.generate_video(audio_path, image_path)
if result[0] is None:
return None, result[1]
video_data, status = result
# 動画の保存
if isinstance(video_data, dict) and 'video' in video_data:
saved_path = self.save_with_timestamp(video_data['video'], output_dir)
if saved_path:
return saved_path, f"{status}\n保存先: {saved_path}"
else:
return None, f"{status}\n保存に失敗しました"
else:
return None, f"予期しないレスポンス形式: {video_data}"
def main():
"""テストスクリプトのメイン関数"""
# ロギング設定
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# クライアント初期化
try:
client = TalkingHeadAPIClient()
except Exception as e:
logging.error(f"クライアント初期化失敗: {e}")
return
# サンプルファイルを使用
audio_path = "example/audio.wav"
image_path = "example/image.png"
# ファイルの存在確認
if not os.path.exists(audio_path):
logging.error(f"音声ファイルが見つかりません: {audio_path}")
return
if not os.path.exists(image_path):
logging.error(f"画像ファイルが見つかりません: {image_path}")
return
try:
# 動画生成と保存
saved_path, status = client.process_with_save(audio_path, image_path)
if saved_path:
print(f"\n✅ 成功!")
print(f"ステータス: {status}")
print(f"動画を確認してください: {saved_path}")
else:
print(f"\n❌ 失敗")
print(f"ステータス: {status}")
except KeyboardInterrupt:
logging.info("処理を中断しました")
except Exception as e:
logging.error(f"予期しないエラー: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main() |