Spaces:
Runtime error
Runtime error
""" | |
DittoTalkingHead Streaming API Server | |
WebSocket/SSEによるリアルタイムストリーミング実装 | |
""" | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, File, UploadFile, HTTPException | |
from fastapi.responses import StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import asyncio | |
import tempfile | |
import numpy as np | |
import base64 | |
import json | |
from typing import AsyncGenerator, Optional | |
import cv2 | |
import time | |
import logging | |
from pathlib import Path | |
import traceback | |
from stream_pipeline_offline import StreamSDK | |
# ログ設定 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="DittoTalkingHead Streaming API") | |
# CORS設定 | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# SDK設定 | |
CFG_PKL = "checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" | |
DATA_ROOT = "checkpoints/ditto_pytorch" | |
# グローバル設定 | |
class AppState: | |
def __init__(self): | |
self.sdk: Optional[StreamSDK] = None | |
self.active_connections: int = 0 | |
self.max_connections: int = 5 | |
state = AppState() | |
def init_sdk(): | |
"""SDKの初期化""" | |
if state.sdk is None: | |
logger.info("Initializing StreamSDK...") | |
state.sdk = StreamSDK(CFG_PKL, DATA_ROOT) | |
logger.info("StreamSDK initialized successfully") | |
return state.sdk | |
async def startup_event(): | |
"""起動時にSDKを初期化""" | |
init_sdk() | |
async def root(): | |
"""ヘルスチェック""" | |
return { | |
"status": "ok", | |
"service": "DittoTalkingHead Streaming API", | |
"active_connections": state.active_connections, | |
"max_connections": state.max_connections | |
} | |
async def websocket_endpoint(websocket: WebSocket): | |
"""WebSocketエンドポイント - リアルタイムストリーミング""" | |
# 接続数チェック | |
if state.active_connections >= state.max_connections: | |
await websocket.close(code=1008, reason="Server busy") | |
return | |
await websocket.accept() | |
state.active_connections += 1 | |
logger.info(f"New WebSocket connection. Active: {state.active_connections}") | |
sdk_instance = None | |
output_path = None | |
try: | |
# 初期設定を受信 | |
config = await websocket.receive_json() | |
source_image_b64 = config.get("source_image") | |
sample_rate = config.get("sample_rate", 16000) | |
chunk_duration = config.get("chunk_duration", 0.2) | |
if not source_image_b64: | |
await websocket.send_json({"type": "error", "message": "source_image is required"}) | |
return | |
# 画像をデコードして一時ファイルに保存 | |
image_data = base64.b64decode(source_image_b64) | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img: | |
tmp_img.write(image_data) | |
source_path = tmp_img.name | |
# 出力ファイルの準備 | |
output_path = tempfile.mktemp(suffix=".mp4") | |
# SDK設定 | |
sdk_instance = init_sdk() | |
sdk_instance.setup(source_path, output_path, online_mode=True, max_size=1024) | |
await websocket.send_json({ | |
"type": "ready", | |
"message": "Ready to receive audio chunks", | |
"chunk_size": int(sample_rate * chunk_duration) | |
}) | |
# フレーム送信タスク | |
async def send_frames(): | |
frame_count = 0 | |
last_frame_time = time.time() | |
while True: | |
try: | |
current_time = time.time() | |
if sdk_instance.writer_queue.qsize() > 0: | |
frame = sdk_instance.writer_queue.get_nowait() | |
if frame is not None: | |
# フレームをJPEGエンコード(品質調整可能) | |
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80] | |
_, jpeg = cv2.imencode('.jpg', | |
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), | |
encode_param) | |
frame_b64 = base64.b64encode(jpeg).decode('utf-8') | |
# FPS計算 | |
fps = 1.0 / (current_time - last_frame_time) if current_time > last_frame_time else 0 | |
last_frame_time = current_time | |
await websocket.send_json({ | |
"type": "frame", | |
"frame_id": frame_count, | |
"timestamp": current_time, | |
"fps": round(fps, 2), | |
"data": frame_b64 | |
}) | |
frame_count += 1 | |
except asyncio.CancelledError: | |
break | |
except Exception as e: | |
logger.error(f"Error sending frame: {e}") | |
await asyncio.sleep(0.01) # 10ms間隔でチェック | |
# フレーム送信タスクを開始 | |
frame_task = asyncio.create_task(send_frames()) | |
# 音声チャンクを受信して処理 | |
total_samples = 0 | |
chunk_size = int(sample_rate * chunk_duration) | |
processing_start = time.time() | |
while True: | |
message = await websocket.receive() | |
if "bytes" in message: | |
# 音声データを受信 | |
audio_bytes = message["bytes"] | |
audio_chunk = np.frombuffer(audio_bytes, dtype=np.float32) | |
# パディング | |
if len(audio_chunk) < chunk_size: | |
audio_chunk = np.pad(audio_chunk, (0, chunk_size - len(audio_chunk))) | |
# SDKに送信 | |
sdk_instance.run_chunk(audio_chunk[:chunk_size]) | |
total_samples += len(audio_chunk) | |
# 進捗情報を送信 | |
elapsed = time.time() - processing_start | |
await websocket.send_json({ | |
"type": "progress", | |
"samples_processed": total_samples, | |
"duration_seconds": total_samples / sample_rate, | |
"elapsed_seconds": elapsed | |
}) | |
elif "text" in message: | |
# コマンドを受信 | |
command = json.loads(message["text"]) | |
if command.get("action") == "stop": | |
logger.info("Received stop command") | |
break | |
# 処理終了 | |
frame_task.cancel() | |
try: | |
await frame_task | |
except asyncio.CancelledError: | |
pass | |
# フレーム数を推定してsetup_Nd | |
estimated_frames = int(total_samples / sample_rate * 20) | |
sdk_instance.setup_Nd(estimated_frames) | |
# 残りのフレームを処理 | |
await websocket.send_json({"type": "processing", "message": "Finalizing video..."}) | |
# SDKを閉じて最終MP4を生成 | |
sdk_instance.close() | |
# 最終的なMP4を送信 | |
if Path(output_path).exists(): | |
with open(output_path, "rb") as f: | |
mp4_data = f.read() | |
mp4_b64 = base64.b64encode(mp4_data).decode('utf-8') | |
await websocket.send_json({ | |
"type": "final_video", | |
"size_bytes": len(mp4_data), | |
"duration_seconds": total_samples / sample_rate, | |
"data": mp4_b64 | |
}) | |
else: | |
await websocket.send_json({ | |
"type": "error", | |
"message": "Failed to generate final video" | |
}) | |
except WebSocketDisconnect: | |
logger.info("Client disconnected") | |
except Exception as e: | |
logger.error(f"WebSocket error: {e}") | |
logger.error(traceback.format_exc()) | |
try: | |
await websocket.send_json({ | |
"type": "error", | |
"message": str(e) | |
}) | |
except: | |
pass | |
finally: | |
state.active_connections -= 1 | |
logger.info(f"Connection closed. Active: {state.active_connections}") | |
# クリーンアップ | |
if output_path and Path(output_path).exists(): | |
try: | |
Path(output_path).unlink() | |
except: | |
pass | |
async def sse_generate( | |
source_image: UploadFile = File(...), | |
sample_rate: int = 16000, | |
max_duration: float = 10.0 | |
): | |
"""SSEエンドポイント - Server-Sent Eventsによるストリーミング""" | |
if state.active_connections >= state.max_connections: | |
raise HTTPException(status_code=503, detail="Server busy") | |
state.active_connections += 1 | |
async def generate() -> AsyncGenerator[str, None]: | |
sdk_instance = None | |
output_path = None | |
try: | |
# 画像を保存 | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img: | |
content = await source_image.read() | |
tmp_img.write(content) | |
source_path = tmp_img.name | |
output_path = tempfile.mktemp(suffix=".mp4") | |
# SDK設定 | |
sdk_instance = init_sdk() | |
sdk_instance.setup(source_path, output_path, online_mode=True, max_size=1024) | |
# イベント送信 | |
yield f"data: {json.dumps({'type': 'start', 'message': 'Processing started'})}\n\n" | |
# デモ用:ダミー音声を生成して処理 | |
chunk_duration = 0.2 | |
chunk_size = int(sample_rate * chunk_duration) | |
num_chunks = int(max_duration / chunk_duration) | |
for i in range(num_chunks): | |
# ダミー音声チャンク(実際の実装では音声ストリームから取得) | |
audio_chunk = np.random.randn(chunk_size).astype(np.float32) * 0.1 | |
sdk_instance.run_chunk(audio_chunk) | |
# フレームチェック | |
if sdk_instance.writer_queue.qsize() > 0: | |
try: | |
frame = sdk_instance.writer_queue.get_nowait() | |
if frame is not None: | |
# サムネイル生成(低解像度) | |
thumbnail = cv2.resize(frame, (160, 160)) | |
_, jpeg = cv2.imencode('.jpg', cv2.cvtColor(thumbnail, cv2.COLOR_RGB2BGR)) | |
frame_b64 = base64.b64encode(jpeg).decode('utf-8') | |
yield f"data: {json.dumps({'type': 'thumbnail', 'frame_id': i, 'data': frame_b64})}\n\n" | |
except: | |
pass | |
await asyncio.sleep(chunk_duration) | |
# 完了 | |
estimated_frames = num_chunks * 5 # 概算 | |
sdk_instance.setup_Nd(estimated_frames) | |
sdk_instance.close() | |
yield f"data: {json.dumps({'type': 'complete', 'frames': estimated_frames})}\n\n" | |
except Exception as e: | |
logger.error(f"SSE error: {e}") | |
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" | |
finally: | |
state.active_connections -= 1 | |
if output_path and Path(output_path).exists(): | |
try: | |
Path(output_path).unlink() | |
except: | |
pass | |
return StreamingResponse( | |
generate(), | |
media_type="text/event-stream", | |
headers={ | |
"Cache-Control": "no-cache", | |
"Connection": "keep-alive", | |
} | |
) | |
async def test_page(): | |
"""テスト用HTMLページ""" | |
html_content = """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>DittoTalkingHead Streaming Test</title> | |
<style> | |
body { font-family: Arial, sans-serif; margin: 20px; } | |
.container { max-width: 800px; margin: 0 auto; } | |
#live-frame { max-width: 100%; border: 1px solid #ccc; } | |
#status { margin: 10px 0; padding: 10px; background: #f0f0f0; } | |
.controls { margin: 20px 0; } | |
button { padding: 10px 20px; margin: 5px; } | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<h1>DittoTalkingHead Streaming Test</h1> | |
<div class="controls"> | |
<input type="file" id="source-image" accept="image/*"> | |
<button id="start-btn">Start Streaming</button> | |
<button id="stop-btn" disabled>Stop</button> | |
</div> | |
<div id="status">Ready</div> | |
<img id="live-frame" style="display: none;"> | |
<video id="final-video" controls style="display: none; width: 100%;"></video> | |
</div> | |
<script> | |
// WebSocket実装はstreaming_api_guide.mdを参照 | |
console.log('WebSocket endpoint: ws://localhost:8000/ws/generate'); | |
</script> | |
</body> | |
</html> | |
""" | |
from fastapi.responses import HTMLResponse | |
return HTMLResponse(content=html_content) | |
if __name__ == "__main__": | |
import uvicorn | |
import torch | |
# GPU設定 | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.backends.cudnn.benchmark = True | |
logger.info("Starting DittoTalkingHead Streaming API Server...") | |
logger.info(f"GPU available: {torch.cuda.is_available()}") | |
uvicorn.run( | |
app, | |
host="0.0.0.0", | |
port=8000, | |
log_level="info", | |
access_log=True | |
) |