Spaces:
Runtime error
Runtime error
ストリーミング関連のフレームレートを25fpsから20fpsに変更し、関連するテストケースを更新しました。これにより、全体のフレーム数計算が一貫性を持つようになりました。
Browse files- api_server_streaming.py +402 -0
- app_streaming.py +1 -1
- core/atomic_components/writer.py +1 -1
- core/models/modules/lmdm_modules/model.py +1 -1
- streaming_client.py +332 -0
- test_streaming.py +2 -2
api_server_streaming.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
DittoTalkingHead Streaming API Server
|
3 |
+
WebSocket/SSEによるリアルタイムストリーミング実装
|
4 |
+
"""
|
5 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, File, UploadFile, HTTPException
|
6 |
+
from fastapi.responses import StreamingResponse
|
7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
8 |
+
import asyncio
|
9 |
+
import tempfile
|
10 |
+
import numpy as np
|
11 |
+
import base64
|
12 |
+
import json
|
13 |
+
from typing import AsyncGenerator, Optional
|
14 |
+
import cv2
|
15 |
+
import time
|
16 |
+
import logging
|
17 |
+
from pathlib import Path
|
18 |
+
import traceback
|
19 |
+
|
20 |
+
from stream_pipeline_offline import StreamSDK
|
21 |
+
|
22 |
+
# ログ設定
|
23 |
+
logging.basicConfig(level=logging.INFO)
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
app = FastAPI(title="DittoTalkingHead Streaming API")
|
27 |
+
|
28 |
+
# CORS設定
|
29 |
+
app.add_middleware(
|
30 |
+
CORSMiddleware,
|
31 |
+
allow_origins=["*"],
|
32 |
+
allow_credentials=True,
|
33 |
+
allow_methods=["*"],
|
34 |
+
allow_headers=["*"],
|
35 |
+
)
|
36 |
+
|
37 |
+
# SDK設定
|
38 |
+
CFG_PKL = "checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl"
|
39 |
+
DATA_ROOT = "checkpoints/ditto_pytorch"
|
40 |
+
|
41 |
+
# グローバル設定
|
42 |
+
class AppState:
|
43 |
+
def __init__(self):
|
44 |
+
self.sdk: Optional[StreamSDK] = None
|
45 |
+
self.active_connections: int = 0
|
46 |
+
self.max_connections: int = 5
|
47 |
+
|
48 |
+
state = AppState()
|
49 |
+
|
50 |
+
def init_sdk():
|
51 |
+
"""SDKの初期化"""
|
52 |
+
if state.sdk is None:
|
53 |
+
logger.info("Initializing StreamSDK...")
|
54 |
+
state.sdk = StreamSDK(CFG_PKL, DATA_ROOT)
|
55 |
+
logger.info("StreamSDK initialized successfully")
|
56 |
+
return state.sdk
|
57 |
+
|
58 |
+
@app.on_event("startup")
|
59 |
+
async def startup_event():
|
60 |
+
"""起動時にSDKを初期化"""
|
61 |
+
init_sdk()
|
62 |
+
|
63 |
+
@app.get("/")
|
64 |
+
async def root():
|
65 |
+
"""ヘルスチェック"""
|
66 |
+
return {
|
67 |
+
"status": "ok",
|
68 |
+
"service": "DittoTalkingHead Streaming API",
|
69 |
+
"active_connections": state.active_connections,
|
70 |
+
"max_connections": state.max_connections
|
71 |
+
}
|
72 |
+
|
73 |
+
@app.websocket("/ws/generate")
|
74 |
+
async def websocket_endpoint(websocket: WebSocket):
|
75 |
+
"""WebSocketエンドポイント - リアルタイムストリーミング"""
|
76 |
+
|
77 |
+
# 接続数チェック
|
78 |
+
if state.active_connections >= state.max_connections:
|
79 |
+
await websocket.close(code=1008, reason="Server busy")
|
80 |
+
return
|
81 |
+
|
82 |
+
await websocket.accept()
|
83 |
+
state.active_connections += 1
|
84 |
+
logger.info(f"New WebSocket connection. Active: {state.active_connections}")
|
85 |
+
|
86 |
+
sdk_instance = None
|
87 |
+
output_path = None
|
88 |
+
|
89 |
+
try:
|
90 |
+
# 初期設定を受信
|
91 |
+
config = await websocket.receive_json()
|
92 |
+
source_image_b64 = config.get("source_image")
|
93 |
+
sample_rate = config.get("sample_rate", 16000)
|
94 |
+
chunk_duration = config.get("chunk_duration", 0.2)
|
95 |
+
|
96 |
+
if not source_image_b64:
|
97 |
+
await websocket.send_json({"type": "error", "message": "source_image is required"})
|
98 |
+
return
|
99 |
+
|
100 |
+
# 画像をデコードして一時ファイルに保存
|
101 |
+
image_data = base64.b64decode(source_image_b64)
|
102 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img:
|
103 |
+
tmp_img.write(image_data)
|
104 |
+
source_path = tmp_img.name
|
105 |
+
|
106 |
+
# 出力ファイルの準備
|
107 |
+
output_path = tempfile.mktemp(suffix=".mp4")
|
108 |
+
|
109 |
+
# SDK設定
|
110 |
+
sdk_instance = init_sdk()
|
111 |
+
sdk_instance.setup(source_path, output_path, online_mode=True, max_size=1024)
|
112 |
+
|
113 |
+
await websocket.send_json({
|
114 |
+
"type": "ready",
|
115 |
+
"message": "Ready to receive audio chunks",
|
116 |
+
"chunk_size": int(sample_rate * chunk_duration)
|
117 |
+
})
|
118 |
+
|
119 |
+
# フレーム送信タスク
|
120 |
+
async def send_frames():
|
121 |
+
frame_count = 0
|
122 |
+
last_frame_time = time.time()
|
123 |
+
|
124 |
+
while True:
|
125 |
+
try:
|
126 |
+
current_time = time.time()
|
127 |
+
|
128 |
+
if sdk_instance.writer_queue.qsize() > 0:
|
129 |
+
frame = sdk_instance.writer_queue.get_nowait()
|
130 |
+
if frame is not None:
|
131 |
+
# フレームをJPEGエンコード(品質調整可能)
|
132 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80]
|
133 |
+
_, jpeg = cv2.imencode('.jpg',
|
134 |
+
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
|
135 |
+
encode_param)
|
136 |
+
frame_b64 = base64.b64encode(jpeg).decode('utf-8')
|
137 |
+
|
138 |
+
# FPS計算
|
139 |
+
fps = 1.0 / (current_time - last_frame_time) if current_time > last_frame_time else 0
|
140 |
+
last_frame_time = current_time
|
141 |
+
|
142 |
+
await websocket.send_json({
|
143 |
+
"type": "frame",
|
144 |
+
"frame_id": frame_count,
|
145 |
+
"timestamp": current_time,
|
146 |
+
"fps": round(fps, 2),
|
147 |
+
"data": frame_b64
|
148 |
+
})
|
149 |
+
frame_count += 1
|
150 |
+
except asyncio.CancelledError:
|
151 |
+
break
|
152 |
+
except Exception as e:
|
153 |
+
logger.error(f"Error sending frame: {e}")
|
154 |
+
|
155 |
+
await asyncio.sleep(0.01) # 10ms間隔でチェック
|
156 |
+
|
157 |
+
# フレーム送信タスクを開始
|
158 |
+
frame_task = asyncio.create_task(send_frames())
|
159 |
+
|
160 |
+
# 音声チャンクを受信して処理
|
161 |
+
total_samples = 0
|
162 |
+
chunk_size = int(sample_rate * chunk_duration)
|
163 |
+
processing_start = time.time()
|
164 |
+
|
165 |
+
while True:
|
166 |
+
message = await websocket.receive()
|
167 |
+
|
168 |
+
if "bytes" in message:
|
169 |
+
# 音声データを受信
|
170 |
+
audio_bytes = message["bytes"]
|
171 |
+
audio_chunk = np.frombuffer(audio_bytes, dtype=np.float32)
|
172 |
+
|
173 |
+
# パディング
|
174 |
+
if len(audio_chunk) < chunk_size:
|
175 |
+
audio_chunk = np.pad(audio_chunk, (0, chunk_size - len(audio_chunk)))
|
176 |
+
|
177 |
+
# SDKに送信
|
178 |
+
sdk_instance.run_chunk(audio_chunk[:chunk_size])
|
179 |
+
total_samples += len(audio_chunk)
|
180 |
+
|
181 |
+
# 進捗情報を送信
|
182 |
+
elapsed = time.time() - processing_start
|
183 |
+
await websocket.send_json({
|
184 |
+
"type": "progress",
|
185 |
+
"samples_processed": total_samples,
|
186 |
+
"duration_seconds": total_samples / sample_rate,
|
187 |
+
"elapsed_seconds": elapsed
|
188 |
+
})
|
189 |
+
|
190 |
+
elif "text" in message:
|
191 |
+
# コマンドを受信
|
192 |
+
command = json.loads(message["text"])
|
193 |
+
if command.get("action") == "stop":
|
194 |
+
logger.info("Received stop command")
|
195 |
+
break
|
196 |
+
|
197 |
+
# 処理終了
|
198 |
+
frame_task.cancel()
|
199 |
+
try:
|
200 |
+
await frame_task
|
201 |
+
except asyncio.CancelledError:
|
202 |
+
pass
|
203 |
+
|
204 |
+
# フレーム数を推定してsetup_Nd
|
205 |
+
estimated_frames = int(total_samples / sample_rate * 20)
|
206 |
+
sdk_instance.setup_Nd(estimated_frames)
|
207 |
+
|
208 |
+
# 残りのフレームを処理
|
209 |
+
await websocket.send_json({"type": "processing", "message": "Finalizing video..."})
|
210 |
+
|
211 |
+
# SDKを閉じて最終MP4を生成
|
212 |
+
sdk_instance.close()
|
213 |
+
|
214 |
+
# 最終的なMP4を送信
|
215 |
+
if Path(output_path).exists():
|
216 |
+
with open(output_path, "rb") as f:
|
217 |
+
mp4_data = f.read()
|
218 |
+
mp4_b64 = base64.b64encode(mp4_data).decode('utf-8')
|
219 |
+
|
220 |
+
await websocket.send_json({
|
221 |
+
"type": "final_video",
|
222 |
+
"size_bytes": len(mp4_data),
|
223 |
+
"duration_seconds": total_samples / sample_rate,
|
224 |
+
"data": mp4_b64
|
225 |
+
})
|
226 |
+
else:
|
227 |
+
await websocket.send_json({
|
228 |
+
"type": "error",
|
229 |
+
"message": "Failed to generate final video"
|
230 |
+
})
|
231 |
+
|
232 |
+
except WebSocketDisconnect:
|
233 |
+
logger.info("Client disconnected")
|
234 |
+
except Exception as e:
|
235 |
+
logger.error(f"WebSocket error: {e}")
|
236 |
+
logger.error(traceback.format_exc())
|
237 |
+
try:
|
238 |
+
await websocket.send_json({
|
239 |
+
"type": "error",
|
240 |
+
"message": str(e)
|
241 |
+
})
|
242 |
+
except:
|
243 |
+
pass
|
244 |
+
finally:
|
245 |
+
state.active_connections -= 1
|
246 |
+
logger.info(f"Connection closed. Active: {state.active_connections}")
|
247 |
+
|
248 |
+
# クリーンアップ
|
249 |
+
if output_path and Path(output_path).exists():
|
250 |
+
try:
|
251 |
+
Path(output_path).unlink()
|
252 |
+
except:
|
253 |
+
pass
|
254 |
+
|
255 |
+
@app.post("/sse/generate")
|
256 |
+
async def sse_generate(
|
257 |
+
source_image: UploadFile = File(...),
|
258 |
+
sample_rate: int = 16000,
|
259 |
+
max_duration: float = 10.0
|
260 |
+
):
|
261 |
+
"""SSEエンドポイント - Server-Sent Eventsによるストリーミング"""
|
262 |
+
|
263 |
+
if state.active_connections >= state.max_connections:
|
264 |
+
raise HTTPException(status_code=503, detail="Server busy")
|
265 |
+
|
266 |
+
state.active_connections += 1
|
267 |
+
|
268 |
+
async def generate() -> AsyncGenerator[str, None]:
|
269 |
+
sdk_instance = None
|
270 |
+
output_path = None
|
271 |
+
|
272 |
+
try:
|
273 |
+
# 画像を保存
|
274 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img:
|
275 |
+
content = await source_image.read()
|
276 |
+
tmp_img.write(content)
|
277 |
+
source_path = tmp_img.name
|
278 |
+
|
279 |
+
output_path = tempfile.mktemp(suffix=".mp4")
|
280 |
+
|
281 |
+
# SDK設定
|
282 |
+
sdk_instance = init_sdk()
|
283 |
+
sdk_instance.setup(source_path, output_path, online_mode=True, max_size=1024)
|
284 |
+
|
285 |
+
# イベント送信
|
286 |
+
yield f"data: {json.dumps({'type': 'start', 'message': 'Processing started'})}\n\n"
|
287 |
+
|
288 |
+
# デモ用:ダミー音声を生成して処理
|
289 |
+
chunk_duration = 0.2
|
290 |
+
chunk_size = int(sample_rate * chunk_duration)
|
291 |
+
num_chunks = int(max_duration / chunk_duration)
|
292 |
+
|
293 |
+
for i in range(num_chunks):
|
294 |
+
# ダミー音声チャンク(実際の実装では音声ストリームから取得)
|
295 |
+
audio_chunk = np.random.randn(chunk_size).astype(np.float32) * 0.1
|
296 |
+
sdk_instance.run_chunk(audio_chunk)
|
297 |
+
|
298 |
+
# フレームチェック
|
299 |
+
if sdk_instance.writer_queue.qsize() > 0:
|
300 |
+
try:
|
301 |
+
frame = sdk_instance.writer_queue.get_nowait()
|
302 |
+
if frame is not None:
|
303 |
+
# サムネイル生成(低解像度)
|
304 |
+
thumbnail = cv2.resize(frame, (160, 160))
|
305 |
+
_, jpeg = cv2.imencode('.jpg', cv2.cvtColor(thumbnail, cv2.COLOR_RGB2BGR))
|
306 |
+
frame_b64 = base64.b64encode(jpeg).decode('utf-8')
|
307 |
+
|
308 |
+
yield f"data: {json.dumps({'type': 'thumbnail', 'frame_id': i, 'data': frame_b64})}\n\n"
|
309 |
+
except:
|
310 |
+
pass
|
311 |
+
|
312 |
+
await asyncio.sleep(chunk_duration)
|
313 |
+
|
314 |
+
# 完了
|
315 |
+
estimated_frames = num_chunks * 5 # 概算
|
316 |
+
sdk_instance.setup_Nd(estimated_frames)
|
317 |
+
sdk_instance.close()
|
318 |
+
|
319 |
+
yield f"data: {json.dumps({'type': 'complete', 'frames': estimated_frames})}\n\n"
|
320 |
+
|
321 |
+
except Exception as e:
|
322 |
+
logger.error(f"SSE error: {e}")
|
323 |
+
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
324 |
+
finally:
|
325 |
+
state.active_connections -= 1
|
326 |
+
if output_path and Path(output_path).exists():
|
327 |
+
try:
|
328 |
+
Path(output_path).unlink()
|
329 |
+
except:
|
330 |
+
pass
|
331 |
+
|
332 |
+
return StreamingResponse(
|
333 |
+
generate(),
|
334 |
+
media_type="text/event-stream",
|
335 |
+
headers={
|
336 |
+
"Cache-Control": "no-cache",
|
337 |
+
"Connection": "keep-alive",
|
338 |
+
}
|
339 |
+
)
|
340 |
+
|
341 |
+
@app.get("/test")
|
342 |
+
async def test_page():
|
343 |
+
"""テスト用HTMLページ"""
|
344 |
+
html_content = """
|
345 |
+
<!DOCTYPE html>
|
346 |
+
<html>
|
347 |
+
<head>
|
348 |
+
<title>DittoTalkingHead Streaming Test</title>
|
349 |
+
<style>
|
350 |
+
body { font-family: Arial, sans-serif; margin: 20px; }
|
351 |
+
.container { max-width: 800px; margin: 0 auto; }
|
352 |
+
#live-frame { max-width: 100%; border: 1px solid #ccc; }
|
353 |
+
#status { margin: 10px 0; padding: 10px; background: #f0f0f0; }
|
354 |
+
.controls { margin: 20px 0; }
|
355 |
+
button { padding: 10px 20px; margin: 5px; }
|
356 |
+
</style>
|
357 |
+
</head>
|
358 |
+
<body>
|
359 |
+
<div class="container">
|
360 |
+
<h1>DittoTalkingHead Streaming Test</h1>
|
361 |
+
|
362 |
+
<div class="controls">
|
363 |
+
<input type="file" id="source-image" accept="image/*">
|
364 |
+
<button id="start-btn">Start Streaming</button>
|
365 |
+
<button id="stop-btn" disabled>Stop</button>
|
366 |
+
</div>
|
367 |
+
|
368 |
+
<div id="status">Ready</div>
|
369 |
+
|
370 |
+
<img id="live-frame" style="display: none;">
|
371 |
+
<video id="final-video" controls style="display: none; width: 100%;"></video>
|
372 |
+
</div>
|
373 |
+
|
374 |
+
<script>
|
375 |
+
// WebSocket実装はstreaming_api_guide.mdを参照
|
376 |
+
console.log('WebSocket endpoint: ws://localhost:8000/ws/generate');
|
377 |
+
</script>
|
378 |
+
</body>
|
379 |
+
</html>
|
380 |
+
"""
|
381 |
+
from fastapi.responses import HTMLResponse
|
382 |
+
return HTMLResponse(content=html_content)
|
383 |
+
|
384 |
+
if __name__ == "__main__":
|
385 |
+
import uvicorn
|
386 |
+
import torch
|
387 |
+
|
388 |
+
# GPU設定
|
389 |
+
if torch.cuda.is_available():
|
390 |
+
torch.cuda.empty_cache()
|
391 |
+
torch.backends.cudnn.benchmark = True
|
392 |
+
|
393 |
+
logger.info("Starting DittoTalkingHead Streaming API Server...")
|
394 |
+
logger.info(f"GPU available: {torch.cuda.is_available()}")
|
395 |
+
|
396 |
+
uvicorn.run(
|
397 |
+
app,
|
398 |
+
host="0.0.0.0",
|
399 |
+
port=8000,
|
400 |
+
log_level="info",
|
401 |
+
access_log=True
|
402 |
+
)
|
app_streaming.py
CHANGED
@@ -46,7 +46,7 @@ def generator(mic, src_img):
|
|
46 |
# setup: online_mode=True でストリーミング
|
47 |
tmp_out = tempfile.mktemp(suffix=".mp4")
|
48 |
sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024)
|
49 |
-
N_total = int(np.ceil(len(wav_full) / sr *
|
50 |
sdk.setup_Nd(N_total)
|
51 |
|
52 |
# 処理開始時刻
|
|
|
46 |
# setup: online_mode=True でストリーミング
|
47 |
tmp_out = tempfile.mktemp(suffix=".mp4")
|
48 |
sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024)
|
49 |
+
N_total = int(np.ceil(len(wav_full) / sr * 20)) # 概算フレーム数
|
50 |
sdk.setup_Nd(N_total)
|
51 |
|
52 |
# 処理開始時刻
|
core/atomic_components/writer.py
CHANGED
@@ -3,7 +3,7 @@ import os
|
|
3 |
|
4 |
|
5 |
class VideoWriterByImageIO:
|
6 |
-
def __init__(self, video_path, fps=
|
7 |
video_format = kwargs.get("format", "mp4") # default is mp4 format
|
8 |
codec = kwargs.get("vcodec", "libx264") # default is libx264 encoding
|
9 |
quality = kwargs.get("quality") # video quality
|
|
|
3 |
|
4 |
|
5 |
class VideoWriterByImageIO:
|
6 |
+
def __init__(self, video_path, fps=20, **kwargs):
|
7 |
video_format = kwargs.get("format", "mp4") # default is mp4 format
|
8 |
codec = kwargs.get("vcodec", "libx264") # default is libx264 encoding
|
9 |
quality = kwargs.get("quality") # video quality
|
core/models/modules/lmdm_modules/model.py
CHANGED
@@ -237,7 +237,7 @@ class MotionDecoder(nn.Module):
|
|
237 |
def __init__(
|
238 |
self,
|
239 |
nfeats: int,
|
240 |
-
seq_len: int =
|
241 |
latent_dim: int = 256,
|
242 |
ff_size: int = 1024,
|
243 |
num_layers: int = 4,
|
|
|
237 |
def __init__(
|
238 |
self,
|
239 |
nfeats: int,
|
240 |
+
seq_len: int = 80, # 4 seconds, 20 fps
|
241 |
latent_dim: int = 256,
|
242 |
ff_size: int = 1024,
|
243 |
num_layers: int = 4,
|
streaming_client.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
DittoTalkingHead Streaming Client
|
3 |
+
WebSocketを使用したストリーミングクライアントの実装例
|
4 |
+
"""
|
5 |
+
import asyncio
|
6 |
+
import websockets
|
7 |
+
import numpy as np
|
8 |
+
import soundfile as sf
|
9 |
+
import base64
|
10 |
+
import json
|
11 |
+
import cv2
|
12 |
+
from typing import Optional, Callable
|
13 |
+
import pyaudio
|
14 |
+
import threading
|
15 |
+
import queue
|
16 |
+
from pathlib import Path
|
17 |
+
import logging
|
18 |
+
|
19 |
+
logging.basicConfig(level=logging.INFO)
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
class DittoStreamingClient:
|
23 |
+
"""DittoTalkingHeadストリーミングクライアント"""
|
24 |
+
|
25 |
+
def __init__(self, server_url="ws://localhost:8000/ws/generate"):
|
26 |
+
self.server_url = server_url
|
27 |
+
self.sample_rate = 16000
|
28 |
+
self.chunk_duration = 0.2 # 200ms
|
29 |
+
self.chunk_size = int(self.sample_rate * self.chunk_duration)
|
30 |
+
self.websocket = None
|
31 |
+
self.is_connected = False
|
32 |
+
self.frame_callback: Optional[Callable] = None
|
33 |
+
self.final_video_callback: Optional[Callable] = None
|
34 |
+
|
35 |
+
async def connect(self, source_image_path: str):
|
36 |
+
"""サーバーに接続してセッションを開始"""
|
37 |
+
try:
|
38 |
+
# 画像をBase64エンコード
|
39 |
+
with open(source_image_path, "rb") as f:
|
40 |
+
image_b64 = base64.b64encode(f.read()).decode('utf-8')
|
41 |
+
|
42 |
+
# WebSocket接続
|
43 |
+
self.websocket = await websockets.connect(self.server_url)
|
44 |
+
self.is_connected = True
|
45 |
+
|
46 |
+
# 初期設定を送信
|
47 |
+
await self.websocket.send(json.dumps({
|
48 |
+
"source_image": image_b64,
|
49 |
+
"sample_rate": self.sample_rate,
|
50 |
+
"chunk_duration": self.chunk_duration
|
51 |
+
}))
|
52 |
+
|
53 |
+
# 応答を待つ
|
54 |
+
response = await self.websocket.recv()
|
55 |
+
data = json.loads(response)
|
56 |
+
|
57 |
+
if data["type"] == "ready":
|
58 |
+
logger.info(f"Connected to server: {data['message']}")
|
59 |
+
return True
|
60 |
+
else:
|
61 |
+
logger.error(f"Connection failed: {data}")
|
62 |
+
return False
|
63 |
+
|
64 |
+
except Exception as e:
|
65 |
+
logger.error(f"Connection error: {e}")
|
66 |
+
self.is_connected = False
|
67 |
+
raise
|
68 |
+
|
69 |
+
async def disconnect(self):
|
70 |
+
"""接続を切断"""
|
71 |
+
if self.websocket:
|
72 |
+
await self.websocket.close()
|
73 |
+
self.is_connected = False
|
74 |
+
logger.info("Disconnected from server")
|
75 |
+
|
76 |
+
async def stream_audio_file(self, audio_path: str, source_image_path: str):
|
77 |
+
"""音声ファイルをストリーミング"""
|
78 |
+
try:
|
79 |
+
# 接続
|
80 |
+
await self.connect(source_image_path)
|
81 |
+
|
82 |
+
# 音声を読み込み
|
83 |
+
audio_data, sr = sf.read(audio_path)
|
84 |
+
if sr != self.sample_rate:
|
85 |
+
import librosa
|
86 |
+
audio_data = librosa.resample(
|
87 |
+
audio_data,
|
88 |
+
orig_sr=sr,
|
89 |
+
target_sr=self.sample_rate
|
90 |
+
)
|
91 |
+
|
92 |
+
# フレーム受信タスク
|
93 |
+
receive_task = asyncio.create_task(self._receive_frames())
|
94 |
+
|
95 |
+
# 音声をチャンク単位で送信
|
96 |
+
total_chunks = 0
|
97 |
+
for i in range(0, len(audio_data), self.chunk_size):
|
98 |
+
chunk = audio_data[i:i+self.chunk_size]
|
99 |
+
if len(chunk) < self.chunk_size:
|
100 |
+
chunk = np.pad(chunk, (0, self.chunk_size - len(chunk)))
|
101 |
+
|
102 |
+
# Float32として送信
|
103 |
+
await self.websocket.send(chunk.astype(np.float32).tobytes())
|
104 |
+
total_chunks += 1
|
105 |
+
|
106 |
+
# リアルタイムシミュレーション
|
107 |
+
await asyncio.sleep(self.chunk_duration)
|
108 |
+
|
109 |
+
# 進捗表示
|
110 |
+
progress = (i + self.chunk_size) / len(audio_data) * 100
|
111 |
+
logger.info(f"Streaming progress: {progress:.1f}%")
|
112 |
+
|
113 |
+
# 停止コマンドを送信
|
114 |
+
await self.websocket.send(json.dumps({"action": "stop"}))
|
115 |
+
logger.info(f"Sent {total_chunks} audio chunks")
|
116 |
+
|
117 |
+
# フレーム受信を待つ
|
118 |
+
await receive_task
|
119 |
+
|
120 |
+
finally:
|
121 |
+
await self.disconnect()
|
122 |
+
|
123 |
+
async def stream_microphone(self, source_image_path: str, duration: Optional[float] = None):
|
124 |
+
"""マイクからリアルタイムストリーミング"""
|
125 |
+
try:
|
126 |
+
# 接続
|
127 |
+
await self.connect(source_image_path)
|
128 |
+
|
129 |
+
# フレーム受信タスク
|
130 |
+
receive_task = asyncio.create_task(self._receive_frames())
|
131 |
+
|
132 |
+
# マイク録音用のキュー
|
133 |
+
audio_queue = queue.Queue()
|
134 |
+
stop_event = threading.Event()
|
135 |
+
|
136 |
+
# マイク録音スレッド
|
137 |
+
def record_audio():
|
138 |
+
p = pyaudio.PyAudio()
|
139 |
+
stream = p.open(
|
140 |
+
format=pyaudio.paFloat32,
|
141 |
+
channels=1,
|
142 |
+
rate=self.sample_rate,
|
143 |
+
input=True,
|
144 |
+
frames_per_buffer=self.chunk_size
|
145 |
+
)
|
146 |
+
|
147 |
+
logger.info("Recording started... Press Ctrl+C to stop")
|
148 |
+
|
149 |
+
try:
|
150 |
+
start_time = asyncio.get_event_loop().time()
|
151 |
+
while not stop_event.is_set():
|
152 |
+
if duration and (asyncio.get_event_loop().time() - start_time) > duration:
|
153 |
+
break
|
154 |
+
|
155 |
+
audio_chunk = stream.read(self.chunk_size, exception_on_overflow=False)
|
156 |
+
audio_queue.put(audio_chunk)
|
157 |
+
|
158 |
+
except Exception as e:
|
159 |
+
logger.error(f"Recording error: {e}")
|
160 |
+
finally:
|
161 |
+
stream.stop_stream()
|
162 |
+
stream.close()
|
163 |
+
p.terminate()
|
164 |
+
logger.info("Recording stopped")
|
165 |
+
|
166 |
+
# 録音スレッドを開始
|
167 |
+
record_thread = threading.Thread(target=record_audio)
|
168 |
+
record_thread.start()
|
169 |
+
|
170 |
+
try:
|
171 |
+
# 音声データを送信
|
172 |
+
while record_thread.is_alive() or not audio_queue.empty():
|
173 |
+
try:
|
174 |
+
audio_chunk = audio_queue.get(timeout=0.1)
|
175 |
+
audio_array = np.frombuffer(audio_chunk, dtype=np.float32)
|
176 |
+
await self.websocket.send(audio_array.tobytes())
|
177 |
+
except queue.Empty:
|
178 |
+
continue
|
179 |
+
except KeyboardInterrupt:
|
180 |
+
logger.info("Stopping recording...")
|
181 |
+
break
|
182 |
+
|
183 |
+
finally:
|
184 |
+
stop_event.set()
|
185 |
+
record_thread.join()
|
186 |
+
|
187 |
+
# 停止コマンドを送信
|
188 |
+
await self.websocket.send(json.dumps({"action": "stop"}))
|
189 |
+
|
190 |
+
# フレーム受信を待つ
|
191 |
+
await receive_task
|
192 |
+
|
193 |
+
finally:
|
194 |
+
await self.disconnect()
|
195 |
+
|
196 |
+
async def _receive_frames(self):
|
197 |
+
"""フレームとメッセージを受信"""
|
198 |
+
frame_count = 0
|
199 |
+
|
200 |
+
try:
|
201 |
+
while True:
|
202 |
+
message = await self.websocket.recv()
|
203 |
+
data = json.loads(message)
|
204 |
+
|
205 |
+
if data["type"] == "frame":
|
206 |
+
frame_count += 1
|
207 |
+
logger.info(f"Received frame {data['frame_id']} (FPS: {data.get('fps', 0)})")
|
208 |
+
|
209 |
+
if self.frame_callback:
|
210 |
+
# フレームをデコード
|
211 |
+
frame_data = base64.b64decode(data["data"])
|
212 |
+
nparr = np.frombuffer(frame_data, np.uint8)
|
213 |
+
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
214 |
+
self.frame_callback(frame, data)
|
215 |
+
|
216 |
+
elif data["type"] == "progress":
|
217 |
+
logger.info(f"Progress: {data['duration_seconds']:.1f}s processed")
|
218 |
+
|
219 |
+
elif data["type"] == "processing":
|
220 |
+
logger.info(f"Server: {data['message']}")
|
221 |
+
|
222 |
+
elif data["type"] == "final_video":
|
223 |
+
logger.info(f"Received final video ({data['size_bytes']} bytes, {data['duration_seconds']:.1f}s)")
|
224 |
+
|
225 |
+
if self.final_video_callback:
|
226 |
+
video_data = base64.b64decode(data["data"])
|
227 |
+
self.final_video_callback(video_data, data)
|
228 |
+
break
|
229 |
+
|
230 |
+
elif data["type"] == "error":
|
231 |
+
logger.error(f"Server error: {data['message']}")
|
232 |
+
break
|
233 |
+
|
234 |
+
except websockets.exceptions.ConnectionClosed:
|
235 |
+
logger.info("Connection closed by server")
|
236 |
+
except Exception as e:
|
237 |
+
logger.error(f"Receive error: {e}")
|
238 |
+
|
239 |
+
logger.info(f"Total frames received: {frame_count}")
|
240 |
+
|
241 |
+
def set_frame_callback(self, callback: Callable):
|
242 |
+
"""フレーム受信時のコールバックを設定"""
|
243 |
+
self.frame_callback = callback
|
244 |
+
|
245 |
+
def set_final_video_callback(self, callback: Callable):
|
246 |
+
"""最終動画受信時のコールバックを設定"""
|
247 |
+
self.final_video_callback = callback
|
248 |
+
|
249 |
+
|
250 |
+
# 使用例とテスト
|
251 |
+
async def main():
|
252 |
+
"""使用例"""
|
253 |
+
client = DittoStreamingClient()
|
254 |
+
|
255 |
+
# フレーム表示用のコールバック
|
256 |
+
def display_frame(frame, metadata):
|
257 |
+
cv2.imshow("Live Frame", frame)
|
258 |
+
cv2.waitKey(1)
|
259 |
+
|
260 |
+
# 最終動画保存用のコールバック
|
261 |
+
def save_video(video_data, metadata):
|
262 |
+
output_path = "output_streaming.mp4"
|
263 |
+
with open(output_path, "wb") as f:
|
264 |
+
f.write(video_data)
|
265 |
+
logger.info(f"Video saved to {output_path}")
|
266 |
+
|
267 |
+
client.set_frame_callback(display_frame)
|
268 |
+
client.set_final_video_callback(save_video)
|
269 |
+
|
270 |
+
# テスト画像とサンプル音声のパス
|
271 |
+
source_image = "example/reference.png"
|
272 |
+
audio_file = "example/audio.wav"
|
273 |
+
|
274 |
+
# ファイルが存在するか確認
|
275 |
+
if not Path(source_image).exists():
|
276 |
+
logger.error(f"Source image not found: {source_image}")
|
277 |
+
return
|
278 |
+
|
279 |
+
# 音声ファイルからストリーミング
|
280 |
+
if Path(audio_file).exists():
|
281 |
+
logger.info("=== Testing audio file streaming ===")
|
282 |
+
await client.stream_audio_file(audio_file, source_image)
|
283 |
+
else:
|
284 |
+
logger.warning(f"Audio file not found: {audio_file}")
|
285 |
+
|
286 |
+
# マイクからストリーミング(5秒間)
|
287 |
+
# logger.info("\n=== Testing microphone streaming (5 seconds) ===")
|
288 |
+
# await client.stream_microphone(source_image, duration=5.0)
|
289 |
+
|
290 |
+
cv2.destroyAllWindows()
|
291 |
+
|
292 |
+
|
293 |
+
# バッチ処理クライアント
|
294 |
+
class BatchStreamingClient:
|
295 |
+
"""複数のリクエストを並列処理するクライアント"""
|
296 |
+
|
297 |
+
def __init__(self, server_url="ws://localhost:8000/ws/generate", max_parallel=3):
|
298 |
+
self.server_url = server_url
|
299 |
+
self.max_parallel = max_parallel
|
300 |
+
|
301 |
+
async def process_batch(self, tasks: list):
|
302 |
+
"""バッチ処理"""
|
303 |
+
semaphore = asyncio.Semaphore(self.max_parallel)
|
304 |
+
|
305 |
+
async def process_with_limit(task):
|
306 |
+
async with semaphore:
|
307 |
+
client = DittoStreamingClient(self.server_url)
|
308 |
+
await client.stream_audio_file(
|
309 |
+
task["audio_path"],
|
310 |
+
task["image_path"]
|
311 |
+
)
|
312 |
+
return task["id"]
|
313 |
+
|
314 |
+
results = await asyncio.gather(
|
315 |
+
*[process_with_limit(task) for task in tasks],
|
316 |
+
return_exceptions=True
|
317 |
+
)
|
318 |
+
|
319 |
+
return results
|
320 |
+
|
321 |
+
|
322 |
+
if __name__ == "__main__":
|
323 |
+
# 単一クライアントのテスト
|
324 |
+
asyncio.run(main())
|
325 |
+
|
326 |
+
# バッチ処理の例
|
327 |
+
# batch_client = BatchStreamingClient()
|
328 |
+
# tasks = [
|
329 |
+
# {"id": 1, "audio_path": "audio1.wav", "image_path": "image1.png"},
|
330 |
+
# {"id": 2, "audio_path": "audio2.wav", "image_path": "image2.png"},
|
331 |
+
# ]
|
332 |
+
# asyncio.run(batch_client.process_batch(tasks))
|
test_streaming.py
CHANGED
@@ -34,7 +34,7 @@ def test_streaming():
|
|
34 |
tmp_out = tempfile.mktemp(suffix=".mp4")
|
35 |
|
36 |
sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024)
|
37 |
-
N_total = int(np.ceil(duration *
|
38 |
sdk.setup_Nd(N_total)
|
39 |
print("✅ セットアップ完了")
|
40 |
|
@@ -98,7 +98,7 @@ def test_streaming():
|
|
98 |
print(f"✅ 出力ファイル: {tmp_out}")
|
99 |
|
100 |
# 期待される結果の確認
|
101 |
-
expected_frames = int(duration *
|
102 |
if frames_received >= expected_frames * 0.8: # 80%以上
|
103 |
print("✅ テスト成功!")
|
104 |
else:
|
|
|
34 |
tmp_out = tempfile.mktemp(suffix=".mp4")
|
35 |
|
36 |
sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024)
|
37 |
+
N_total = int(np.ceil(duration * 20)) # 20fps
|
38 |
sdk.setup_Nd(N_total)
|
39 |
print("✅ セットアップ完了")
|
40 |
|
|
|
98 |
print(f"✅ 出力ファイル: {tmp_out}")
|
99 |
|
100 |
# 期待される結果の確認
|
101 |
+
expected_frames = int(duration * 20) # 20fps
|
102 |
if frames_received >= expected_frames * 0.8: # 80%以上
|
103 |
print("✅ テスト成功!")
|
104 |
else:
|