File size: 4,416 Bytes
d9a2a3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2089ecf
d9a2a3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2089ecf
d9a2a3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
ストリーミング実装のテストスクリプト
"""
import numpy as np
import soundfile as sf
import tempfile
import time
from pathlib import Path
from stream_pipeline_offline import StreamSDK

# テスト設定
CFG_PKL = "checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl"
DATA_ROOT = "checkpoints/ditto_pytorch"
EXAMPLES_DIR = Path("example")

def test_streaming():
    """ストリーミング機能の基本テスト"""
    print("=== ストリーミング機能テスト開始 ===")
    
    # テスト用の音声を生成(3秒のサイン波)
    duration = 3.0  # seconds
    sample_rate = 16000
    t = np.linspace(0, duration, int(sample_rate * duration))
    audio_data = np.sin(2 * np.pi * 440 * t) * 0.5  # 440Hz
    
    # SDKの初期化
    print("1. SDK初期化...")
    sdk = StreamSDK(CFG_PKL, DATA_ROOT)
    print("✅ SDK初期化完了")
    
    # セットアップ
    print("\n2. ストリーミングモードでセットアップ...")
    src_img = str(EXAMPLES_DIR / "reference.png")
    tmp_out = tempfile.mktemp(suffix=".mp4")
    
    sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024)
    N_total = int(np.ceil(duration * 20))  # 20fps
    sdk.setup_Nd(N_total)
    print("✅ セットアップ完了")
    
    # チャンク単位で音声を送信
    print("\n3. チャンク単位で音声送信...")
    chunk_sec = 0.2  # 200ms
    chunk_samples = int(sample_rate * chunk_sec)
    chunks_sent = 0
    frames_received = 0
    
    start_time = time.time()
    
    for i in range(0, len(audio_data), chunk_samples):
        chunk = audio_data[i:i + chunk_samples]
        if len(chunk) < chunk_samples:
            chunk = np.pad(chunk, (0, chunk_samples - len(chunk)))
        
        sdk.run_chunk(chunk)
        chunks_sent += 1
        
        # キューからフレームを確認
        while sdk.writer_queue.qsize() > 0:
            try:
                frame = sdk.writer_queue.get_nowait()
                if frame is not None:
                    frames_received += 1
                    print(f"  フレーム {frames_received} 受信 (チャンク {chunks_sent})")
            except:
                break
        
        time.sleep(0.05)  # 少し待機
    
    # 残りのフレームを待つ
    print("\n4. 残りのフレームを処理...")
    timeout = 5.0  # 5秒タイムアウト
    timeout_start = time.time()
    
    while time.time() - timeout_start < timeout:
        if sdk.writer_queue.qsize() > 0:
            try:
                frame = sdk.writer_queue.get_nowait()
                if frame is not None:
                    frames_received += 1
                    print(f"  フレーム {frames_received} 受信")
            except:
                pass
        else:
            time.sleep(0.1)
    
    # クローズ
    print("\n5. SDKクローズ...")
    sdk.close()
    
    elapsed = time.time() - start_time
    
    # 結果
    print("\n=== テスト結果 ===")
    print(f"✅ 送信チャンク数: {chunks_sent}")
    print(f"✅ 受信フレーム数: {frames_received}")
    print(f"✅ 処理時間: {elapsed:.2f}秒")
    print(f"✅ 出力ファイル: {tmp_out}")
    
    # 期待される結果の確認
    expected_frames = int(duration * 20)  # 20fps
    if frames_received >= expected_frames * 0.8:  # 80%以上
        print("✅ テスト成功!")
    else:
        print(f"⚠️ 期待フレーム数 ({expected_frames}) に対して受信数が少ない")
    
    return True


def test_writer_queue():
    """writer_queueの動作確認"""
    print("\n=== writer_queue 動作確認 ===")
    
    sdk = StreamSDK(CFG_PKL, DATA_ROOT)
    
    # キューの存在確認
    if hasattr(sdk, 'writer_queue'):
        print("✅ writer_queue が存在します")
        print(f"  キューサイズ: {sdk.writer_queue.qsize()}")
        print(f"  最大サイズ: {sdk.writer_queue.maxsize}")
    else:
        print("❌ writer_queue が見つかりません")
        return False
    
    return True


if __name__ == "__main__":
    # writer_queueの確認
    if not test_writer_queue():
        print("基本的な要件が満たされていません")
        exit(1)
    
    # ストリーミングテスト
    try:
        test_streaming()
    except Exception as e:
        print(f"❌ エラー: {e}")
        import traceback
        traceback.print_exc()