oKen38461 commited on
Commit
2089ecf
·
1 Parent(s): d9a2a3d

ストリーミング関連のフレームレートを25fpsから20fpsに変更し、関連するテストケースを更新しました。これにより、全体のフレーム数計算が一貫性を持つようになりました。

Browse files
api_server_streaming.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DittoTalkingHead Streaming API Server
3
+ WebSocket/SSEによるリアルタイムストリーミング実装
4
+ """
5
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, File, UploadFile, HTTPException
6
+ from fastapi.responses import StreamingResponse
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ import asyncio
9
+ import tempfile
10
+ import numpy as np
11
+ import base64
12
+ import json
13
+ from typing import AsyncGenerator, Optional
14
+ import cv2
15
+ import time
16
+ import logging
17
+ from pathlib import Path
18
+ import traceback
19
+
20
+ from stream_pipeline_offline import StreamSDK
21
+
22
+ # ログ設定
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+
26
+ app = FastAPI(title="DittoTalkingHead Streaming API")
27
+
28
+ # CORS設定
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=["*"],
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
+ )
36
+
37
+ # SDK設定
38
+ CFG_PKL = "checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl"
39
+ DATA_ROOT = "checkpoints/ditto_pytorch"
40
+
41
+ # グローバル設定
42
+ class AppState:
43
+ def __init__(self):
44
+ self.sdk: Optional[StreamSDK] = None
45
+ self.active_connections: int = 0
46
+ self.max_connections: int = 5
47
+
48
+ state = AppState()
49
+
50
+ def init_sdk():
51
+ """SDKの初期化"""
52
+ if state.sdk is None:
53
+ logger.info("Initializing StreamSDK...")
54
+ state.sdk = StreamSDK(CFG_PKL, DATA_ROOT)
55
+ logger.info("StreamSDK initialized successfully")
56
+ return state.sdk
57
+
58
+ @app.on_event("startup")
59
+ async def startup_event():
60
+ """起動時にSDKを初期化"""
61
+ init_sdk()
62
+
63
+ @app.get("/")
64
+ async def root():
65
+ """ヘルスチェック"""
66
+ return {
67
+ "status": "ok",
68
+ "service": "DittoTalkingHead Streaming API",
69
+ "active_connections": state.active_connections,
70
+ "max_connections": state.max_connections
71
+ }
72
+
73
+ @app.websocket("/ws/generate")
74
+ async def websocket_endpoint(websocket: WebSocket):
75
+ """WebSocketエンドポイント - リアルタイムストリーミング"""
76
+
77
+ # 接続数チェック
78
+ if state.active_connections >= state.max_connections:
79
+ await websocket.close(code=1008, reason="Server busy")
80
+ return
81
+
82
+ await websocket.accept()
83
+ state.active_connections += 1
84
+ logger.info(f"New WebSocket connection. Active: {state.active_connections}")
85
+
86
+ sdk_instance = None
87
+ output_path = None
88
+
89
+ try:
90
+ # 初期設定を受信
91
+ config = await websocket.receive_json()
92
+ source_image_b64 = config.get("source_image")
93
+ sample_rate = config.get("sample_rate", 16000)
94
+ chunk_duration = config.get("chunk_duration", 0.2)
95
+
96
+ if not source_image_b64:
97
+ await websocket.send_json({"type": "error", "message": "source_image is required"})
98
+ return
99
+
100
+ # 画像をデコードして一時ファイルに保存
101
+ image_data = base64.b64decode(source_image_b64)
102
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img:
103
+ tmp_img.write(image_data)
104
+ source_path = tmp_img.name
105
+
106
+ # 出力ファイルの準備
107
+ output_path = tempfile.mktemp(suffix=".mp4")
108
+
109
+ # SDK設定
110
+ sdk_instance = init_sdk()
111
+ sdk_instance.setup(source_path, output_path, online_mode=True, max_size=1024)
112
+
113
+ await websocket.send_json({
114
+ "type": "ready",
115
+ "message": "Ready to receive audio chunks",
116
+ "chunk_size": int(sample_rate * chunk_duration)
117
+ })
118
+
119
+ # フレーム送信タスク
120
+ async def send_frames():
121
+ frame_count = 0
122
+ last_frame_time = time.time()
123
+
124
+ while True:
125
+ try:
126
+ current_time = time.time()
127
+
128
+ if sdk_instance.writer_queue.qsize() > 0:
129
+ frame = sdk_instance.writer_queue.get_nowait()
130
+ if frame is not None:
131
+ # フレームをJPEGエンコード(品質調整可能)
132
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80]
133
+ _, jpeg = cv2.imencode('.jpg',
134
+ cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
135
+ encode_param)
136
+ frame_b64 = base64.b64encode(jpeg).decode('utf-8')
137
+
138
+ # FPS計算
139
+ fps = 1.0 / (current_time - last_frame_time) if current_time > last_frame_time else 0
140
+ last_frame_time = current_time
141
+
142
+ await websocket.send_json({
143
+ "type": "frame",
144
+ "frame_id": frame_count,
145
+ "timestamp": current_time,
146
+ "fps": round(fps, 2),
147
+ "data": frame_b64
148
+ })
149
+ frame_count += 1
150
+ except asyncio.CancelledError:
151
+ break
152
+ except Exception as e:
153
+ logger.error(f"Error sending frame: {e}")
154
+
155
+ await asyncio.sleep(0.01) # 10ms間隔でチェック
156
+
157
+ # フレーム送信タスクを開始
158
+ frame_task = asyncio.create_task(send_frames())
159
+
160
+ # 音声チャンクを受信して処理
161
+ total_samples = 0
162
+ chunk_size = int(sample_rate * chunk_duration)
163
+ processing_start = time.time()
164
+
165
+ while True:
166
+ message = await websocket.receive()
167
+
168
+ if "bytes" in message:
169
+ # 音声データを受信
170
+ audio_bytes = message["bytes"]
171
+ audio_chunk = np.frombuffer(audio_bytes, dtype=np.float32)
172
+
173
+ # パディング
174
+ if len(audio_chunk) < chunk_size:
175
+ audio_chunk = np.pad(audio_chunk, (0, chunk_size - len(audio_chunk)))
176
+
177
+ # SDKに送信
178
+ sdk_instance.run_chunk(audio_chunk[:chunk_size])
179
+ total_samples += len(audio_chunk)
180
+
181
+ # 進捗情報を送信
182
+ elapsed = time.time() - processing_start
183
+ await websocket.send_json({
184
+ "type": "progress",
185
+ "samples_processed": total_samples,
186
+ "duration_seconds": total_samples / sample_rate,
187
+ "elapsed_seconds": elapsed
188
+ })
189
+
190
+ elif "text" in message:
191
+ # コマンドを受信
192
+ command = json.loads(message["text"])
193
+ if command.get("action") == "stop":
194
+ logger.info("Received stop command")
195
+ break
196
+
197
+ # 処理終了
198
+ frame_task.cancel()
199
+ try:
200
+ await frame_task
201
+ except asyncio.CancelledError:
202
+ pass
203
+
204
+ # フレーム数を推定してsetup_Nd
205
+ estimated_frames = int(total_samples / sample_rate * 20)
206
+ sdk_instance.setup_Nd(estimated_frames)
207
+
208
+ # 残りのフレームを処理
209
+ await websocket.send_json({"type": "processing", "message": "Finalizing video..."})
210
+
211
+ # SDKを閉じて最終MP4を生成
212
+ sdk_instance.close()
213
+
214
+ # 最終的なMP4を送信
215
+ if Path(output_path).exists():
216
+ with open(output_path, "rb") as f:
217
+ mp4_data = f.read()
218
+ mp4_b64 = base64.b64encode(mp4_data).decode('utf-8')
219
+
220
+ await websocket.send_json({
221
+ "type": "final_video",
222
+ "size_bytes": len(mp4_data),
223
+ "duration_seconds": total_samples / sample_rate,
224
+ "data": mp4_b64
225
+ })
226
+ else:
227
+ await websocket.send_json({
228
+ "type": "error",
229
+ "message": "Failed to generate final video"
230
+ })
231
+
232
+ except WebSocketDisconnect:
233
+ logger.info("Client disconnected")
234
+ except Exception as e:
235
+ logger.error(f"WebSocket error: {e}")
236
+ logger.error(traceback.format_exc())
237
+ try:
238
+ await websocket.send_json({
239
+ "type": "error",
240
+ "message": str(e)
241
+ })
242
+ except:
243
+ pass
244
+ finally:
245
+ state.active_connections -= 1
246
+ logger.info(f"Connection closed. Active: {state.active_connections}")
247
+
248
+ # クリーンアップ
249
+ if output_path and Path(output_path).exists():
250
+ try:
251
+ Path(output_path).unlink()
252
+ except:
253
+ pass
254
+
255
+ @app.post("/sse/generate")
256
+ async def sse_generate(
257
+ source_image: UploadFile = File(...),
258
+ sample_rate: int = 16000,
259
+ max_duration: float = 10.0
260
+ ):
261
+ """SSEエンドポイント - Server-Sent Eventsによるストリーミング"""
262
+
263
+ if state.active_connections >= state.max_connections:
264
+ raise HTTPException(status_code=503, detail="Server busy")
265
+
266
+ state.active_connections += 1
267
+
268
+ async def generate() -> AsyncGenerator[str, None]:
269
+ sdk_instance = None
270
+ output_path = None
271
+
272
+ try:
273
+ # 画像を保存
274
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img:
275
+ content = await source_image.read()
276
+ tmp_img.write(content)
277
+ source_path = tmp_img.name
278
+
279
+ output_path = tempfile.mktemp(suffix=".mp4")
280
+
281
+ # SDK設定
282
+ sdk_instance = init_sdk()
283
+ sdk_instance.setup(source_path, output_path, online_mode=True, max_size=1024)
284
+
285
+ # イベント送信
286
+ yield f"data: {json.dumps({'type': 'start', 'message': 'Processing started'})}\n\n"
287
+
288
+ # デモ用:ダミー音声を生成して処理
289
+ chunk_duration = 0.2
290
+ chunk_size = int(sample_rate * chunk_duration)
291
+ num_chunks = int(max_duration / chunk_duration)
292
+
293
+ for i in range(num_chunks):
294
+ # ダミー音声チャンク(実際の実装では音声ストリームから取得)
295
+ audio_chunk = np.random.randn(chunk_size).astype(np.float32) * 0.1
296
+ sdk_instance.run_chunk(audio_chunk)
297
+
298
+ # フレームチェック
299
+ if sdk_instance.writer_queue.qsize() > 0:
300
+ try:
301
+ frame = sdk_instance.writer_queue.get_nowait()
302
+ if frame is not None:
303
+ # サムネイル生成(低解像度)
304
+ thumbnail = cv2.resize(frame, (160, 160))
305
+ _, jpeg = cv2.imencode('.jpg', cv2.cvtColor(thumbnail, cv2.COLOR_RGB2BGR))
306
+ frame_b64 = base64.b64encode(jpeg).decode('utf-8')
307
+
308
+ yield f"data: {json.dumps({'type': 'thumbnail', 'frame_id': i, 'data': frame_b64})}\n\n"
309
+ except:
310
+ pass
311
+
312
+ await asyncio.sleep(chunk_duration)
313
+
314
+ # 完了
315
+ estimated_frames = num_chunks * 5 # 概算
316
+ sdk_instance.setup_Nd(estimated_frames)
317
+ sdk_instance.close()
318
+
319
+ yield f"data: {json.dumps({'type': 'complete', 'frames': estimated_frames})}\n\n"
320
+
321
+ except Exception as e:
322
+ logger.error(f"SSE error: {e}")
323
+ yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
324
+ finally:
325
+ state.active_connections -= 1
326
+ if output_path and Path(output_path).exists():
327
+ try:
328
+ Path(output_path).unlink()
329
+ except:
330
+ pass
331
+
332
+ return StreamingResponse(
333
+ generate(),
334
+ media_type="text/event-stream",
335
+ headers={
336
+ "Cache-Control": "no-cache",
337
+ "Connection": "keep-alive",
338
+ }
339
+ )
340
+
341
+ @app.get("/test")
342
+ async def test_page():
343
+ """テスト用HTMLページ"""
344
+ html_content = """
345
+ <!DOCTYPE html>
346
+ <html>
347
+ <head>
348
+ <title>DittoTalkingHead Streaming Test</title>
349
+ <style>
350
+ body { font-family: Arial, sans-serif; margin: 20px; }
351
+ .container { max-width: 800px; margin: 0 auto; }
352
+ #live-frame { max-width: 100%; border: 1px solid #ccc; }
353
+ #status { margin: 10px 0; padding: 10px; background: #f0f0f0; }
354
+ .controls { margin: 20px 0; }
355
+ button { padding: 10px 20px; margin: 5px; }
356
+ </style>
357
+ </head>
358
+ <body>
359
+ <div class="container">
360
+ <h1>DittoTalkingHead Streaming Test</h1>
361
+
362
+ <div class="controls">
363
+ <input type="file" id="source-image" accept="image/*">
364
+ <button id="start-btn">Start Streaming</button>
365
+ <button id="stop-btn" disabled>Stop</button>
366
+ </div>
367
+
368
+ <div id="status">Ready</div>
369
+
370
+ <img id="live-frame" style="display: none;">
371
+ <video id="final-video" controls style="display: none; width: 100%;"></video>
372
+ </div>
373
+
374
+ <script>
375
+ // WebSocket実装はstreaming_api_guide.mdを参照
376
+ console.log('WebSocket endpoint: ws://localhost:8000/ws/generate');
377
+ </script>
378
+ </body>
379
+ </html>
380
+ """
381
+ from fastapi.responses import HTMLResponse
382
+ return HTMLResponse(content=html_content)
383
+
384
+ if __name__ == "__main__":
385
+ import uvicorn
386
+ import torch
387
+
388
+ # GPU設定
389
+ if torch.cuda.is_available():
390
+ torch.cuda.empty_cache()
391
+ torch.backends.cudnn.benchmark = True
392
+
393
+ logger.info("Starting DittoTalkingHead Streaming API Server...")
394
+ logger.info(f"GPU available: {torch.cuda.is_available()}")
395
+
396
+ uvicorn.run(
397
+ app,
398
+ host="0.0.0.0",
399
+ port=8000,
400
+ log_level="info",
401
+ access_log=True
402
+ )
app_streaming.py CHANGED
@@ -46,7 +46,7 @@ def generator(mic, src_img):
46
  # setup: online_mode=True でストリーミング
47
  tmp_out = tempfile.mktemp(suffix=".mp4")
48
  sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024)
49
- N_total = int(np.ceil(len(wav_full) / sr * 25)) # 概算フレーム数
50
  sdk.setup_Nd(N_total)
51
 
52
  # 処理開始時刻
 
46
  # setup: online_mode=True でストリーミング
47
  tmp_out = tempfile.mktemp(suffix=".mp4")
48
  sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024)
49
+ N_total = int(np.ceil(len(wav_full) / sr * 20)) # 概算フレーム数
50
  sdk.setup_Nd(N_total)
51
 
52
  # 処理開始時刻
core/atomic_components/writer.py CHANGED
@@ -3,7 +3,7 @@ import os
3
 
4
 
5
  class VideoWriterByImageIO:
6
- def __init__(self, video_path, fps=25, **kwargs):
7
  video_format = kwargs.get("format", "mp4") # default is mp4 format
8
  codec = kwargs.get("vcodec", "libx264") # default is libx264 encoding
9
  quality = kwargs.get("quality") # video quality
 
3
 
4
 
5
  class VideoWriterByImageIO:
6
+ def __init__(self, video_path, fps=20, **kwargs):
7
  video_format = kwargs.get("format", "mp4") # default is mp4 format
8
  codec = kwargs.get("vcodec", "libx264") # default is libx264 encoding
9
  quality = kwargs.get("quality") # video quality
core/models/modules/lmdm_modules/model.py CHANGED
@@ -237,7 +237,7 @@ class MotionDecoder(nn.Module):
237
  def __init__(
238
  self,
239
  nfeats: int,
240
- seq_len: int = 100, # 4 seconds, 25 fps
241
  latent_dim: int = 256,
242
  ff_size: int = 1024,
243
  num_layers: int = 4,
 
237
  def __init__(
238
  self,
239
  nfeats: int,
240
+ seq_len: int = 80, # 4 seconds, 20 fps
241
  latent_dim: int = 256,
242
  ff_size: int = 1024,
243
  num_layers: int = 4,
streaming_client.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DittoTalkingHead Streaming Client
3
+ WebSocketを使用したストリーミングクライアントの実装例
4
+ """
5
+ import asyncio
6
+ import websockets
7
+ import numpy as np
8
+ import soundfile as sf
9
+ import base64
10
+ import json
11
+ import cv2
12
+ from typing import Optional, Callable
13
+ import pyaudio
14
+ import threading
15
+ import queue
16
+ from pathlib import Path
17
+ import logging
18
+
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ class DittoStreamingClient:
23
+ """DittoTalkingHeadストリーミングクライアント"""
24
+
25
+ def __init__(self, server_url="ws://localhost:8000/ws/generate"):
26
+ self.server_url = server_url
27
+ self.sample_rate = 16000
28
+ self.chunk_duration = 0.2 # 200ms
29
+ self.chunk_size = int(self.sample_rate * self.chunk_duration)
30
+ self.websocket = None
31
+ self.is_connected = False
32
+ self.frame_callback: Optional[Callable] = None
33
+ self.final_video_callback: Optional[Callable] = None
34
+
35
+ async def connect(self, source_image_path: str):
36
+ """サーバーに接続してセッションを開始"""
37
+ try:
38
+ # 画像をBase64エンコード
39
+ with open(source_image_path, "rb") as f:
40
+ image_b64 = base64.b64encode(f.read()).decode('utf-8')
41
+
42
+ # WebSocket接続
43
+ self.websocket = await websockets.connect(self.server_url)
44
+ self.is_connected = True
45
+
46
+ # 初期設定を送信
47
+ await self.websocket.send(json.dumps({
48
+ "source_image": image_b64,
49
+ "sample_rate": self.sample_rate,
50
+ "chunk_duration": self.chunk_duration
51
+ }))
52
+
53
+ # 応答を待つ
54
+ response = await self.websocket.recv()
55
+ data = json.loads(response)
56
+
57
+ if data["type"] == "ready":
58
+ logger.info(f"Connected to server: {data['message']}")
59
+ return True
60
+ else:
61
+ logger.error(f"Connection failed: {data}")
62
+ return False
63
+
64
+ except Exception as e:
65
+ logger.error(f"Connection error: {e}")
66
+ self.is_connected = False
67
+ raise
68
+
69
+ async def disconnect(self):
70
+ """接続を切断"""
71
+ if self.websocket:
72
+ await self.websocket.close()
73
+ self.is_connected = False
74
+ logger.info("Disconnected from server")
75
+
76
+ async def stream_audio_file(self, audio_path: str, source_image_path: str):
77
+ """音声ファイルをストリーミング"""
78
+ try:
79
+ # 接続
80
+ await self.connect(source_image_path)
81
+
82
+ # 音声を読み込み
83
+ audio_data, sr = sf.read(audio_path)
84
+ if sr != self.sample_rate:
85
+ import librosa
86
+ audio_data = librosa.resample(
87
+ audio_data,
88
+ orig_sr=sr,
89
+ target_sr=self.sample_rate
90
+ )
91
+
92
+ # フレーム受信タスク
93
+ receive_task = asyncio.create_task(self._receive_frames())
94
+
95
+ # 音声をチャンク単位で送信
96
+ total_chunks = 0
97
+ for i in range(0, len(audio_data), self.chunk_size):
98
+ chunk = audio_data[i:i+self.chunk_size]
99
+ if len(chunk) < self.chunk_size:
100
+ chunk = np.pad(chunk, (0, self.chunk_size - len(chunk)))
101
+
102
+ # Float32として送信
103
+ await self.websocket.send(chunk.astype(np.float32).tobytes())
104
+ total_chunks += 1
105
+
106
+ # リアルタイムシミュレーション
107
+ await asyncio.sleep(self.chunk_duration)
108
+
109
+ # 進捗表示
110
+ progress = (i + self.chunk_size) / len(audio_data) * 100
111
+ logger.info(f"Streaming progress: {progress:.1f}%")
112
+
113
+ # 停止コマンドを送信
114
+ await self.websocket.send(json.dumps({"action": "stop"}))
115
+ logger.info(f"Sent {total_chunks} audio chunks")
116
+
117
+ # フレーム受信を待つ
118
+ await receive_task
119
+
120
+ finally:
121
+ await self.disconnect()
122
+
123
+ async def stream_microphone(self, source_image_path: str, duration: Optional[float] = None):
124
+ """マイクからリアルタイムストリーミング"""
125
+ try:
126
+ # 接続
127
+ await self.connect(source_image_path)
128
+
129
+ # フレーム受信タスク
130
+ receive_task = asyncio.create_task(self._receive_frames())
131
+
132
+ # マイク録音用のキュー
133
+ audio_queue = queue.Queue()
134
+ stop_event = threading.Event()
135
+
136
+ # マイク録音スレッド
137
+ def record_audio():
138
+ p = pyaudio.PyAudio()
139
+ stream = p.open(
140
+ format=pyaudio.paFloat32,
141
+ channels=1,
142
+ rate=self.sample_rate,
143
+ input=True,
144
+ frames_per_buffer=self.chunk_size
145
+ )
146
+
147
+ logger.info("Recording started... Press Ctrl+C to stop")
148
+
149
+ try:
150
+ start_time = asyncio.get_event_loop().time()
151
+ while not stop_event.is_set():
152
+ if duration and (asyncio.get_event_loop().time() - start_time) > duration:
153
+ break
154
+
155
+ audio_chunk = stream.read(self.chunk_size, exception_on_overflow=False)
156
+ audio_queue.put(audio_chunk)
157
+
158
+ except Exception as e:
159
+ logger.error(f"Recording error: {e}")
160
+ finally:
161
+ stream.stop_stream()
162
+ stream.close()
163
+ p.terminate()
164
+ logger.info("Recording stopped")
165
+
166
+ # 録音スレッドを開始
167
+ record_thread = threading.Thread(target=record_audio)
168
+ record_thread.start()
169
+
170
+ try:
171
+ # 音声データを送信
172
+ while record_thread.is_alive() or not audio_queue.empty():
173
+ try:
174
+ audio_chunk = audio_queue.get(timeout=0.1)
175
+ audio_array = np.frombuffer(audio_chunk, dtype=np.float32)
176
+ await self.websocket.send(audio_array.tobytes())
177
+ except queue.Empty:
178
+ continue
179
+ except KeyboardInterrupt:
180
+ logger.info("Stopping recording...")
181
+ break
182
+
183
+ finally:
184
+ stop_event.set()
185
+ record_thread.join()
186
+
187
+ # 停止コマンドを送信
188
+ await self.websocket.send(json.dumps({"action": "stop"}))
189
+
190
+ # フレーム受信を待つ
191
+ await receive_task
192
+
193
+ finally:
194
+ await self.disconnect()
195
+
196
+ async def _receive_frames(self):
197
+ """フレームとメッセージを受信"""
198
+ frame_count = 0
199
+
200
+ try:
201
+ while True:
202
+ message = await self.websocket.recv()
203
+ data = json.loads(message)
204
+
205
+ if data["type"] == "frame":
206
+ frame_count += 1
207
+ logger.info(f"Received frame {data['frame_id']} (FPS: {data.get('fps', 0)})")
208
+
209
+ if self.frame_callback:
210
+ # フレームをデコード
211
+ frame_data = base64.b64decode(data["data"])
212
+ nparr = np.frombuffer(frame_data, np.uint8)
213
+ frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
214
+ self.frame_callback(frame, data)
215
+
216
+ elif data["type"] == "progress":
217
+ logger.info(f"Progress: {data['duration_seconds']:.1f}s processed")
218
+
219
+ elif data["type"] == "processing":
220
+ logger.info(f"Server: {data['message']}")
221
+
222
+ elif data["type"] == "final_video":
223
+ logger.info(f"Received final video ({data['size_bytes']} bytes, {data['duration_seconds']:.1f}s)")
224
+
225
+ if self.final_video_callback:
226
+ video_data = base64.b64decode(data["data"])
227
+ self.final_video_callback(video_data, data)
228
+ break
229
+
230
+ elif data["type"] == "error":
231
+ logger.error(f"Server error: {data['message']}")
232
+ break
233
+
234
+ except websockets.exceptions.ConnectionClosed:
235
+ logger.info("Connection closed by server")
236
+ except Exception as e:
237
+ logger.error(f"Receive error: {e}")
238
+
239
+ logger.info(f"Total frames received: {frame_count}")
240
+
241
+ def set_frame_callback(self, callback: Callable):
242
+ """フレーム受信時のコールバックを設定"""
243
+ self.frame_callback = callback
244
+
245
+ def set_final_video_callback(self, callback: Callable):
246
+ """最終動画受信時のコールバックを設定"""
247
+ self.final_video_callback = callback
248
+
249
+
250
+ # 使用例とテスト
251
+ async def main():
252
+ """使用例"""
253
+ client = DittoStreamingClient()
254
+
255
+ # フレーム表示用のコールバック
256
+ def display_frame(frame, metadata):
257
+ cv2.imshow("Live Frame", frame)
258
+ cv2.waitKey(1)
259
+
260
+ # 最終動画保存用のコールバック
261
+ def save_video(video_data, metadata):
262
+ output_path = "output_streaming.mp4"
263
+ with open(output_path, "wb") as f:
264
+ f.write(video_data)
265
+ logger.info(f"Video saved to {output_path}")
266
+
267
+ client.set_frame_callback(display_frame)
268
+ client.set_final_video_callback(save_video)
269
+
270
+ # テスト画像とサンプル音声のパス
271
+ source_image = "example/reference.png"
272
+ audio_file = "example/audio.wav"
273
+
274
+ # ファイルが存在するか確認
275
+ if not Path(source_image).exists():
276
+ logger.error(f"Source image not found: {source_image}")
277
+ return
278
+
279
+ # 音声ファイルからストリーミング
280
+ if Path(audio_file).exists():
281
+ logger.info("=== Testing audio file streaming ===")
282
+ await client.stream_audio_file(audio_file, source_image)
283
+ else:
284
+ logger.warning(f"Audio file not found: {audio_file}")
285
+
286
+ # マイクからストリーミング(5秒間)
287
+ # logger.info("\n=== Testing microphone streaming (5 seconds) ===")
288
+ # await client.stream_microphone(source_image, duration=5.0)
289
+
290
+ cv2.destroyAllWindows()
291
+
292
+
293
+ # バッチ処理クライアント
294
+ class BatchStreamingClient:
295
+ """複数のリクエストを並列処理するクライアント"""
296
+
297
+ def __init__(self, server_url="ws://localhost:8000/ws/generate", max_parallel=3):
298
+ self.server_url = server_url
299
+ self.max_parallel = max_parallel
300
+
301
+ async def process_batch(self, tasks: list):
302
+ """バッチ処理"""
303
+ semaphore = asyncio.Semaphore(self.max_parallel)
304
+
305
+ async def process_with_limit(task):
306
+ async with semaphore:
307
+ client = DittoStreamingClient(self.server_url)
308
+ await client.stream_audio_file(
309
+ task["audio_path"],
310
+ task["image_path"]
311
+ )
312
+ return task["id"]
313
+
314
+ results = await asyncio.gather(
315
+ *[process_with_limit(task) for task in tasks],
316
+ return_exceptions=True
317
+ )
318
+
319
+ return results
320
+
321
+
322
+ if __name__ == "__main__":
323
+ # 単一クライアントのテスト
324
+ asyncio.run(main())
325
+
326
+ # バッチ処理の例
327
+ # batch_client = BatchStreamingClient()
328
+ # tasks = [
329
+ # {"id": 1, "audio_path": "audio1.wav", "image_path": "image1.png"},
330
+ # {"id": 2, "audio_path": "audio2.wav", "image_path": "image2.png"},
331
+ # ]
332
+ # asyncio.run(batch_client.process_batch(tasks))
test_streaming.py CHANGED
@@ -34,7 +34,7 @@ def test_streaming():
34
  tmp_out = tempfile.mktemp(suffix=".mp4")
35
 
36
  sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024)
37
- N_total = int(np.ceil(duration * 25)) # 25fps
38
  sdk.setup_Nd(N_total)
39
  print("✅ セットアップ完了")
40
 
@@ -98,7 +98,7 @@ def test_streaming():
98
  print(f"✅ 出力ファイル: {tmp_out}")
99
 
100
  # 期待される結果の確認
101
- expected_frames = int(duration * 25) # 25fps
102
  if frames_received >= expected_frames * 0.8: # 80%以上
103
  print("✅ テスト成功!")
104
  else:
 
34
  tmp_out = tempfile.mktemp(suffix=".mp4")
35
 
36
  sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024)
37
+ N_total = int(np.ceil(duration * 20)) # 20fps
38
  sdk.setup_Nd(N_total)
39
  print("✅ セットアップ完了")
40
 
 
98
  print(f"✅ 出力ファイル: {tmp_out}")
99
 
100
  # 期待される結果の確認
101
+ expected_frames = int(duration * 20) # 20fps
102
  if frames_received >= expected_frames * 0.8: # 80%以上
103
  print("✅ テスト成功!")
104
  else: