oKen38461 commited on
Commit
0f839d2
·
1 Parent(s): 07b71bb

`.gitignore`に`tests/`を追加し、`README.md`のAPIドキュメントセクションを更新しました。また、`test_api_client.py`、`test_api.py`、`test_performance_optimized.py`、`test_performance.py`のテストスクリプトを削除しました。

Browse files
Files changed (6) hide show
  1. .gitignore +1 -0
  2. README.md +7 -3
  3. test_api.py +0 -102
  4. test_api_client.py +0 -220
  5. test_performance.py +0 -175
  6. test_performance_optimized.py +0 -375
.gitignore CHANGED
@@ -38,6 +38,7 @@ log/*
38
  example/
39
  ToDo/
40
  docs/
 
41
 
42
 
43
  !example/audio.wav
 
38
  example/
39
  ToDo/
40
  docs/
41
+ tests/
42
 
43
 
44
  !example/audio.wav
README.md CHANGED
@@ -74,8 +74,12 @@ python test_api_client.py
74
  - **処理速度**: 16秒の音声を約15秒で処理(Phase 3最適化により50-65%高速化)
75
 
76
  ## ドキュメント
77
- - [APIドキュメント](docs/api_documentation.md) - 詳細なAPI仕様とサンプルコード
78
- - [Phase2実装仕様](ToDo/0717-2_Phase2_API_SOW.md) - API実装の詳細
79
- - [Phase3最適化ガイド](docs/phase3_optimization_guide.md) - パフォーマンス最適化の詳細
 
 
 
 
80
 
81
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
74
  - **処理速度**: 16秒の音声を約15秒で処理(Phase 3最適化により50-65%高速化)
75
 
76
  ## ドキュメント
77
+ - 📁 **[APIドキュメント](docs/api/)** - リアルタイムを超える動画生成APIの全ドキュメント
78
+ - 🚀 [統合ガイド](docs/api/integration_guide.md) - 完全なAPIインテグレーションガイド
79
+ - [クイックリファレンス](docs/api/quick_reference.md) - 5分で実装できるクイックスタート
80
+ - 📝 [API仕様書](docs/api/documentation.md) - 詳細なAPI仕様とサンプルコード
81
+ - 💻 [統合サンプル集](docs/api/integration_examples.py) - 実装例とベストプラクティス
82
+ - 📋 [Phase2実装仕様](ToDo/0717-2_Phase2_API_SOW.md) - API実装の詳細
83
+ - 🔧 [Phase3最適化ガイド](docs/phase3_optimization_guide.md) - パフォーマンス最適化の詳細
84
 
85
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
test_api.py DELETED
@@ -1,102 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- DittoTalkingHead API テストスクリプト
4
- 簡単なAPIテストを実行します
5
- """
6
-
7
- import logging
8
- import sys
9
- from test_api_client import TalkingHeadAPIClient
10
-
11
- # ロギング設定
12
- logging.basicConfig(
13
- level=logging.INFO,
14
- format='%(asctime)s - %(message)s',
15
- datefmt='%Y-%m-%d %H:%M:%S'
16
- )
17
-
18
- def test_basic_functionality():
19
- """基本機能のテスト"""
20
- logging.info("=== 基本機能テスト開始 ===")
21
-
22
- # クライアント初期化
23
- client = TalkingHeadAPIClient()
24
-
25
- # サンプルファイルを使用
26
- audio_path = "example/audio.wav"
27
- image_path = "example/image.png"
28
-
29
- try:
30
- # 動画生成
31
- logging.info(f"接続開始: O-ken5481/talkingAvater_bgk")
32
- logging.info(f"ファイルアップロード: {audio_path}, {image_path}")
33
- logging.info("処理開始...")
34
-
35
- result = client.generate_video(audio_path, image_path)
36
- video_path, status = result
37
-
38
- if video_path:
39
- logging.info("動画生成完了")
40
-
41
- # タイムスタンプ付きで保存
42
- if isinstance(video_path, dict) and 'video' in video_path:
43
- saved_path = client.save_with_timestamp(video_path['video'])
44
- if saved_path:
45
- logging.info(f"保存完了: {saved_path}")
46
- print(f"\n✅ テスト成功!")
47
- print(f"ステータス: {status}")
48
- print(f"保存先: {saved_path}")
49
- return True
50
-
51
- print(f"\n❌ テスト失敗")
52
- print(f"ステータス: {status}")
53
- return False
54
-
55
- except Exception as e:
56
- logging.error(f"エラー発生: {e}")
57
- return False
58
-
59
- def test_error_handling():
60
- """エラーハンドリングのテスト"""
61
- logging.info("\n=== エラーハンドリングテスト開始 ===")
62
-
63
- client = TalkingHeadAPIClient()
64
-
65
- # 存在しないファイルでテスト
66
- result = client.generate_video("nonexistent.wav", "nonexistent.png")
67
- video_path, status = result
68
-
69
- if video_path is None and "見つかりません" in status:
70
- logging.info("✅ ファイル不在エラーを正しく検出")
71
- return True
72
- else:
73
- logging.error("❌ エラーハンドリングが正しく動作していません")
74
- return False
75
-
76
- def main():
77
- """メイン関数"""
78
- print("DittoTalkingHead API テスト")
79
- print("=" * 50)
80
-
81
- # 基本機能テスト
82
- basic_test_passed = test_basic_functionality()
83
-
84
- # エラーハンドリングテスト
85
- error_test_passed = test_error_handling()
86
-
87
- # 結果サマリー
88
- print("\n" + "=" * 50)
89
- print("テスト結果:")
90
- print(f"- 基本機能テスト: {'✅ 成功' if basic_test_passed else '❌ 失敗'}")
91
- print(f"- エラーハンドリングテスト: {'✅ 成功' if error_test_passed else '❌ 失敗'}")
92
-
93
- # 終了コード
94
- if basic_test_passed and error_test_passed:
95
- print("\n全てのテストが成功しました! 🎉")
96
- sys.exit(0)
97
- else:
98
- print("\n一部のテストが失敗しました。")
99
- sys.exit(1)
100
-
101
- if __name__ == "__main__":
102
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_api_client.py DELETED
@@ -1,220 +0,0 @@
1
- from gradio_client import Client, handle_file
2
- from datetime import datetime
3
- import os
4
- import shutil
5
- import logging
6
- import time
7
- from typing import Tuple, Optional
8
-
9
- class TalkingHeadAPIClient:
10
- """DittoTalkingHead API クライアント"""
11
-
12
- def __init__(self, space_name: str = "O-ken5481/talkingAvater_bgk", max_retries: int = 3, retry_delay: int = 5):
13
- """
14
- Args:
15
- space_name: Hugging Face SpaceのID(デフォルト: O-ken5481/talkingAvater_bgk)
16
- max_retries: 最大リトライ回数
17
- retry_delay: リトライ間隔(秒)
18
- """
19
- self.space_name = space_name
20
- self.max_retries = max_retries
21
- self.retry_delay = retry_delay
22
- self.logger = self._setup_logger()
23
- self.client = None
24
- self._connect()
25
-
26
- def _setup_logger(self) -> logging.Logger:
27
- """ロガーの設定"""
28
- logger = logging.getLogger('TalkingHeadAPIClient')
29
- logger.setLevel(logging.INFO)
30
-
31
- if not logger.handlers:
32
- handler = logging.StreamHandler()
33
- formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s',
34
- datefmt='%Y-%m-%d %H:%M:%S')
35
- handler.setFormatter(formatter)
36
- logger.addHandler(handler)
37
-
38
- return logger
39
-
40
- def _connect(self) -> None:
41
- """APIへの接続"""
42
- for attempt in range(self.max_retries):
43
- try:
44
- self.logger.info(f"接続開始: {self.space_name} (試行 {attempt + 1}/{self.max_retries})")
45
- self.client = Client(self.space_name)
46
- self.logger.info("接続成功")
47
- return
48
- except Exception as e:
49
- self.logger.error(f"接続失敗: {e}")
50
- if attempt < self.max_retries - 1:
51
- self.logger.info(f"{self.retry_delay}秒後にリトライします...")
52
- time.sleep(self.retry_delay)
53
- else:
54
- raise ConnectionError(f"APIへの接続に失敗しました: {e}")
55
-
56
- def generate_video(self, audio_path: str, image_path: str) -> Tuple[Optional[dict], str]:
57
- """
58
- API経由で動画生成
59
-
60
- Args:
61
- audio_path: 音声ファイルのパス
62
- image_path: 画像ファイルのパス
63
-
64
- Returns:
65
- tuple: (video_data, status_message)
66
- """
67
- # ファイルの存在確認
68
- if not os.path.exists(audio_path):
69
- error_msg = f"音声ファイルが見つかりません: {audio_path}"
70
- self.logger.error(error_msg)
71
- return None, error_msg
72
-
73
- if not os.path.exists(image_path):
74
- error_msg = f"画像ファイルが見つかりません: {image_path}"
75
- self.logger.error(error_msg)
76
- return None, error_msg
77
-
78
- # API呼び出し
79
- for attempt in range(self.max_retries):
80
- try:
81
- self.logger.info(f"ファイルアップロード: {audio_path}, {image_path}")
82
- self.logger.info("処理開始...")
83
-
84
- result = self.client.predict(
85
- audio_file=handle_file(audio_path),
86
- source_image=handle_file(image_path),
87
- api_name="/process_talking_head"
88
- )
89
-
90
- self.logger.info("動画生成完了")
91
- return result
92
-
93
- except Exception as e:
94
- self.logger.error(f"処理エラー (試行 {attempt + 1}/{self.max_retries}): {e}")
95
- if attempt < self.max_retries - 1:
96
- self.logger.info(f"{self.retry_delay}秒後にリトライします...")
97
- time.sleep(self.retry_delay)
98
- else:
99
- error_msg = f"動画生成に失敗しました: {e}"
100
- return None, error_msg
101
-
102
- def save_with_timestamp(self, video_path: str, output_dir: str = "example") -> Optional[str]:
103
- """
104
- 動画をタイムスタンプ付きで保存
105
-
106
- Args:
107
- video_path: 生成された動画のパス
108
- output_dir: 保存先ディレクトリ
109
-
110
- Returns:
111
- str: 保存されたファイルパス(エラー時はNone)
112
- """
113
- try:
114
- # 動画パスの確認
115
- if not video_path or not os.path.exists(video_path):
116
- self.logger.error(f"動画ファイルが見つかりません: {video_path}")
117
- return None
118
-
119
- # 出力ディレクトリの作成
120
- os.makedirs(output_dir, exist_ok=True)
121
-
122
- # YYYY-MM-DD_HH-MM-SS.mp4 形式で保存
123
- timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
124
- output_path = os.path.join(output_dir, f"{timestamp}.mp4")
125
-
126
- # ファイルをコピー
127
- shutil.copy2(video_path, output_path)
128
-
129
- # ファイルサイズの確認
130
- file_size = os.path.getsize(output_path)
131
- self.logger.info(f"保存完了: {output_path} (サイズ: {file_size:,} bytes)")
132
-
133
- return output_path
134
-
135
- except Exception as e:
136
- self.logger.error(f"保存エラー: {e}")
137
- return None
138
-
139
- def process_with_save(self, audio_path: str, image_path: str, output_dir: str = "example") -> Tuple[Optional[str], str]:
140
- """
141
- 動画生成と保存を一括実行
142
-
143
- Args:
144
- audio_path: 音声ファイルのパス
145
- image_path: 画像ファイルのパス
146
- output_dir: 保存先ディレクトリ
147
-
148
- Returns:
149
- tuple: (saved_path, status_message)
150
- """
151
- # 動画生成
152
- result = self.generate_video(audio_path, image_path)
153
-
154
- if result[0] is None:
155
- return None, result[1]
156
-
157
- video_data, status = result
158
-
159
- # 動画の保存
160
- if isinstance(video_data, dict) and 'video' in video_data:
161
- saved_path = self.save_with_timestamp(video_data['video'], output_dir)
162
- if saved_path:
163
- return saved_path, f"{status}\n保存先: {saved_path}"
164
- else:
165
- return None, f"{status}\n保存に失敗しました"
166
- else:
167
- return None, f"予期しないレスポンス形式: {video_data}"
168
-
169
-
170
- def main():
171
- """テストスクリプトのメイン関数"""
172
- # ロギング設定
173
- logging.basicConfig(
174
- level=logging.INFO,
175
- format='%(asctime)s - %(message)s',
176
- datefmt='%Y-%m-%d %H:%M:%S'
177
- )
178
-
179
- # クライアント初期化
180
- try:
181
- client = TalkingHeadAPIClient()
182
- except Exception as e:
183
- logging.error(f"クライアント初期化失敗: {e}")
184
- return
185
-
186
- # サンプルファイルを使用
187
- audio_path = "example/audio.wav"
188
- image_path = "example/image.png"
189
-
190
- # ファイルの存在確認
191
- if not os.path.exists(audio_path):
192
- logging.error(f"音声ファイルが見つかりません: {audio_path}")
193
- return
194
-
195
- if not os.path.exists(image_path):
196
- logging.error(f"画像ファイルが見つかりません: {image_path}")
197
- return
198
-
199
- try:
200
- # 動画生成と保存
201
- saved_path, status = client.process_with_save(audio_path, image_path)
202
-
203
- if saved_path:
204
- print(f"\n✅ 成功!")
205
- print(f"ステータス: {status}")
206
- print(f"動画を確認してください: {saved_path}")
207
- else:
208
- print(f"\n❌ 失敗")
209
- print(f"ステータス: {status}")
210
-
211
- except KeyboardInterrupt:
212
- logging.info("処理を中断しました")
213
- except Exception as e:
214
- logging.error(f"予期しないエラー: {e}")
215
- import traceback
216
- traceback.print_exc()
217
-
218
-
219
- if __name__ == "__main__":
220
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_performance.py DELETED
@@ -1,175 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- パフォーマンステストスクリプト
4
- 動画生成の各ステップの実行時間を計測
5
- """
6
-
7
- import time
8
- import logging
9
- from test_api_client import TalkingHeadAPIClient
10
- import os
11
-
12
- # ロギング設定
13
- logging.basicConfig(
14
- level=logging.INFO,
15
- format='%(asctime)s - %(message)s',
16
- datefmt='%Y-%m-%d %H:%M:%S'
17
- )
18
-
19
- class TimingStats:
20
- def __init__(self):
21
- self.stats = {}
22
- self.start_times = {}
23
-
24
- def start(self, name):
25
- self.start_times[name] = time.time()
26
-
27
- def end(self, name):
28
- if name in self.start_times:
29
- duration = time.time() - self.start_times[name]
30
- self.stats[name] = duration
31
- return duration
32
- return None
33
-
34
- def report(self):
35
- print("\n=== パフォーマンス計測結果 ===")
36
- total_time = sum(self.stats.values())
37
- for name, duration in self.stats.items():
38
- percentage = (duration / total_time) * 100 if total_time > 0 else 0
39
- print(f"{name}: {duration:.2f}秒 ({percentage:.1f}%)")
40
- print(f"\n合計時間: {total_time:.2f}秒")
41
-
42
- # 音声ファイルの長さを取得
43
- try:
44
- import librosa
45
- audio_path = "example/audio.wav"
46
- y, sr = librosa.load(audio_path, sr=None)
47
- audio_duration = len(y) / sr
48
- print(f"音声ファイルの長さ: {audio_duration:.2f}秒")
49
- print(f"処理時間比率: {total_time/audio_duration:.2f}x")
50
- except Exception as e:
51
- print(f"音声長さの取得失敗: {e}")
52
-
53
- def test_performance():
54
- """パフォーマンステストを実行"""
55
- timer = TimingStats()
56
-
57
- # 全体の開始時間
58
- timer.start("全体処理")
59
-
60
- # クライアント初期化
61
- timer.start("API接続")
62
- try:
63
- client = TalkingHeadAPIClient()
64
- timer.end("API接続")
65
- except Exception as e:
66
- logging.error(f"クライアント初期化失敗: {e}")
67
- return
68
-
69
- # サンプルファイル
70
- audio_path = "example/audio.wav"
71
- image_path = "example/image.png"
72
-
73
- # ファイル情報を表示
74
- audio_size = os.path.getsize(audio_path) / 1024 / 1024 # MB
75
- image_size = os.path.getsize(image_path) / 1024 / 1024 # MB
76
- print(f"\n入力ファイル情報:")
77
- print(f"- 音声: {audio_path} ({audio_size:.2f} MB)")
78
- print(f"- 画像: {image_path} ({image_size:.2f} MB)")
79
-
80
- # 動画生成
81
- timer.start("動画生成(API呼び出し)")
82
- try:
83
- result = client.generate_video(audio_path, image_path)
84
- video_data, status = result
85
- timer.end("動画生成(API呼び出し)")
86
-
87
- if video_data:
88
- # 保存処理
89
- timer.start("動画保存")
90
- if isinstance(video_data, dict) and 'video' in video_data:
91
- saved_path = client.save_with_timestamp(video_data['video'])
92
- timer.end("動画保存")
93
-
94
- # 出力ファイル情報
95
- output_size = os.path.getsize(saved_path) / 1024 / 1024 # MB
96
- print(f"\n出力ファイル情報:")
97
- print(f"- 動画: {saved_path} ({output_size:.2f} MB)")
98
-
99
- timer.end("全体処理")
100
- timer.report()
101
-
102
- print(f"\n✅ テスト成功!")
103
- print(f"ステータス: {status}")
104
- else:
105
- print(f"\n❌ テスト失敗")
106
- print(f"ステータス: {status}")
107
-
108
- except Exception as e:
109
- logging.error(f"エラー発生: {e}")
110
- import traceback
111
- traceback.print_exc()
112
-
113
- def test_multiple_runs(runs=3):
114
- """複数回実行して平均時間を計測"""
115
- print(f"\n=== {runs}回連続実行テスト ===")
116
-
117
- times = []
118
- for i in range(runs):
119
- print(f"\n--- 実行 {i+1}/{runs} ---")
120
- start = time.time()
121
-
122
- try:
123
- client = TalkingHeadAPIClient()
124
- result = client.generate_video("example/audio.wav", "example/image.png")
125
- if result[0]:
126
- duration = time.time() - start
127
- times.append(duration)
128
- print(f"実行時間: {duration:.2f}秒")
129
- except Exception as e:
130
- print(f"エラー: {e}")
131
-
132
- if times:
133
- avg_time = sum(times) / len(times)
134
- min_time = min(times)
135
- max_time = max(times)
136
- print(f"\n=== 統計 ===")
137
- print(f"平均時間: {avg_time:.2f}秒")
138
- print(f"最小時間: {min_time:.2f}秒")
139
- print(f"最大時間: {max_time:.2f}秒")
140
-
141
- def analyze_bottlenecks():
142
- """ボトルネック分析のための詳細テスト"""
143
- print("\n=== ボトルネック分析 ===")
144
-
145
- # ローカルファイルの読み込み時間
146
- start = time.time()
147
- with open("example/audio.wav", "rb") as f:
148
- audio_data = f.read()
149
- with open("example/image.png", "rb") as f:
150
- image_data = f.read()
151
- local_read_time = time.time() - start
152
- print(f"ローカルファイル読み込み: {local_read_time:.3f}秒")
153
-
154
- # ネットワーク遅延の推定(Hugging Face Spaceへのping相当)
155
- import requests
156
- start = time.time()
157
- try:
158
- response = requests.get("https://o-ken5481-talkingavater-bgk.hf.space", timeout=10)
159
- network_time = time.time() - start
160
- print(f"ネットワーク遅延(推定): {network_time:.3f}秒")
161
- except:
162
- print("ネットワーク遅延の測定失敗")
163
-
164
- if __name__ == "__main__":
165
- print("DittoTalkingHead パフォーマンステスト")
166
- print("=" * 50)
167
-
168
- # 1. 詳細な時間計測
169
- test_performance()
170
-
171
- # 2. 複数回実行テスト
172
- # test_multiple_runs(3)
173
-
174
- # 3. ボトルネック分析
175
- analyze_bottlenecks()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_performance_optimized.py DELETED
@@ -1,375 +0,0 @@
1
- """
2
- Performance test script for Phase 3 optimizations
3
- Tests various optimization strategies and measures performance improvements
4
- """
5
-
6
- import time
7
- import os
8
- import sys
9
- import numpy as np
10
- from pathlib import Path
11
- import torch
12
- from typing import Dict, List, Tuple
13
- import json
14
- from datetime import datetime
15
-
16
- # Add project root to path
17
- sys.path.append(str(Path(__file__).parent))
18
-
19
- from model_manager import ModelManager
20
- from core.optimization import (
21
- FixedResolutionProcessor,
22
- GPUOptimizer,
23
- AvatarCache,
24
- AvatarTokenManager,
25
- ColdStartOptimizer
26
- )
27
-
28
-
29
- class PerformanceTester:
30
- """Performance testing framework for DittoTalkingHead optimizations"""
31
-
32
- def __init__(self):
33
- self.results = []
34
- self.resolution_optimizer = FixedResolutionProcessor()
35
- self.gpu_optimizer = GPUOptimizer()
36
- self.cold_start_optimizer = ColdStartOptimizer()
37
- self.avatar_cache = AvatarCache()
38
-
39
- # Test configurations
40
- self.test_configs = {
41
- "audio_durations": [4, 8, 16, 32], # seconds
42
- "resolutions": [256, 320, 512], # will test 320 fixed vs others
43
- "optimization_levels": ["none", "gpu_only", "resolution_only", "full"]
44
- }
45
-
46
- def setup_test_environment(self):
47
- """Set up test environment"""
48
- print("=== Setting up test environment ===")
49
-
50
- # Initialize models
51
- USE_PYTORCH = True
52
- model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH)
53
-
54
- if not model_manager.setup_models():
55
- raise RuntimeError("Failed to setup models")
56
-
57
- # Initialize SDK
58
- if USE_PYTORCH:
59
- data_root = "./checkpoints/ditto_pytorch"
60
- cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl"
61
- else:
62
- data_root = "./checkpoints/ditto_trt_Ampere_Plus"
63
- cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl"
64
-
65
- from stream_pipeline_offline import StreamSDK
66
- self.sdk = StreamSDK(cfg_pkl, data_root)
67
-
68
- print("✅ Test environment ready")
69
-
70
- def generate_test_data(self, duration: int) -> Tuple[str, str]:
71
- """
72
- Generate test audio and image files
73
-
74
- Args:
75
- duration: Audio duration in seconds
76
-
77
- Returns:
78
- Tuple of (audio_path, image_path)
79
- """
80
- import tempfile
81
- from scipy.io import wavfile
82
- from PIL import Image
83
-
84
- # Generate test audio (sine wave)
85
- sample_rate = 16000
86
- t = np.linspace(0, duration, duration * sample_rate)
87
- audio_data = np.sin(2 * np.pi * 440 * t).astype(np.float32) * 0.5
88
-
89
- with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
90
- wavfile.write(tmp.name, sample_rate, audio_data)
91
- audio_path = tmp.name
92
-
93
- # Generate test image
94
- img = Image.new('RGB', (512, 512), color='white')
95
- # Add some features
96
- from PIL import ImageDraw
97
- draw = ImageDraw.Draw(img)
98
- draw.ellipse([156, 156, 356, 356], fill='lightblue') # Face
99
- draw.ellipse([200, 200, 220, 220], fill='black') # Left eye
100
- draw.ellipse([292, 200, 312, 220], fill='black') # Right eye
101
- draw.arc([220, 250, 292, 300], 0, 180, fill='red', width=3) # Mouth
102
-
103
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
104
- img.save(tmp.name)
105
- image_path = tmp.name
106
-
107
- return audio_path, image_path
108
-
109
- def test_baseline(self, audio_duration: int) -> Dict[str, float]:
110
- """
111
- Test baseline performance without optimizations
112
-
113
- Args:
114
- audio_duration: Test audio duration in seconds
115
-
116
- Returns:
117
- Performance metrics
118
- """
119
- print(f"\n--- Testing baseline (no optimizations, {audio_duration}s audio) ---")
120
-
121
- audio_path, image_path = self.generate_test_data(audio_duration)
122
-
123
- try:
124
- # Disable optimizations
125
- torch.backends.cudnn.benchmark = False
126
-
127
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp:
128
- output_path = tmp.name
129
-
130
- # Run without optimizations
131
- from inference import run, seed_everything
132
- seed_everything(1024)
133
-
134
- start_time = time.time()
135
- run(self.sdk, audio_path, image_path, output_path)
136
- process_time = time.time() - start_time
137
-
138
- # Clean up
139
- for path in [audio_path, image_path, output_path]:
140
- if os.path.exists(path):
141
- os.unlink(path)
142
-
143
- return {
144
- "audio_duration": audio_duration,
145
- "process_time": process_time,
146
- "realtime_factor": process_time / audio_duration,
147
- "optimization": "none"
148
- }
149
-
150
- except Exception as e:
151
- print(f"Error in baseline test: {e}")
152
- return None
153
-
154
- def test_gpu_optimization(self, audio_duration: int) -> Dict[str, float]:
155
- """Test with GPU optimizations only"""
156
- print(f"\n--- Testing GPU optimization ({audio_duration}s audio) ---")
157
-
158
- audio_path, image_path = self.generate_test_data(audio_duration)
159
-
160
- try:
161
- # Apply GPU optimizations
162
- self.gpu_optimizer._setup_cuda_optimizations()
163
-
164
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp:
165
- output_path = tmp.name
166
-
167
- from inference import run, seed_everything
168
- seed_everything(1024)
169
-
170
- start_time = time.time()
171
- run(self.sdk, audio_path, image_path, output_path)
172
- process_time = time.time() - start_time
173
-
174
- # Clean up
175
- for path in [audio_path, image_path, output_path]:
176
- if os.path.exists(path):
177
- os.unlink(path)
178
-
179
- return {
180
- "audio_duration": audio_duration,
181
- "process_time": process_time,
182
- "realtime_factor": process_time / audio_duration,
183
- "optimization": "gpu_only"
184
- }
185
-
186
- except Exception as e:
187
- print(f"Error in GPU optimization test: {e}")
188
- return None
189
-
190
- def test_resolution_optimization(self, audio_duration: int) -> Dict[str, float]:
191
- """Test with resolution optimization (320x320)"""
192
- print(f"\n--- Testing resolution optimization ({audio_duration}s audio) ---")
193
-
194
- audio_path, image_path = self.generate_test_data(audio_duration)
195
-
196
- try:
197
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp:
198
- output_path = tmp.name
199
-
200
- # Apply resolution optimization
201
- setup_kwargs = {
202
- "max_size": self.resolution_optimizer.get_max_dim(), # 320
203
- "sampling_timesteps": self.resolution_optimizer.get_diffusion_steps() # 25
204
- }
205
-
206
- from inference import run, seed_everything
207
- seed_everything(1024)
208
-
209
- start_time = time.time()
210
- run(self.sdk, audio_path, image_path, output_path,
211
- more_kwargs={"setup_kwargs": setup_kwargs})
212
- process_time = time.time() - start_time
213
-
214
- # Clean up
215
- for path in [audio_path, image_path, output_path]:
216
- if os.path.exists(path):
217
- os.unlink(path)
218
-
219
- return {
220
- "audio_duration": audio_duration,
221
- "process_time": process_time,
222
- "realtime_factor": process_time / audio_duration,
223
- "optimization": "resolution_only",
224
- "resolution": f"{self.resolution_optimizer.get_max_dim()}x{self.resolution_optimizer.get_max_dim()}"
225
- }
226
-
227
- except Exception as e:
228
- print(f"Error in resolution optimization test: {e}")
229
- return None
230
-
231
- def test_full_optimization(self, audio_duration: int) -> Dict[str, float]:
232
- """Test with all optimizations enabled"""
233
- print(f"\n--- Testing full optimization ({audio_duration}s audio) ---")
234
-
235
- audio_path, image_path = self.generate_test_data(audio_duration)
236
-
237
- try:
238
- # Apply all optimizations
239
- self.gpu_optimizer._setup_cuda_optimizations()
240
-
241
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp:
242
- output_path = tmp.name
243
-
244
- setup_kwargs = {
245
- "max_size": self.resolution_optimizer.get_max_dim(),
246
- "sampling_timesteps": self.resolution_optimizer.get_diffusion_steps()
247
- }
248
-
249
- from inference import run, seed_everything
250
- seed_everything(1024)
251
-
252
- start_time = time.time()
253
- run(self.sdk, audio_path, image_path, output_path,
254
- more_kwargs={"setup_kwargs": setup_kwargs})
255
- process_time = time.time() - start_time
256
-
257
- # Clean up
258
- for path in [audio_path, image_path, output_path]:
259
- if os.path.exists(path):
260
- os.unlink(path)
261
-
262
- return {
263
- "audio_duration": audio_duration,
264
- "process_time": process_time,
265
- "realtime_factor": process_time / audio_duration,
266
- "optimization": "full",
267
- "resolution": f"{self.resolution_optimizer.get_max_dim()}x{self.resolution_optimizer.get_max_dim()}",
268
- "gpu_optimized": True
269
- }
270
-
271
- except Exception as e:
272
- print(f"Error in full optimization test: {e}")
273
- return None
274
-
275
- def run_comprehensive_test(self):
276
- """Run comprehensive performance tests"""
277
- print("\n" + "="*60)
278
- print("Starting comprehensive performance test")
279
- print("="*60)
280
-
281
- self.setup_test_environment()
282
-
283
- # Test different audio durations and optimization levels
284
- for duration in self.test_configs["audio_durations"]:
285
- print(f"\n{'='*60}")
286
- print(f"Testing with {duration}s audio")
287
- print(f"{'='*60}")
288
-
289
- # Run tests with different optimization levels
290
- tests = [
291
- ("Baseline", self.test_baseline),
292
- ("GPU Only", self.test_gpu_optimization),
293
- ("Resolution Only", self.test_resolution_optimization),
294
- ("Full Optimization", self.test_full_optimization)
295
- ]
296
-
297
- duration_results = []
298
-
299
- for test_name, test_func in tests:
300
- result = test_func(duration)
301
- if result:
302
- duration_results.append(result)
303
- print(f"{test_name}: {result['process_time']:.2f}s (RT factor: {result['realtime_factor']:.2f}x)")
304
-
305
- # Clear GPU cache between tests
306
- self.gpu_optimizer.clear_cache()
307
- time.sleep(1) # Brief pause
308
-
309
- self.results.extend(duration_results)
310
-
311
- # Generate report
312
- self.generate_report()
313
-
314
- def generate_report(self):
315
- """Generate performance test report"""
316
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
317
- report_file = f"performance_report_{timestamp}.json"
318
-
319
- # Calculate improvements
320
- summary = {
321
- "test_date": timestamp,
322
- "gpu_info": self.gpu_optimizer.get_memory_stats(),
323
- "optimization_config": self.resolution_optimizer.get_performance_config(),
324
- "results": self.results
325
- }
326
-
327
- # Calculate average improvements by optimization type
328
- avg_improvements = {}
329
- for opt_type in ["gpu_only", "resolution_only", "full"]:
330
- opt_results = [r for r in self.results if r.get("optimization") == opt_type]
331
- baseline_results = [r for r in self.results if r.get("optimization") == "none"
332
- and r["audio_duration"] == opt_results[0]["audio_duration"]]
333
-
334
- if opt_results and baseline_results:
335
- avg_improvement = 0
336
- for opt_r in opt_results:
337
- baseline_r = next((b for b in baseline_results
338
- if b["audio_duration"] == opt_r["audio_duration"]), None)
339
- if baseline_r:
340
- improvement = (baseline_r["process_time"] - opt_r["process_time"]) / baseline_r["process_time"] * 100
341
- avg_improvement += improvement
342
-
343
- avg_improvements[opt_type] = avg_improvement / len(opt_results)
344
-
345
- summary["average_improvements"] = avg_improvements
346
-
347
- # Save report
348
- with open(report_file, 'w') as f:
349
- json.dump(summary, f, indent=2)
350
-
351
- # Print summary
352
- print("\n" + "="*60)
353
- print("PERFORMANCE TEST SUMMARY")
354
- print("="*60)
355
-
356
- print("\nAverage Performance Improvements:")
357
- for opt_type, improvement in avg_improvements.items():
358
- print(f"- {opt_type}: {improvement:.1f}% faster")
359
-
360
- print(f"\nDetailed results saved to: {report_file}")
361
-
362
- # Check if we meet the target (16s audio in <10s)
363
- target_results = [r for r in self.results
364
- if r.get("optimization") == "full" and r["audio_duration"] == 16]
365
- if target_results:
366
- meets_target = target_results[0]["process_time"] <= 10.0
367
- print(f"\n✅ Target Achievement (16s audio < 10s): {'YES' if meets_target else 'NO'}")
368
- print(f" Actual time: {target_results[0]['process_time']:.2f}s")
369
-
370
-
371
- if __name__ == "__main__":
372
- import tempfile
373
-
374
- tester = PerformanceTester()
375
- tester.run_comprehensive_test()