File size: 8,124 Bytes
910f2aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
from gradio_client import Client, handle_file
from datetime import datetime
import os
import shutil
import logging
import time
from typing import Tuple, Optional

class TalkingHeadAPIClient:
    """DittoTalkingHead API クライアント"""
    
    def __init__(self, space_name: str = "O-ken5481/talkingAvater_bgk", max_retries: int = 3, retry_delay: int = 5):
        """
        Args:
            space_name: Hugging Face SpaceのID(デフォルト: O-ken5481/talkingAvater_bgk)
            max_retries: 最大リトライ回数
            retry_delay: リトライ間隔(秒)
        """
        self.space_name = space_name
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self.logger = self._setup_logger()
        self.client = None
        self._connect()
    
    def _setup_logger(self) -> logging.Logger:
        """ロガーの設定"""
        logger = logging.getLogger('TalkingHeadAPIClient')
        logger.setLevel(logging.INFO)
        
        if not logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', 
                                        datefmt='%Y-%m-%d %H:%M:%S')
            handler.setFormatter(formatter)
            logger.addHandler(handler)
        
        return logger
    
    def _connect(self) -> None:
        """APIへの接続"""
        for attempt in range(self.max_retries):
            try:
                self.logger.info(f"接続開始: {self.space_name} (試行 {attempt + 1}/{self.max_retries})")
                self.client = Client(self.space_name)
                self.logger.info("接続成功")
                return
            except Exception as e:
                self.logger.error(f"接続失敗: {e}")
                if attempt < self.max_retries - 1:
                    self.logger.info(f"{self.retry_delay}秒後にリトライします...")
                    time.sleep(self.retry_delay)
                else:
                    raise ConnectionError(f"APIへの接続に失敗しました: {e}")
    
    def generate_video(self, audio_path: str, image_path: str) -> Tuple[Optional[dict], str]:
        """
        API経由で動画生成
        
        Args:
            audio_path: 音声ファイルのパス
            image_path: 画像ファイルのパス
            
        Returns:
            tuple: (video_data, status_message)
        """
        # ファイルの存在確認
        if not os.path.exists(audio_path):
            error_msg = f"音声ファイルが見つかりません: {audio_path}"
            self.logger.error(error_msg)
            return None, error_msg
        
        if not os.path.exists(image_path):
            error_msg = f"画像ファイルが見つかりません: {image_path}"
            self.logger.error(error_msg)
            return None, error_msg
        
        # API呼び出し
        for attempt in range(self.max_retries):
            try:
                self.logger.info(f"ファイルアップロード: {audio_path}, {image_path}")
                self.logger.info("処理開始...")
                
                result = self.client.predict(
                    audio_file=handle_file(audio_path),
                    source_image=handle_file(image_path),
                    api_name="/process_talking_head"
                )
                
                self.logger.info("動画生成完了")
                return result
                
            except Exception as e:
                self.logger.error(f"処理エラー (試行 {attempt + 1}/{self.max_retries}): {e}")
                if attempt < self.max_retries - 1:
                    self.logger.info(f"{self.retry_delay}秒後にリトライします...")
                    time.sleep(self.retry_delay)
                else:
                    error_msg = f"動画生成に失敗しました: {e}"
                    return None, error_msg
    
    def save_with_timestamp(self, video_path: str, output_dir: str = "example") -> Optional[str]:
        """
        動画をタイムスタンプ付きで保存
        
        Args:
            video_path: 生成された動画のパス
            output_dir: 保存先ディレクトリ
            
        Returns:
            str: 保存されたファイルパス(エラー時はNone)
        """
        try:
            # 動画パスの確認
            if not video_path or not os.path.exists(video_path):
                self.logger.error(f"動画ファイルが見つかりません: {video_path}")
                return None
            
            # 出力ディレクトリの作成
            os.makedirs(output_dir, exist_ok=True)
            
            # YYYY-MM-DD_HH-MM-SS.mp4 形式で保存
            timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
            output_path = os.path.join(output_dir, f"{timestamp}.mp4")
            
            # ファイルをコピー
            shutil.copy2(video_path, output_path)
            
            # ファイルサイズの確認
            file_size = os.path.getsize(output_path)
            self.logger.info(f"保存完了: {output_path} (サイズ: {file_size:,} bytes)")
            
            return output_path
            
        except Exception as e:
            self.logger.error(f"保存エラー: {e}")
            return None
    
    def process_with_save(self, audio_path: str, image_path: str, output_dir: str = "example") -> Tuple[Optional[str], str]:
        """
        動画生成と保存を一括実行
        
        Args:
            audio_path: 音声ファイルのパス
            image_path: 画像ファイルのパス
            output_dir: 保存先ディレクトリ
            
        Returns:
            tuple: (saved_path, status_message)
        """
        # 動画生成
        result = self.generate_video(audio_path, image_path)
        
        if result[0] is None:
            return None, result[1]
        
        video_data, status = result
        
        # 動画の保存
        if isinstance(video_data, dict) and 'video' in video_data:
            saved_path = self.save_with_timestamp(video_data['video'], output_dir)
            if saved_path:
                return saved_path, f"{status}\n保存先: {saved_path}"
            else:
                return None, f"{status}\n保存に失敗しました"
        else:
            return None, f"予期しないレスポンス形式: {video_data}"


def main():
    """テストスクリプトのメイン関数"""
    # ロギング設定
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    
    # クライアント初期化
    try:
        client = TalkingHeadAPIClient()
    except Exception as e:
        logging.error(f"クライアント初期化失敗: {e}")
        return
    
    # サンプルファイルを使用
    audio_path = "example/audio.wav"
    image_path = "example/image.png"
    
    # ファイルの存在確認
    if not os.path.exists(audio_path):
        logging.error(f"音声ファイルが見つかりません: {audio_path}")
        return
    
    if not os.path.exists(image_path):
        logging.error(f"画像ファイルが見つかりません: {image_path}")
        return
    
    try:
        # 動画生成と保存
        saved_path, status = client.process_with_save(audio_path, image_path)
        
        if saved_path:
            print(f"\n✅ 成功!")
            print(f"ステータス: {status}")
            print(f"動画を確認してください: {saved_path}")
        else:
            print(f"\n❌ 失敗")
            print(f"ステータス: {status}")
            
    except KeyboardInterrupt:
        logging.info("処理を中断しました")
    except Exception as e:
        logging.error(f"予期しないエラー: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()