Spaces:
Runtime error
Runtime error
推論キャッシュと並列処理の機能を追加し、`process_talking_head_optimized`関数をキャッシュと並列処理に対応させました。また、Gradioインターフェースにキャッシュ管理機能を追加しました。
Browse files- app_optimized.py +175 -20
- core/optimization/__init__.py +10 -1
- core/optimization/inference_cache.py +386 -0
- core/optimization/parallel_inference.py +268 -0
- core/optimization/parallel_processing.py +400 -0
app_optimized.py
CHANGED
@@ -18,7 +18,12 @@ from core.optimization import (
|
|
18 |
GPUOptimizer,
|
19 |
AvatarCache,
|
20 |
AvatarTokenManager,
|
21 |
-
ColdStartOptimizer
|
|
|
|
|
|
|
|
|
|
|
22 |
)
|
23 |
|
24 |
# サンプルファイルのディレクトリを定義
|
@@ -44,6 +49,18 @@ avatar_cache = AvatarCache(cache_dir="/tmp/avatar_cache", ttl_days=14)
|
|
44 |
token_manager = AvatarTokenManager(avatar_cache)
|
45 |
print(f"✅ アバターキャッシュ初期化: {avatar_cache.get_cache_info()}")
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
# モデルの初期化(最適化版)
|
48 |
USE_PYTORCH = True
|
49 |
model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH)
|
@@ -92,6 +109,17 @@ except Exception as e:
|
|
92 |
traceback.print_exc()
|
93 |
raise
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
def prepare_avatar(image_file) -> Dict[str, Any]:
|
96 |
"""
|
97 |
画像を事前処理してアバタートークンを生成
|
@@ -150,16 +178,19 @@ def process_talking_head_optimized(
|
|
150 |
audio_file,
|
151 |
source_image,
|
152 |
avatar_token: Optional[str] = None,
|
153 |
-
use_resolution_optimization: bool = True
|
|
|
|
|
154 |
):
|
155 |
"""
|
156 |
-
最適化されたTalking Head
|
157 |
|
158 |
Args:
|
159 |
audio_file: 音声ファイル
|
160 |
source_image: ソース画像(avatar_tokenがない場合に使用)
|
161 |
avatar_token: 事前生成されたアバタートークン
|
162 |
use_resolution_optimization: 解像度最適化を使用するか
|
|
|
163 |
"""
|
164 |
|
165 |
if audio_file is None:
|
@@ -184,7 +215,6 @@ def process_talking_head_optimized(
|
|
184 |
|
185 |
# 解像度最適化設定を適用
|
186 |
if use_resolution_optimization:
|
187 |
-
# SDKに解像度設定を適用
|
188 |
setup_kwargs = {
|
189 |
"max_size": FIXED_RESOLUTION, # 320固定
|
190 |
"sampling_timesteps": resolution_optimizer.get_diffusion_steps() # 25
|
@@ -193,15 +223,68 @@ def process_talking_head_optimized(
|
|
193 |
else:
|
194 |
setup_kwargs = {}
|
195 |
|
196 |
-
#
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
# 結果の確認
|
207 |
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
|
@@ -210,8 +293,12 @@ def process_talking_head_optimized(
|
|
210 |
✅ 処理完了!
|
211 |
処理時間: {process_time:.2f}秒
|
212 |
解像度: {FIXED_RESOLUTION}×{FIXED_RESOLUTION}
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
215 |
"""
|
216 |
return output_path, perf_info
|
217 |
else:
|
@@ -233,6 +320,8 @@ with gr.Blocks(title="DittoTalkingHead - Phase 3 最適化版") as demo:
|
|
233 |
- 🎯 画像事前アップロード&キャッシュ機能
|
234 |
- ⚡ GPU最適化(Mixed Precision, torch.compile)
|
235 |
- 💾 Cold Start最適化
|
|
|
|
|
236 |
|
237 |
## 使い方
|
238 |
### 方法1: 通常の使用
|
@@ -271,6 +360,16 @@ with gr.Blocks(title="DittoTalkingHead - Phase 3 最適化版") as demo:
|
|
271 |
value=True
|
272 |
)
|
273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
generate_btn = gr.Button("🎬 生成", variant="primary")
|
275 |
|
276 |
with gr.Column():
|
@@ -305,16 +404,29 @@ with gr.Blocks(title="DittoTalkingHead - Phase 3 最適化版") as demo:
|
|
305 |
|
306 |
# タブ3: 最適化情報
|
307 |
with gr.TabItem("📊 最適化情報"):
|
308 |
-
gr.
|
|
|
|
|
|
|
309 |
### 現在の最適化設定
|
310 |
|
311 |
{resolution_optimizer.get_optimization_summary()}
|
312 |
|
313 |
{gpu_optimizer.get_optimization_summary()}
|
314 |
|
315 |
-
###
|
316 |
{avatar_cache.get_cache_info()}
|
|
|
|
|
|
|
317 |
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
|
319 |
# サンプル
|
320 |
example_audio = EXAMPLES_DIR / "audio.wav"
|
@@ -323,9 +435,9 @@ with gr.Blocks(title="DittoTalkingHead - Phase 3 最適化版") as demo:
|
|
323 |
if example_audio.exists() and example_image.exists():
|
324 |
gr.Examples(
|
325 |
examples=[
|
326 |
-
[str(example_audio), str(example_image), None, True]
|
327 |
],
|
328 |
-
inputs=[audio_input, image_input, token_input, use_optimization],
|
329 |
outputs=[video_output, status_output],
|
330 |
fn=process_talking_head_optimized
|
331 |
)
|
@@ -333,7 +445,7 @@ with gr.Blocks(title="DittoTalkingHead - Phase 3 最適化版") as demo:
|
|
333 |
# イベントハンドラ
|
334 |
generate_btn.click(
|
335 |
fn=process_talking_head_optimized,
|
336 |
-
inputs=[audio_input, image_input, token_input, use_optimization],
|
337 |
outputs=[video_output, status_output]
|
338 |
)
|
339 |
|
@@ -342,6 +454,49 @@ with gr.Blocks(title="DittoTalkingHead - Phase 3 最適化版") as demo:
|
|
342 |
inputs=[avatar_image_input],
|
343 |
outputs=[prepare_output]
|
344 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
|
346 |
if __name__ == "__main__":
|
347 |
# Cold Start最適化設定でGradioを起動
|
|
|
18 |
GPUOptimizer,
|
19 |
AvatarCache,
|
20 |
AvatarTokenManager,
|
21 |
+
ColdStartOptimizer,
|
22 |
+
InferenceCache,
|
23 |
+
CachedInference,
|
24 |
+
ParallelProcessor,
|
25 |
+
ParallelInference,
|
26 |
+
OptimizedInferenceWrapper
|
27 |
)
|
28 |
|
29 |
# サンプルファイルのディレクトリを定義
|
|
|
49 |
token_manager = AvatarTokenManager(avatar_cache)
|
50 |
print(f"✅ アバターキャッシュ初期化: {avatar_cache.get_cache_info()}")
|
51 |
|
52 |
+
# 5. 推論キャッシュの初期化
|
53 |
+
inference_cache = InferenceCache(
|
54 |
+
cache_dir="/tmp/inference_cache",
|
55 |
+
memory_cache_size=50,
|
56 |
+
file_cache_size_gb=5.0,
|
57 |
+
ttl_hours=24
|
58 |
+
)
|
59 |
+
cached_inference = CachedInference(inference_cache)
|
60 |
+
print(f"✅ 推論キャッシュ初期化: {inference_cache.get_cache_stats()}")
|
61 |
+
|
62 |
+
# 6. 並列処理の初期化(SDK初期化後に移動)
|
63 |
+
|
64 |
# モデルの初期化(最適化版)
|
65 |
USE_PYTORCH = True
|
66 |
model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH)
|
|
|
109 |
traceback.print_exc()
|
110 |
raise
|
111 |
|
112 |
+
# 並列処理の初期化(SDK初期化成功後)
|
113 |
+
parallel_processor = ParallelProcessor(num_threads=4, num_processes=2)
|
114 |
+
parallel_inference = ParallelInference(SDK, parallel_processor)
|
115 |
+
optimized_wrapper = OptimizedInferenceWrapper(
|
116 |
+
SDK,
|
117 |
+
use_parallel=True,
|
118 |
+
use_cache=True,
|
119 |
+
use_gpu_opt=True
|
120 |
+
)
|
121 |
+
print(f"✅ 並列処理初期化: {parallel_inference.get_performance_stats()}")
|
122 |
+
|
123 |
def prepare_avatar(image_file) -> Dict[str, Any]:
|
124 |
"""
|
125 |
画像を事前処理してアバタートークンを生成
|
|
|
178 |
audio_file,
|
179 |
source_image,
|
180 |
avatar_token: Optional[str] = None,
|
181 |
+
use_resolution_optimization: bool = True,
|
182 |
+
use_inference_cache: bool = True,
|
183 |
+
use_parallel_processing: bool = True
|
184 |
):
|
185 |
"""
|
186 |
+
最適化されたTalking Head生成処理(キャッシュ対応)
|
187 |
|
188 |
Args:
|
189 |
audio_file: 音声ファイル
|
190 |
source_image: ソース画像(avatar_tokenがない場合に使用)
|
191 |
avatar_token: 事前生成されたアバタートークン
|
192 |
use_resolution_optimization: 解像度最適化を使用するか
|
193 |
+
use_inference_cache: 推論キャッシュを使用するか
|
194 |
"""
|
195 |
|
196 |
if audio_file is None:
|
|
|
215 |
|
216 |
# 解像度最適化設定を適用
|
217 |
if use_resolution_optimization:
|
|
|
218 |
setup_kwargs = {
|
219 |
"max_size": FIXED_RESOLUTION, # 320固定
|
220 |
"sampling_timesteps": resolution_optimizer.get_diffusion_steps() # 25
|
|
|
223 |
else:
|
224 |
setup_kwargs = {}
|
225 |
|
226 |
+
# 処理方法の選択
|
227 |
+
if use_parallel_processing and source_image:
|
228 |
+
# 並列処理を使用
|
229 |
+
print("🔄 並列処理モードで実行...")
|
230 |
+
|
231 |
+
if use_inference_cache:
|
232 |
+
# キャッシュ + 並列処理
|
233 |
+
def inference_func(audio_path, image_path, out_path, **kwargs):
|
234 |
+
# 並列処理ラッパーを使用
|
235 |
+
optimized_wrapper.process(
|
236 |
+
audio_path, image_path, out_path,
|
237 |
+
seed=1024,
|
238 |
+
more_kwargs={"setup_kwargs": kwargs.get('setup_kwargs', {})}
|
239 |
+
)
|
240 |
+
|
241 |
+
# キャッシュシステムを通じて処理
|
242 |
+
result_path, cache_hit, process_time = cached_inference.process_with_cache(
|
243 |
+
inference_func,
|
244 |
+
audio_file,
|
245 |
+
source_image,
|
246 |
+
output_path,
|
247 |
+
resolution=f"{FIXED_RESOLUTION}x{FIXED_RESOLUTION}" if use_resolution_optimization else "default",
|
248 |
+
steps=setup_kwargs.get('sampling_timesteps', 50),
|
249 |
+
setup_kwargs=setup_kwargs
|
250 |
+
)
|
251 |
+
cache_status = "キャッシュヒット(並列)" if cache_hit else "新規生成(並列)"
|
252 |
+
else:
|
253 |
+
# 並列処理のみ
|
254 |
+
_, process_time, stats = optimized_wrapper.process(
|
255 |
+
audio_file, source_image, output_path,
|
256 |
+
seed=1024,
|
257 |
+
more_kwargs={"setup_kwargs": setup_kwargs}
|
258 |
+
)
|
259 |
+
cache_hit = False
|
260 |
+
cache_status = "並列処理(キャッシュ未使用)"
|
261 |
+
|
262 |
+
elif use_inference_cache and source_image:
|
263 |
+
# キャッシュのみ(並列処理なし)
|
264 |
+
def inference_func(audio_path, image_path, out_path, **kwargs):
|
265 |
+
seed_everything(1024)
|
266 |
+
run(SDK, audio_path, image_path, out_path,
|
267 |
+
more_kwargs={"setup_kwargs": kwargs.get('setup_kwargs', {})})
|
268 |
+
|
269 |
+
# キャッシュシステムを通じて処理
|
270 |
+
result_path, cache_hit, process_time = cached_inference.process_with_cache(
|
271 |
+
inference_func,
|
272 |
+
audio_file,
|
273 |
+
source_image,
|
274 |
+
output_path,
|
275 |
+
resolution=f"{FIXED_RESOLUTION}x{FIXED_RESOLUTION}" if use_resolution_optimization else "default",
|
276 |
+
steps=setup_kwargs.get('sampling_timesteps', 50),
|
277 |
+
setup_kwargs=setup_kwargs
|
278 |
+
)
|
279 |
+
cache_status = "キャッシュヒット" if cache_hit else "新規生成"
|
280 |
+
else:
|
281 |
+
# 通常処理(並列処理もキャッシュもなし)
|
282 |
+
print(f"処理開始: audio={audio_file}, image={source_image}, token={avatar_token is not None}")
|
283 |
+
seed_everything(1024)
|
284 |
+
run(SDK, audio_file, source_image, output_path, more_kwargs={"setup_kwargs": setup_kwargs})
|
285 |
+
process_time = time.time() - start_time
|
286 |
+
cache_hit = False
|
287 |
+
cache_status = "通常処理"
|
288 |
|
289 |
# 結果の確認
|
290 |
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
|
|
|
293 |
✅ 処理完了!
|
294 |
処理時間: {process_time:.2f}秒
|
295 |
解像度: {FIXED_RESOLUTION}×{FIXED_RESOLUTION}
|
296 |
+
最適化設定:
|
297 |
+
- 解像度最適化: {'有効' if use_resolution_optimization else '無効'}
|
298 |
+
- 並列処理: {'有効' if use_parallel_processing else '無効'}
|
299 |
+
- アバターキャッシュ: {'使用' if avatar_token else '未使用'}
|
300 |
+
- 推論キャッシュ: {cache_status}
|
301 |
+
キャッシュ統計: {inference_cache.get_cache_stats()['memory_cache_entries']}件(メモリ), {inference_cache.get_cache_stats()['file_cache_entries']}件(ファイル)
|
302 |
"""
|
303 |
return output_path, perf_info
|
304 |
else:
|
|
|
320 |
- 🎯 画像事前アップロード&キャッシュ機能
|
321 |
- ⚡ GPU最適化(Mixed Precision, torch.compile)
|
322 |
- 💾 Cold Start最適化
|
323 |
+
- 🔄 推論キャッシュ(同じ入力で即座に結果を返す)
|
324 |
+
- 🚀 並列処理(音声・画像の前処理を並列化)
|
325 |
|
326 |
## 使い方
|
327 |
### 方法1: 通常の使用
|
|
|
360 |
value=True
|
361 |
)
|
362 |
|
363 |
+
use_cache = gr.Checkbox(
|
364 |
+
label="推論キャッシュを使用(同じ入力で高速化)",
|
365 |
+
value=True
|
366 |
+
)
|
367 |
+
|
368 |
+
use_parallel = gr.Checkbox(
|
369 |
+
label="並列処理を使用(前処理を高速化)",
|
370 |
+
value=True
|
371 |
+
)
|
372 |
+
|
373 |
generate_btn = gr.Button("🎬 生成", variant="primary")
|
374 |
|
375 |
with gr.Column():
|
|
|
404 |
|
405 |
# タブ3: 最適化情報
|
406 |
with gr.TabItem("📊 最適化情報"):
|
407 |
+
with gr.Row():
|
408 |
+
refresh_btn = gr.Button("🔄 情報を更新", scale=1)
|
409 |
+
|
410 |
+
info_display = gr.Markdown(f"""
|
411 |
### 現在の最適化設定
|
412 |
|
413 |
{resolution_optimizer.get_optimization_summary()}
|
414 |
|
415 |
{gpu_optimizer.get_optimization_summary()}
|
416 |
|
417 |
+
### アバターキャッシュ情報
|
418 |
{avatar_cache.get_cache_info()}
|
419 |
+
|
420 |
+
### 推論キャッシュ情報
|
421 |
+
{inference_cache.get_cache_stats()}
|
422 |
""")
|
423 |
+
|
424 |
+
# キャッシュ管理ボタン
|
425 |
+
with gr.Row():
|
426 |
+
clear_inference_cache_btn = gr.Button("🗑️ 推論キャッシュをクリア", variant="secondary")
|
427 |
+
clear_avatar_cache_btn = gr.Button("🗑️ アバターキャッシュをクリア", variant="secondary")
|
428 |
+
|
429 |
+
cache_status = gr.Textbox(label="キャッシュ操作ステータス", lines=2)
|
430 |
|
431 |
# サンプル
|
432 |
example_audio = EXAMPLES_DIR / "audio.wav"
|
|
|
435 |
if example_audio.exists() and example_image.exists():
|
436 |
gr.Examples(
|
437 |
examples=[
|
438 |
+
[str(example_audio), str(example_image), None, True, True, True]
|
439 |
],
|
440 |
+
inputs=[audio_input, image_input, token_input, use_optimization, use_cache, use_parallel],
|
441 |
outputs=[video_output, status_output],
|
442 |
fn=process_talking_head_optimized
|
443 |
)
|
|
|
445 |
# イベントハンドラ
|
446 |
generate_btn.click(
|
447 |
fn=process_talking_head_optimized,
|
448 |
+
inputs=[audio_input, image_input, token_input, use_optimization, use_cache, use_parallel],
|
449 |
outputs=[video_output, status_output]
|
450 |
)
|
451 |
|
|
|
454 |
inputs=[avatar_image_input],
|
455 |
outputs=[prepare_output]
|
456 |
)
|
457 |
+
|
458 |
+
# キャッシュ管理関数
|
459 |
+
def refresh_info():
|
460 |
+
return f"""
|
461 |
+
### 現在の最適化設定
|
462 |
+
|
463 |
+
{resolution_optimizer.get_optimization_summary()}
|
464 |
+
|
465 |
+
{gpu_optimizer.get_optimization_summary()}
|
466 |
+
|
467 |
+
### アバターキャッシュ情報
|
468 |
+
{avatar_cache.get_cache_info()}
|
469 |
+
|
470 |
+
### 推論キャッシュ情報
|
471 |
+
{inference_cache.get_cache_stats()}
|
472 |
+
|
473 |
+
### 並列処理情報
|
474 |
+
{parallel_inference.get_performance_stats()}
|
475 |
+
"""
|
476 |
+
|
477 |
+
def clear_inference_cache():
|
478 |
+
inference_cache.clear_cache()
|
479 |
+
return "✅ 推論キャッシュをクリアしました"
|
480 |
+
|
481 |
+
def clear_avatar_cache():
|
482 |
+
avatar_cache.clear_cache()
|
483 |
+
return "✅ アバターキャッシュをクリアしました"
|
484 |
+
|
485 |
+
# キャッシュ管理イベント
|
486 |
+
refresh_btn.click(
|
487 |
+
fn=refresh_info,
|
488 |
+
outputs=[info_display]
|
489 |
+
)
|
490 |
+
|
491 |
+
clear_inference_cache_btn.click(
|
492 |
+
fn=clear_inference_cache,
|
493 |
+
outputs=[cache_status]
|
494 |
+
)
|
495 |
+
|
496 |
+
clear_avatar_cache_btn.click(
|
497 |
+
fn=clear_avatar_cache,
|
498 |
+
outputs=[cache_status]
|
499 |
+
)
|
500 |
|
501 |
if __name__ == "__main__":
|
502 |
# Cold Start最適化設定でGradioを起動
|
core/optimization/__init__.py
CHANGED
@@ -6,6 +6,9 @@ from .resolution_optimization import FixedResolutionProcessor
|
|
6 |
from .gpu_optimization import GPUOptimizer, OptimizedInference
|
7 |
from .avatar_cache import AvatarCache, AvatarTokenManager
|
8 |
from .cold_start_optimization import ColdStartOptimizer
|
|
|
|
|
|
|
9 |
|
10 |
__all__ = [
|
11 |
'FixedResolutionProcessor',
|
@@ -13,5 +16,11 @@ __all__ = [
|
|
13 |
'OptimizedInference',
|
14 |
'AvatarCache',
|
15 |
'AvatarTokenManager',
|
16 |
-
'ColdStartOptimizer'
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
]
|
|
|
6 |
from .gpu_optimization import GPUOptimizer, OptimizedInference
|
7 |
from .avatar_cache import AvatarCache, AvatarTokenManager
|
8 |
from .cold_start_optimization import ColdStartOptimizer
|
9 |
+
from .inference_cache import InferenceCache, CachedInference
|
10 |
+
from .parallel_processing import ParallelProcessor, PipelineProcessor
|
11 |
+
from .parallel_inference import ParallelInference, OptimizedInferenceWrapper
|
12 |
|
13 |
__all__ = [
|
14 |
'FixedResolutionProcessor',
|
|
|
16 |
'OptimizedInference',
|
17 |
'AvatarCache',
|
18 |
'AvatarTokenManager',
|
19 |
+
'ColdStartOptimizer',
|
20 |
+
'InferenceCache',
|
21 |
+
'CachedInference',
|
22 |
+
'ParallelProcessor',
|
23 |
+
'PipelineProcessor',
|
24 |
+
'ParallelInference',
|
25 |
+
'OptimizedInferenceWrapper'
|
26 |
]
|
core/optimization/inference_cache.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Inference Cache System for DittoTalkingHead
|
3 |
+
Caches video generation results for faster repeated processing
|
4 |
+
"""
|
5 |
+
|
6 |
+
import hashlib
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import pickle
|
10 |
+
import time
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import Optional, Dict, Any, Tuple, Union
|
13 |
+
from functools import lru_cache
|
14 |
+
import shutil
|
15 |
+
from datetime import datetime, timedelta
|
16 |
+
|
17 |
+
|
18 |
+
class InferenceCache:
|
19 |
+
"""
|
20 |
+
Cache system for video generation results
|
21 |
+
Supports both memory and file-based caching
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
cache_dir: str = "/tmp/inference_cache",
|
27 |
+
memory_cache_size: int = 100,
|
28 |
+
file_cache_size_gb: float = 10.0,
|
29 |
+
ttl_hours: int = 24
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Initialize inference cache
|
33 |
+
|
34 |
+
Args:
|
35 |
+
cache_dir: Directory for file-based cache
|
36 |
+
memory_cache_size: Maximum number of items in memory cache
|
37 |
+
file_cache_size_gb: Maximum size of file cache in GB
|
38 |
+
ttl_hours: Time to live for cache entries in hours
|
39 |
+
"""
|
40 |
+
self.cache_dir = Path(cache_dir)
|
41 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
42 |
+
|
43 |
+
self.memory_cache_size = memory_cache_size
|
44 |
+
self.file_cache_size_bytes = int(file_cache_size_gb * 1024 * 1024 * 1024)
|
45 |
+
self.ttl_seconds = ttl_hours * 3600
|
46 |
+
|
47 |
+
# Metadata file for managing cache
|
48 |
+
self.metadata_file = self.cache_dir / "cache_metadata.json"
|
49 |
+
self.metadata = self._load_metadata()
|
50 |
+
|
51 |
+
# In-memory cache
|
52 |
+
self._memory_cache = {}
|
53 |
+
self._access_times = {}
|
54 |
+
|
55 |
+
# Clean up expired entries on initialization
|
56 |
+
self._cleanup_expired()
|
57 |
+
|
58 |
+
def _load_metadata(self) -> Dict[str, Any]:
|
59 |
+
"""Load cache metadata"""
|
60 |
+
if self.metadata_file.exists():
|
61 |
+
try:
|
62 |
+
with open(self.metadata_file, 'r') as f:
|
63 |
+
return json.load(f)
|
64 |
+
except:
|
65 |
+
return {}
|
66 |
+
return {}
|
67 |
+
|
68 |
+
def _save_metadata(self):
|
69 |
+
"""Save cache metadata"""
|
70 |
+
with open(self.metadata_file, 'w') as f:
|
71 |
+
json.dump(self.metadata, f, indent=2)
|
72 |
+
|
73 |
+
def generate_cache_key(
|
74 |
+
self,
|
75 |
+
audio_path: str,
|
76 |
+
image_path: str,
|
77 |
+
**kwargs
|
78 |
+
) -> str:
|
79 |
+
"""
|
80 |
+
Generate unique cache key based on input parameters
|
81 |
+
|
82 |
+
Args:
|
83 |
+
audio_path: Path to audio file
|
84 |
+
image_path: Path to image file
|
85 |
+
**kwargs: Additional parameters affecting output
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
SHA-256 hash as cache key
|
89 |
+
"""
|
90 |
+
# Read file contents for hashing
|
91 |
+
with open(audio_path, 'rb') as f:
|
92 |
+
audio_hash = hashlib.sha256(f.read()).hexdigest()
|
93 |
+
|
94 |
+
with open(image_path, 'rb') as f:
|
95 |
+
image_hash = hashlib.sha256(f.read()).hexdigest()
|
96 |
+
|
97 |
+
# Include relevant parameters in key
|
98 |
+
key_data = {
|
99 |
+
'audio': audio_hash,
|
100 |
+
'image': image_hash,
|
101 |
+
'resolution': kwargs.get('resolution', '320x320'),
|
102 |
+
'steps': kwargs.get('steps', 25),
|
103 |
+
'seed': kwargs.get('seed', None)
|
104 |
+
}
|
105 |
+
|
106 |
+
# Generate final key
|
107 |
+
key_str = json.dumps(key_data, sort_keys=True)
|
108 |
+
return hashlib.sha256(key_str.encode()).hexdigest()
|
109 |
+
|
110 |
+
def get_from_memory(self, cache_key: str) -> Optional[str]:
|
111 |
+
"""
|
112 |
+
Get video path from memory cache
|
113 |
+
|
114 |
+
Args:
|
115 |
+
cache_key: Cache key
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
Video file path if found, None otherwise
|
119 |
+
"""
|
120 |
+
if cache_key in self._memory_cache:
|
121 |
+
self._access_times[cache_key] = time.time()
|
122 |
+
return self._memory_cache[cache_key]
|
123 |
+
return None
|
124 |
+
|
125 |
+
def get_from_file(self, cache_key: str) -> Optional[str]:
|
126 |
+
"""
|
127 |
+
Get video path from file cache
|
128 |
+
|
129 |
+
Args:
|
130 |
+
cache_key: Cache key
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
Video file path if found, None otherwise
|
134 |
+
"""
|
135 |
+
if cache_key not in self.metadata:
|
136 |
+
return None
|
137 |
+
|
138 |
+
entry = self.metadata[cache_key]
|
139 |
+
|
140 |
+
# Check expiration
|
141 |
+
if time.time() > entry['expires_at']:
|
142 |
+
self._remove_cache_entry(cache_key)
|
143 |
+
return None
|
144 |
+
|
145 |
+
# Check if file exists
|
146 |
+
video_path = self.cache_dir / entry['filename']
|
147 |
+
if not video_path.exists():
|
148 |
+
self._remove_cache_entry(cache_key)
|
149 |
+
return None
|
150 |
+
|
151 |
+
# Update access time
|
152 |
+
self.metadata[cache_key]['last_access'] = time.time()
|
153 |
+
self._save_metadata()
|
154 |
+
|
155 |
+
# Add to memory cache
|
156 |
+
self._add_to_memory_cache(cache_key, str(video_path))
|
157 |
+
|
158 |
+
return str(video_path)
|
159 |
+
|
160 |
+
def get(self, cache_key: str) -> Optional[str]:
|
161 |
+
"""
|
162 |
+
Get video from cache (memory first, then file)
|
163 |
+
|
164 |
+
Args:
|
165 |
+
cache_key: Cache key
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
Video file path if found, None otherwise
|
169 |
+
"""
|
170 |
+
# Try memory cache first
|
171 |
+
result = self.get_from_memory(cache_key)
|
172 |
+
if result:
|
173 |
+
return result
|
174 |
+
|
175 |
+
# Try file cache
|
176 |
+
return self.get_from_file(cache_key)
|
177 |
+
|
178 |
+
def put(
|
179 |
+
self,
|
180 |
+
cache_key: str,
|
181 |
+
video_path: str,
|
182 |
+
**metadata
|
183 |
+
) -> bool:
|
184 |
+
"""
|
185 |
+
Store video in cache
|
186 |
+
|
187 |
+
Args:
|
188 |
+
cache_key: Cache key
|
189 |
+
video_path: Path to generated video
|
190 |
+
**metadata: Additional metadata to store
|
191 |
+
|
192 |
+
Returns:
|
193 |
+
True if stored successfully
|
194 |
+
"""
|
195 |
+
try:
|
196 |
+
# Copy video to cache directory
|
197 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
198 |
+
cache_filename = f"{cache_key[:8]}_{timestamp}.mp4"
|
199 |
+
cache_video_path = self.cache_dir / cache_filename
|
200 |
+
|
201 |
+
shutil.copy2(video_path, cache_video_path)
|
202 |
+
|
203 |
+
# Store metadata
|
204 |
+
self.metadata[cache_key] = {
|
205 |
+
'filename': cache_filename,
|
206 |
+
'created_at': time.time(),
|
207 |
+
'expires_at': time.time() + self.ttl_seconds,
|
208 |
+
'last_access': time.time(),
|
209 |
+
'size_bytes': os.path.getsize(cache_video_path),
|
210 |
+
'metadata': metadata
|
211 |
+
}
|
212 |
+
|
213 |
+
# Check cache size and clean if needed
|
214 |
+
self._check_cache_size()
|
215 |
+
|
216 |
+
# Save metadata
|
217 |
+
self._save_metadata()
|
218 |
+
|
219 |
+
# Add to memory cache
|
220 |
+
self._add_to_memory_cache(cache_key, str(cache_video_path))
|
221 |
+
|
222 |
+
return True
|
223 |
+
|
224 |
+
except Exception as e:
|
225 |
+
print(f"Error storing cache: {e}")
|
226 |
+
return False
|
227 |
+
|
228 |
+
def _add_to_memory_cache(self, cache_key: str, video_path: str):
|
229 |
+
"""Add item to memory cache with LRU eviction"""
|
230 |
+
# Check if we need to evict
|
231 |
+
if len(self._memory_cache) >= self.memory_cache_size:
|
232 |
+
# Find least recently used
|
233 |
+
lru_key = min(self._access_times, key=self._access_times.get)
|
234 |
+
del self._memory_cache[lru_key]
|
235 |
+
del self._access_times[lru_key]
|
236 |
+
|
237 |
+
self._memory_cache[cache_key] = video_path
|
238 |
+
self._access_times[cache_key] = time.time()
|
239 |
+
|
240 |
+
def _check_cache_size(self):
|
241 |
+
"""Check and maintain cache size limit"""
|
242 |
+
total_size = sum(
|
243 |
+
entry['size_bytes']
|
244 |
+
for entry in self.metadata.values()
|
245 |
+
)
|
246 |
+
|
247 |
+
if total_size > self.file_cache_size_bytes:
|
248 |
+
# Remove oldest entries until under limit
|
249 |
+
sorted_entries = sorted(
|
250 |
+
self.metadata.items(),
|
251 |
+
key=lambda x: x[1]['last_access']
|
252 |
+
)
|
253 |
+
|
254 |
+
while total_size > self.file_cache_size_bytes and sorted_entries:
|
255 |
+
key_to_remove, entry = sorted_entries.pop(0)
|
256 |
+
total_size -= entry['size_bytes']
|
257 |
+
self._remove_cache_entry(key_to_remove)
|
258 |
+
|
259 |
+
def _cleanup_expired(self):
|
260 |
+
"""Remove expired cache entries"""
|
261 |
+
current_time = time.time()
|
262 |
+
expired_keys = [
|
263 |
+
key for key, entry in self.metadata.items()
|
264 |
+
if current_time > entry['expires_at']
|
265 |
+
]
|
266 |
+
|
267 |
+
for key in expired_keys:
|
268 |
+
self._remove_cache_entry(key)
|
269 |
+
|
270 |
+
if expired_keys:
|
271 |
+
print(f"Cleaned up {len(expired_keys)} expired cache entries")
|
272 |
+
|
273 |
+
def _remove_cache_entry(self, cache_key: str):
|
274 |
+
"""Remove a cache entry"""
|
275 |
+
if cache_key in self.metadata:
|
276 |
+
# Remove file
|
277 |
+
video_file = self.cache_dir / self.metadata[cache_key]['filename']
|
278 |
+
if video_file.exists():
|
279 |
+
video_file.unlink()
|
280 |
+
|
281 |
+
# Remove from metadata
|
282 |
+
del self.metadata[cache_key]
|
283 |
+
|
284 |
+
# Remove from memory cache
|
285 |
+
if cache_key in self._memory_cache:
|
286 |
+
del self._memory_cache[cache_key]
|
287 |
+
del self._access_times[cache_key]
|
288 |
+
|
289 |
+
def clear_cache(self):
|
290 |
+
"""Clear all cache entries"""
|
291 |
+
# Remove all video files
|
292 |
+
for file in self.cache_dir.glob("*.mp4"):
|
293 |
+
file.unlink()
|
294 |
+
|
295 |
+
# Clear metadata
|
296 |
+
self.metadata = {}
|
297 |
+
self._save_metadata()
|
298 |
+
|
299 |
+
# Clear memory cache
|
300 |
+
self._memory_cache.clear()
|
301 |
+
self._access_times.clear()
|
302 |
+
|
303 |
+
print("Inference cache cleared")
|
304 |
+
|
305 |
+
def get_cache_stats(self) -> Dict[str, Any]:
|
306 |
+
"""Get cache statistics"""
|
307 |
+
total_size = sum(
|
308 |
+
entry['size_bytes']
|
309 |
+
for entry in self.metadata.values()
|
310 |
+
)
|
311 |
+
|
312 |
+
memory_hits = len(self._memory_cache)
|
313 |
+
file_entries = len(self.metadata)
|
314 |
+
|
315 |
+
return {
|
316 |
+
'memory_cache_entries': memory_hits,
|
317 |
+
'file_cache_entries': file_entries,
|
318 |
+
'total_cache_size_mb': total_size / (1024 * 1024),
|
319 |
+
'cache_size_limit_gb': self.file_cache_size_bytes / (1024 * 1024 * 1024),
|
320 |
+
'ttl_hours': self.ttl_seconds / 3600,
|
321 |
+
'cache_directory': str(self.cache_dir)
|
322 |
+
}
|
323 |
+
|
324 |
+
|
325 |
+
class CachedInference:
|
326 |
+
"""
|
327 |
+
Wrapper for cached inference execution
|
328 |
+
"""
|
329 |
+
|
330 |
+
def __init__(self, cache: InferenceCache):
|
331 |
+
"""
|
332 |
+
Initialize cached inference
|
333 |
+
|
334 |
+
Args:
|
335 |
+
cache: InferenceCache instance
|
336 |
+
"""
|
337 |
+
self.cache = cache
|
338 |
+
|
339 |
+
def process_with_cache(
|
340 |
+
self,
|
341 |
+
inference_func: callable,
|
342 |
+
audio_path: str,
|
343 |
+
image_path: str,
|
344 |
+
output_path: str,
|
345 |
+
**kwargs
|
346 |
+
) -> Tuple[str, bool, float]:
|
347 |
+
"""
|
348 |
+
Process with caching
|
349 |
+
|
350 |
+
Args:
|
351 |
+
inference_func: Function to generate video
|
352 |
+
audio_path: Path to audio file
|
353 |
+
image_path: Path to image file
|
354 |
+
output_path: Desired output path
|
355 |
+
**kwargs: Additional parameters
|
356 |
+
|
357 |
+
Returns:
|
358 |
+
Tuple of (output_path, cache_hit, process_time)
|
359 |
+
"""
|
360 |
+
start_time = time.time()
|
361 |
+
|
362 |
+
# Generate cache key
|
363 |
+
cache_key = self.cache.generate_cache_key(
|
364 |
+
audio_path, image_path, **kwargs
|
365 |
+
)
|
366 |
+
|
367 |
+
# Check cache
|
368 |
+
cached_video = self.cache.get(cache_key)
|
369 |
+
|
370 |
+
if cached_video:
|
371 |
+
# Cache hit - copy to output path
|
372 |
+
shutil.copy2(cached_video, output_path)
|
373 |
+
process_time = time.time() - start_time
|
374 |
+
print(f"✅ Cache hit! Retrieved in {process_time:.2f}s")
|
375 |
+
return output_path, True, process_time
|
376 |
+
|
377 |
+
# Cache miss - generate video
|
378 |
+
print("Cache miss - generating video...")
|
379 |
+
inference_func(audio_path, image_path, output_path, **kwargs)
|
380 |
+
|
381 |
+
# Store in cache
|
382 |
+
if os.path.exists(output_path):
|
383 |
+
self.cache.put(cache_key, output_path, **kwargs)
|
384 |
+
|
385 |
+
process_time = time.time() - start_time
|
386 |
+
return output_path, False, process_time
|
core/optimization/parallel_inference.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Parallel Inference Integration for DittoTalkingHead
|
3 |
+
Integrates parallel processing into the inference pipeline
|
4 |
+
"""
|
5 |
+
|
6 |
+
import asyncio
|
7 |
+
import time
|
8 |
+
from typing import Dict, Any, Tuple, Optional
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
from .parallel_processing import ParallelProcessor, PipelineProcessor
|
14 |
+
|
15 |
+
|
16 |
+
class ParallelInference:
|
17 |
+
"""
|
18 |
+
Parallel inference wrapper for DittoTalkingHead
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, sdk, parallel_processor: Optional[ParallelProcessor] = None):
|
22 |
+
"""
|
23 |
+
Initialize parallel inference
|
24 |
+
|
25 |
+
Args:
|
26 |
+
sdk: StreamSDK instance
|
27 |
+
parallel_processor: ParallelProcessor instance
|
28 |
+
"""
|
29 |
+
self.sdk = sdk
|
30 |
+
self.parallel_processor = parallel_processor or ParallelProcessor(num_threads=4)
|
31 |
+
|
32 |
+
# Setup pipeline stages
|
33 |
+
self.pipeline_stages = {
|
34 |
+
'load': self._load_files,
|
35 |
+
'preprocess': self._preprocess,
|
36 |
+
'inference': self._inference,
|
37 |
+
'postprocess': self._postprocess
|
38 |
+
}
|
39 |
+
|
40 |
+
def _load_files(self, paths: Dict[str, str]) -> Dict[str, Any]:
|
41 |
+
"""Load audio and image files"""
|
42 |
+
audio_path = paths['audio']
|
43 |
+
image_path = paths['image']
|
44 |
+
|
45 |
+
# Parallel loading
|
46 |
+
audio_data, image_data = self.parallel_processor.preprocess_parallel_sync(
|
47 |
+
audio_path, image_path
|
48 |
+
)
|
49 |
+
|
50 |
+
return {
|
51 |
+
'audio_data': audio_data,
|
52 |
+
'image_data': image_data,
|
53 |
+
'paths': paths
|
54 |
+
}
|
55 |
+
|
56 |
+
def _preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
57 |
+
"""Preprocess loaded data"""
|
58 |
+
# Extract audio features
|
59 |
+
audio = data['audio_data']['audio']
|
60 |
+
sr = data['audio_data']['sample_rate']
|
61 |
+
|
62 |
+
# Prepare for SDK
|
63 |
+
import librosa
|
64 |
+
import math
|
65 |
+
|
66 |
+
# Calculate number of frames
|
67 |
+
num_frames = math.ceil(len(audio) / 16000 * 25)
|
68 |
+
|
69 |
+
# Prepare image
|
70 |
+
image = data['image_data']['image']
|
71 |
+
|
72 |
+
return {
|
73 |
+
'audio': audio,
|
74 |
+
'image': image,
|
75 |
+
'num_frames': num_frames,
|
76 |
+
'paths': data['paths']
|
77 |
+
}
|
78 |
+
|
79 |
+
def _inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
80 |
+
"""Run inference"""
|
81 |
+
# This would integrate with the actual SDK inference
|
82 |
+
# For now, placeholder
|
83 |
+
return {
|
84 |
+
'result': 'inference_result',
|
85 |
+
'paths': data['paths']
|
86 |
+
}
|
87 |
+
|
88 |
+
def _postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
89 |
+
"""Postprocess results"""
|
90 |
+
return data
|
91 |
+
|
92 |
+
async def process_parallel_async(
|
93 |
+
self,
|
94 |
+
audio_path: str,
|
95 |
+
image_path: str,
|
96 |
+
output_path: str,
|
97 |
+
**kwargs
|
98 |
+
) -> Tuple[str, float]:
|
99 |
+
"""
|
100 |
+
Process with full parallelization (async)
|
101 |
+
|
102 |
+
Args:
|
103 |
+
audio_path: Path to audio file
|
104 |
+
image_path: Path to image file
|
105 |
+
output_path: Output video path
|
106 |
+
**kwargs: Additional parameters
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
Tuple of (output_path, process_time)
|
110 |
+
"""
|
111 |
+
start_time = time.time()
|
112 |
+
|
113 |
+
# Parallel preprocessing
|
114 |
+
audio_data, image_data = await self.parallel_processor.preprocess_parallel_async(
|
115 |
+
audio_path, image_path, kwargs.get('target_size', 320)
|
116 |
+
)
|
117 |
+
|
118 |
+
# Run inference (simplified for integration)
|
119 |
+
# In real implementation, this would call SDK methods
|
120 |
+
|
121 |
+
process_time = time.time() - start_time
|
122 |
+
return output_path, process_time
|
123 |
+
|
124 |
+
def process_parallel_sync(
|
125 |
+
self,
|
126 |
+
audio_path: str,
|
127 |
+
image_path: str,
|
128 |
+
output_path: str,
|
129 |
+
**kwargs
|
130 |
+
) -> Tuple[str, float]:
|
131 |
+
"""
|
132 |
+
Process with parallelization (sync)
|
133 |
+
|
134 |
+
Args:
|
135 |
+
audio_path: Path to audio file
|
136 |
+
image_path: Path to image file
|
137 |
+
output_path: Output video path
|
138 |
+
**kwargs: Additional parameters
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
Tuple of (output_path, process_time)
|
142 |
+
"""
|
143 |
+
start_time = time.time()
|
144 |
+
|
145 |
+
try:
|
146 |
+
# Parallel preprocessing
|
147 |
+
print("🔄 Starting parallel preprocessing...")
|
148 |
+
preprocess_start = time.time()
|
149 |
+
|
150 |
+
audio_data, image_data = self.parallel_processor.preprocess_parallel_sync(
|
151 |
+
audio_path, image_path, kwargs.get('target_size', 320)
|
152 |
+
)
|
153 |
+
|
154 |
+
preprocess_time = time.time() - preprocess_start
|
155 |
+
print(f"✅ Parallel preprocessing completed in {preprocess_time:.2f}s")
|
156 |
+
|
157 |
+
# Run actual SDK inference
|
158 |
+
# This integrates with the existing SDK
|
159 |
+
from inference import run, seed_everything
|
160 |
+
|
161 |
+
seed_everything(kwargs.get('seed', 1024))
|
162 |
+
|
163 |
+
inference_start = time.time()
|
164 |
+
run(self.sdk, audio_path, image_path, output_path, more_kwargs=kwargs.get('more_kwargs', {}))
|
165 |
+
inference_time = time.time() - inference_start
|
166 |
+
|
167 |
+
print(f"✅ Inference completed in {inference_time:.2f}s")
|
168 |
+
|
169 |
+
total_time = time.time() - start_time
|
170 |
+
|
171 |
+
# Performance breakdown
|
172 |
+
print(f"""
|
173 |
+
🎯 Performance Breakdown:
|
174 |
+
- Preprocessing (parallel): {preprocess_time:.2f}s
|
175 |
+
- Inference: {inference_time:.2f}s
|
176 |
+
- Total: {total_time:.2f}s
|
177 |
+
""")
|
178 |
+
|
179 |
+
return output_path, total_time
|
180 |
+
|
181 |
+
except Exception as e:
|
182 |
+
print(f"❌ Error in parallel processing: {e}")
|
183 |
+
raise
|
184 |
+
|
185 |
+
def get_performance_stats(self) -> Dict[str, Any]:
|
186 |
+
"""Get performance statistics"""
|
187 |
+
return {
|
188 |
+
'num_threads': self.parallel_processor.num_threads,
|
189 |
+
'num_processes': self.parallel_processor.num_processes,
|
190 |
+
'cuda_streams_enabled': self.parallel_processor.use_cuda_streams
|
191 |
+
}
|
192 |
+
|
193 |
+
|
194 |
+
class OptimizedInferenceWrapper:
|
195 |
+
"""
|
196 |
+
Wrapper that combines all optimizations
|
197 |
+
"""
|
198 |
+
|
199 |
+
def __init__(
|
200 |
+
self,
|
201 |
+
sdk,
|
202 |
+
use_parallel: bool = True,
|
203 |
+
use_cache: bool = True,
|
204 |
+
use_gpu_opt: bool = True
|
205 |
+
):
|
206 |
+
"""
|
207 |
+
Initialize optimized inference wrapper
|
208 |
+
|
209 |
+
Args:
|
210 |
+
sdk: StreamSDK instance
|
211 |
+
use_parallel: Enable parallel processing
|
212 |
+
use_cache: Enable caching
|
213 |
+
use_gpu_opt: Enable GPU optimizations
|
214 |
+
"""
|
215 |
+
self.sdk = sdk
|
216 |
+
self.use_parallel = use_parallel
|
217 |
+
self.use_cache = use_cache
|
218 |
+
self.use_gpu_opt = use_gpu_opt
|
219 |
+
|
220 |
+
# Initialize components
|
221 |
+
if use_parallel:
|
222 |
+
self.parallel_processor = ParallelProcessor(num_threads=4)
|
223 |
+
self.parallel_inference = ParallelInference(sdk, self.parallel_processor)
|
224 |
+
else:
|
225 |
+
self.parallel_processor = None
|
226 |
+
self.parallel_inference = None
|
227 |
+
|
228 |
+
def process(
|
229 |
+
self,
|
230 |
+
audio_path: str,
|
231 |
+
image_path: str,
|
232 |
+
output_path: str,
|
233 |
+
**kwargs
|
234 |
+
) -> Tuple[str, float, Dict[str, Any]]:
|
235 |
+
"""
|
236 |
+
Process with all optimizations
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
Tuple of (output_path, process_time, stats)
|
240 |
+
"""
|
241 |
+
stats = {
|
242 |
+
'parallel_enabled': self.use_parallel,
|
243 |
+
'cache_enabled': self.use_cache,
|
244 |
+
'gpu_opt_enabled': self.use_gpu_opt
|
245 |
+
}
|
246 |
+
|
247 |
+
if self.use_parallel and self.parallel_inference:
|
248 |
+
output_path, process_time = self.parallel_inference.process_parallel_sync(
|
249 |
+
audio_path, image_path, output_path, **kwargs
|
250 |
+
)
|
251 |
+
stats['preprocessing'] = 'parallel'
|
252 |
+
else:
|
253 |
+
# Fallback to sequential
|
254 |
+
from inference import run, seed_everything
|
255 |
+
start_time = time.time()
|
256 |
+
seed_everything(kwargs.get('seed', 1024))
|
257 |
+
run(self.sdk, audio_path, image_path, output_path, more_kwargs=kwargs.get('more_kwargs', {}))
|
258 |
+
process_time = time.time() - start_time
|
259 |
+
stats['preprocessing'] = 'sequential'
|
260 |
+
|
261 |
+
stats['process_time'] = process_time
|
262 |
+
|
263 |
+
return output_path, process_time, stats
|
264 |
+
|
265 |
+
def shutdown(self):
|
266 |
+
"""Cleanup resources"""
|
267 |
+
if self.parallel_processor:
|
268 |
+
self.parallel_processor.shutdown()
|
core/optimization/parallel_processing.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Parallel Processing Module for DittoTalkingHead
|
3 |
+
Implements concurrent audio and image preprocessing
|
4 |
+
"""
|
5 |
+
|
6 |
+
import asyncio
|
7 |
+
import concurrent.futures
|
8 |
+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
9 |
+
import time
|
10 |
+
from typing import Tuple, Dict, Any, Optional, Callable
|
11 |
+
import numpy as np
|
12 |
+
from pathlib import Path
|
13 |
+
import threading
|
14 |
+
import queue
|
15 |
+
import torch
|
16 |
+
from functools import partial
|
17 |
+
|
18 |
+
|
19 |
+
class ParallelProcessor:
|
20 |
+
"""
|
21 |
+
Parallel processing for audio and image preprocessing
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
num_threads: int = 4,
|
27 |
+
num_processes: int = 2,
|
28 |
+
use_cuda_streams: bool = True
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Initialize parallel processor
|
32 |
+
|
33 |
+
Args:
|
34 |
+
num_threads: Number of threads for I/O operations
|
35 |
+
num_processes: Number of processes for CPU-intensive tasks
|
36 |
+
use_cuda_streams: Use CUDA streams for GPU operations
|
37 |
+
"""
|
38 |
+
self.num_threads = num_threads
|
39 |
+
self.num_processes = num_processes
|
40 |
+
self.use_cuda_streams = use_cuda_streams and torch.cuda.is_available()
|
41 |
+
|
42 |
+
# Thread pool for I/O operations
|
43 |
+
self.thread_executor = ThreadPoolExecutor(max_workers=num_threads)
|
44 |
+
|
45 |
+
# Process pool for CPU-intensive operations
|
46 |
+
self.process_executor = ProcessPoolExecutor(max_workers=num_processes)
|
47 |
+
|
48 |
+
# CUDA streams for GPU operations
|
49 |
+
if self.use_cuda_streams:
|
50 |
+
self.cuda_streams = [torch.cuda.Stream() for _ in range(2)]
|
51 |
+
else:
|
52 |
+
self.cuda_streams = None
|
53 |
+
|
54 |
+
print(f"✅ ParallelProcessor initialized: {num_threads} threads, {num_processes} processes")
|
55 |
+
if self.use_cuda_streams:
|
56 |
+
print("✅ CUDA streams enabled for GPU parallelism")
|
57 |
+
|
58 |
+
def preprocess_audio_parallel(self, audio_path: str) -> Dict[str, Any]:
|
59 |
+
"""
|
60 |
+
Preprocess audio file in parallel
|
61 |
+
|
62 |
+
Args:
|
63 |
+
audio_path: Path to audio file
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
Preprocessed audio data
|
67 |
+
"""
|
68 |
+
import librosa
|
69 |
+
|
70 |
+
# Define subtasks
|
71 |
+
def load_audio():
|
72 |
+
return librosa.load(audio_path, sr=16000)
|
73 |
+
|
74 |
+
def extract_features(audio, sr):
|
75 |
+
# Extract various audio features in parallel
|
76 |
+
features = {}
|
77 |
+
|
78 |
+
# MFCC features
|
79 |
+
features['mfcc'] = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=13)
|
80 |
+
|
81 |
+
# Spectral features
|
82 |
+
features['spectral_centroid'] = librosa.feature.spectral_centroid(y=audio, sr=sr)
|
83 |
+
features['spectral_rolloff'] = librosa.feature.spectral_rolloff(y=audio, sr=sr)
|
84 |
+
|
85 |
+
return features
|
86 |
+
|
87 |
+
# Load audio
|
88 |
+
audio, sr = load_audio()
|
89 |
+
|
90 |
+
# Extract features in parallel (if needed)
|
91 |
+
features = extract_features(audio, sr)
|
92 |
+
|
93 |
+
return {
|
94 |
+
'audio': audio,
|
95 |
+
'sample_rate': sr,
|
96 |
+
'features': features,
|
97 |
+
'duration': len(audio) / sr
|
98 |
+
}
|
99 |
+
|
100 |
+
def preprocess_image_parallel(self, image_path: str, target_size: int = 320) -> Dict[str, Any]:
|
101 |
+
"""
|
102 |
+
Preprocess image file in parallel
|
103 |
+
|
104 |
+
Args:
|
105 |
+
image_path: Path to image file
|
106 |
+
target_size: Target resolution
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
Preprocessed image data
|
110 |
+
"""
|
111 |
+
from PIL import Image
|
112 |
+
import cv2
|
113 |
+
|
114 |
+
# Define subtasks
|
115 |
+
def load_and_resize():
|
116 |
+
# Load image
|
117 |
+
img = Image.open(image_path).convert('RGB')
|
118 |
+
|
119 |
+
# Resize
|
120 |
+
img = img.resize((target_size, target_size), Image.Resampling.LANCZOS)
|
121 |
+
|
122 |
+
return np.array(img)
|
123 |
+
|
124 |
+
def extract_face_landmarks(img_array):
|
125 |
+
# Face detection and landmark extraction
|
126 |
+
# Simplified version - in production, use MediaPipe or similar
|
127 |
+
return {
|
128 |
+
'has_face': True,
|
129 |
+
'landmarks': None # Placeholder
|
130 |
+
}
|
131 |
+
|
132 |
+
# Execute in parallel
|
133 |
+
future_img = self.thread_executor.submit(load_and_resize)
|
134 |
+
|
135 |
+
# Get results
|
136 |
+
img_array = future_img.result()
|
137 |
+
|
138 |
+
# Extract landmarks
|
139 |
+
landmarks = extract_face_landmarks(img_array)
|
140 |
+
|
141 |
+
return {
|
142 |
+
'image': img_array,
|
143 |
+
'shape': img_array.shape,
|
144 |
+
'landmarks': landmarks
|
145 |
+
}
|
146 |
+
|
147 |
+
async def preprocess_parallel_async(
|
148 |
+
self,
|
149 |
+
audio_path: str,
|
150 |
+
image_path: str,
|
151 |
+
target_size: int = 320
|
152 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
153 |
+
"""
|
154 |
+
Asynchronously preprocess audio and image in parallel
|
155 |
+
|
156 |
+
Args:
|
157 |
+
audio_path: Path to audio file
|
158 |
+
image_path: Path to image file
|
159 |
+
target_size: Target image resolution
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
Tuple of (audio_data, image_data)
|
163 |
+
"""
|
164 |
+
loop = asyncio.get_event_loop()
|
165 |
+
|
166 |
+
# Create tasks for parallel execution
|
167 |
+
audio_task = loop.run_in_executor(
|
168 |
+
self.thread_executor,
|
169 |
+
self.preprocess_audio_parallel,
|
170 |
+
audio_path
|
171 |
+
)
|
172 |
+
|
173 |
+
image_task = loop.run_in_executor(
|
174 |
+
self.thread_executor,
|
175 |
+
partial(self.preprocess_image_parallel, target_size=target_size),
|
176 |
+
image_path
|
177 |
+
)
|
178 |
+
|
179 |
+
# Wait for both tasks to complete
|
180 |
+
audio_data, image_data = await asyncio.gather(audio_task, image_task)
|
181 |
+
|
182 |
+
return audio_data, image_data
|
183 |
+
|
184 |
+
def preprocess_parallel_sync(
|
185 |
+
self,
|
186 |
+
audio_path: str,
|
187 |
+
image_path: str,
|
188 |
+
target_size: int = 320
|
189 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
190 |
+
"""
|
191 |
+
Synchronously preprocess audio and image in parallel
|
192 |
+
|
193 |
+
Args:
|
194 |
+
audio_path: Path to audio file
|
195 |
+
image_path: Path to image file
|
196 |
+
target_size: Target image resolution
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
Tuple of (audio_data, image_data)
|
200 |
+
"""
|
201 |
+
# Submit tasks to thread pool
|
202 |
+
audio_future = self.thread_executor.submit(
|
203 |
+
self.preprocess_audio_parallel,
|
204 |
+
audio_path
|
205 |
+
)
|
206 |
+
|
207 |
+
image_future = self.thread_executor.submit(
|
208 |
+
self.preprocess_image_parallel,
|
209 |
+
image_path,
|
210 |
+
target_size
|
211 |
+
)
|
212 |
+
|
213 |
+
# Wait for results
|
214 |
+
audio_data = audio_future.result()
|
215 |
+
image_data = image_future.result()
|
216 |
+
|
217 |
+
return audio_data, image_data
|
218 |
+
|
219 |
+
def process_gpu_parallel(
|
220 |
+
self,
|
221 |
+
audio_tensor: torch.Tensor,
|
222 |
+
image_tensor: torch.Tensor,
|
223 |
+
model_audio: torch.nn.Module,
|
224 |
+
model_image: torch.nn.Module
|
225 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
226 |
+
"""
|
227 |
+
Process audio and image through models using CUDA streams
|
228 |
+
|
229 |
+
Args:
|
230 |
+
audio_tensor: Audio tensor
|
231 |
+
image_tensor: Image tensor
|
232 |
+
model_audio: Audio processing model
|
233 |
+
model_image: Image processing model
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
Tuple of processed tensors
|
237 |
+
"""
|
238 |
+
if not self.use_cuda_streams:
|
239 |
+
# Fallback to sequential processing
|
240 |
+
audio_out = model_audio(audio_tensor)
|
241 |
+
image_out = model_image(image_tensor)
|
242 |
+
return audio_out, image_out
|
243 |
+
|
244 |
+
# Use CUDA streams for parallel GPU processing
|
245 |
+
with torch.cuda.stream(self.cuda_streams[0]):
|
246 |
+
audio_out = model_audio(audio_tensor)
|
247 |
+
|
248 |
+
with torch.cuda.stream(self.cuda_streams[1]):
|
249 |
+
image_out = model_image(image_tensor)
|
250 |
+
|
251 |
+
# Synchronize streams
|
252 |
+
torch.cuda.synchronize()
|
253 |
+
|
254 |
+
return audio_out, image_out
|
255 |
+
|
256 |
+
def shutdown(self):
|
257 |
+
"""Shutdown executors"""
|
258 |
+
self.thread_executor.shutdown(wait=True)
|
259 |
+
self.process_executor.shutdown(wait=True)
|
260 |
+
print("✅ ParallelProcessor shutdown complete")
|
261 |
+
|
262 |
+
|
263 |
+
class PipelineProcessor:
|
264 |
+
"""
|
265 |
+
Pipeline-based processing for continuous operations
|
266 |
+
"""
|
267 |
+
|
268 |
+
def __init__(self, stages: Dict[str, Callable], buffer_size: int = 10):
|
269 |
+
"""
|
270 |
+
Initialize pipeline processor
|
271 |
+
|
272 |
+
Args:
|
273 |
+
stages: Dictionary of stage_name -> processing_function
|
274 |
+
buffer_size: Size of queues between stages
|
275 |
+
"""
|
276 |
+
self.stages = stages
|
277 |
+
self.buffer_size = buffer_size
|
278 |
+
|
279 |
+
# Create queues between stages
|
280 |
+
self.queues = {}
|
281 |
+
stage_names = list(stages.keys())
|
282 |
+
for i in range(len(stage_names) - 1):
|
283 |
+
queue_name = f"{stage_names[i]}_to_{stage_names[i+1]}"
|
284 |
+
self.queues[queue_name] = queue.Queue(maxsize=buffer_size)
|
285 |
+
|
286 |
+
# Input and output queues
|
287 |
+
self.input_queue = queue.Queue(maxsize=buffer_size)
|
288 |
+
self.output_queue = queue.Queue(maxsize=buffer_size)
|
289 |
+
|
290 |
+
# Worker threads
|
291 |
+
self.workers = []
|
292 |
+
self.stop_event = threading.Event()
|
293 |
+
|
294 |
+
def _worker(self, stage_name: str, process_func: Callable, input_q: queue.Queue, output_q: queue.Queue):
|
295 |
+
"""Worker thread for a pipeline stage"""
|
296 |
+
while not self.stop_event.is_set():
|
297 |
+
try:
|
298 |
+
# Get input with timeout
|
299 |
+
item = input_q.get(timeout=0.1)
|
300 |
+
|
301 |
+
if item is None: # Poison pill
|
302 |
+
output_q.put(None)
|
303 |
+
break
|
304 |
+
|
305 |
+
# Process item
|
306 |
+
result = process_func(item)
|
307 |
+
|
308 |
+
# Put result
|
309 |
+
output_q.put(result)
|
310 |
+
|
311 |
+
except queue.Empty:
|
312 |
+
continue
|
313 |
+
except Exception as e:
|
314 |
+
print(f"Error in stage {stage_name}: {e}")
|
315 |
+
output_q.put(None)
|
316 |
+
|
317 |
+
def start(self):
|
318 |
+
"""Start pipeline processing"""
|
319 |
+
stage_names = list(self.stages.keys())
|
320 |
+
|
321 |
+
# Create worker threads
|
322 |
+
for i, (stage_name, process_func) in enumerate(self.stages.items()):
|
323 |
+
# Determine input and output queues
|
324 |
+
if i == 0:
|
325 |
+
input_q = self.input_queue
|
326 |
+
else:
|
327 |
+
queue_name = f"{stage_names[i-1]}_to_{stage_names[i]}"
|
328 |
+
input_q = self.queues[queue_name]
|
329 |
+
|
330 |
+
if i == len(stage_names) - 1:
|
331 |
+
output_q = self.output_queue
|
332 |
+
else:
|
333 |
+
queue_name = f"{stage_names[i]}_to_{stage_names[i+1]}"
|
334 |
+
output_q = self.queues[queue_name]
|
335 |
+
|
336 |
+
# Create and start worker
|
337 |
+
worker = threading.Thread(
|
338 |
+
target=self._worker,
|
339 |
+
args=(stage_name, process_func, input_q, output_q)
|
340 |
+
)
|
341 |
+
worker.start()
|
342 |
+
self.workers.append(worker)
|
343 |
+
|
344 |
+
print(f"✅ Pipeline started with {len(self.workers)} stages")
|
345 |
+
|
346 |
+
def process(self, item: Any) -> Any:
|
347 |
+
"""Process an item through the pipeline"""
|
348 |
+
self.input_queue.put(item)
|
349 |
+
return self.output_queue.get()
|
350 |
+
|
351 |
+
def stop(self):
|
352 |
+
"""Stop pipeline processing"""
|
353 |
+
self.stop_event.set()
|
354 |
+
|
355 |
+
# Send poison pills
|
356 |
+
self.input_queue.put(None)
|
357 |
+
|
358 |
+
# Wait for workers
|
359 |
+
for worker in self.workers:
|
360 |
+
worker.join()
|
361 |
+
|
362 |
+
print("✅ Pipeline stopped")
|
363 |
+
|
364 |
+
|
365 |
+
def benchmark_parallel_processing():
|
366 |
+
"""Benchmark parallel vs sequential processing"""
|
367 |
+
import time
|
368 |
+
|
369 |
+
print("\n=== Parallel Processing Benchmark ===")
|
370 |
+
|
371 |
+
# Create processor
|
372 |
+
processor = ParallelProcessor(num_threads=4)
|
373 |
+
|
374 |
+
# Test files (using example files)
|
375 |
+
audio_path = "example/audio.wav"
|
376 |
+
image_path = "example/image.png"
|
377 |
+
|
378 |
+
# Sequential processing
|
379 |
+
start_seq = time.time()
|
380 |
+
audio_data_seq = processor.preprocess_audio_parallel(audio_path)
|
381 |
+
image_data_seq = processor.preprocess_image_parallel(image_path)
|
382 |
+
time_seq = time.time() - start_seq
|
383 |
+
|
384 |
+
# Parallel processing
|
385 |
+
start_par = time.time()
|
386 |
+
audio_data_par, image_data_par = processor.preprocess_parallel_sync(audio_path, image_path)
|
387 |
+
time_par = time.time() - start_par
|
388 |
+
|
389 |
+
# Results
|
390 |
+
print(f"Sequential processing: {time_seq:.3f}s")
|
391 |
+
print(f"Parallel processing: {time_par:.3f}s")
|
392 |
+
print(f"Speedup: {time_seq/time_par:.2f}x")
|
393 |
+
|
394 |
+
processor.shutdown()
|
395 |
+
|
396 |
+
return {
|
397 |
+
'sequential_time': time_seq,
|
398 |
+
'parallel_time': time_par,
|
399 |
+
'speedup': time_seq / time_par
|
400 |
+
}
|