Tomtom84 commited on
Commit
d4e3b98
·
verified ·
1 Parent(s): 91da710

Create kartoffel_decoder.py

Browse files
Files changed (1) hide show
  1. orpheus-tts/kartoffel_decoder.py +196 -0
orpheus-tts/kartoffel_decoder.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from snac import SNAC
2
+ import numpy as np
3
+ import torch
4
+ import asyncio
5
+ import threading
6
+ import queue
7
+ import os
8
+
9
+ # Kartoffel-spezifische Konstanten
10
+ CODE_TOKEN_OFFSET = 128266
11
+ CODE_START_TOKEN_ID = 128257
12
+ CODE_REMOVE_TOKEN_ID = 128258
13
+
14
+ print("DEBUG KARTOFFEL: Loading SNAC model...")
15
+ model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
16
+
17
+ snac_device = os.environ.get("SNAC_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
18
+ model = model.to(snac_device)
19
+ if snac_device == "cuda":
20
+ model = model.half()
21
+ model.eval()
22
+ print(f"DEBUG KARTOFFEL: SNAC model loaded successfully on device: {snac_device}")
23
+
24
+ def redistribute_codes_kartoffel(code_list):
25
+ """Kartoffel-spezifische Code-Redistribution"""
26
+ if not code_list:
27
+ return torch.tensor([[]], device=snac_device, dtype=torch.float32)
28
+
29
+ num_codes = len(code_list)
30
+ num_groups = num_codes // 7
31
+ if num_groups == 0:
32
+ return torch.tensor([[]], device=snac_device, dtype=torch.float32)
33
+
34
+ # Nur vollständige 7er-Gruppen verwenden
35
+ code_list = code_list[:num_groups * 7]
36
+
37
+ layer_1, layer_2, layer_3 = [], [], []
38
+ for i in range(num_groups):
39
+ base_idx = 7 * i
40
+ try:
41
+ layer_1.append(code_list[base_idx])
42
+ layer_2.append(code_list[base_idx + 1] - 4096)
43
+ layer_3.append(code_list[base_idx + 2] - (2 * 4096))
44
+ layer_3.append(code_list[base_idx + 3] - (3 * 4096))
45
+ layer_2.append(code_list[base_idx + 4] - (4 * 4096))
46
+ layer_3.append(code_list[base_idx + 5] - (5 * 4096))
47
+ layer_3.append(code_list[base_idx + 6] - (6 * 4096))
48
+ except IndexError:
49
+ print(f"DEBUG KARTOFFEL: IndexError during code redistribution at group {i}. Skipping group.")
50
+ break
51
+
52
+ if not layer_1:
53
+ return torch.tensor([[]], device=snac_device, dtype=torch.float32)
54
+
55
+ codes = [
56
+ torch.tensor(layer_1, device=snac_device).unsqueeze(0),
57
+ torch.tensor(layer_2, device=snac_device).unsqueeze(0),
58
+ torch.tensor(layer_3, device=snac_device).unsqueeze(0),
59
+ ]
60
+
61
+ with torch.no_grad():
62
+ audio_hat = model.decode(codes)
63
+ return audio_hat
64
+
65
+ def convert_to_audio_kartoffel(audio_tensor):
66
+ """Konvertiert Audio-Tensor zu PCM16-Bytes"""
67
+ if audio_tensor is None or audio_tensor.numel() == 0:
68
+ return b''
69
+
70
+ # Audio zu PCM16 konvertieren
71
+ audio_numpy = (audio_tensor.squeeze().cpu().to(torch.float32).numpy() * 32767)
72
+ audio_numpy = np.clip(audio_numpy, -32768, 32767).astype(np.int16)
73
+ return audio_numpy.tobytes()
74
+
75
+ def extract_kartoffel_tokens(token_text, tokenizer):
76
+ """Extrahiert Audio-Token-IDs aus dem generierten Text"""
77
+ try:
78
+ # Text zu Token-IDs konvertieren
79
+ token_ids = tokenizer.encode(token_text)
80
+
81
+ # Nach Start-Token suchen
82
+ start_idx = -1
83
+ for i, token_id in enumerate(token_ids):
84
+ if token_id == CODE_START_TOKEN_ID:
85
+ start_idx = i
86
+ break
87
+
88
+ if start_idx == -1:
89
+ return []
90
+
91
+ # Audio-Tokens extrahieren (nach Start-Token)
92
+ potential_code_tokens = token_ids[start_idx + 1:]
93
+
94
+ # Nur gültige Audio-Tokens (>= CODE_TOKEN_OFFSET, nicht REMOVE_TOKEN)
95
+ valid_raw_codes = [
96
+ token for token in potential_code_tokens
97
+ if token != CODE_REMOVE_TOKEN_ID and token >= CODE_TOKEN_OFFSET
98
+ ]
99
+
100
+ # Offset abziehen
101
+ valid_codes = [token - CODE_TOKEN_OFFSET for token in valid_raw_codes]
102
+
103
+ return valid_codes
104
+
105
+ except Exception as e:
106
+ print(f"DEBUG KARTOFFEL: Error extracting tokens: {e}")
107
+ return []
108
+
109
+ async def tokens_decoder_kartoffel(token_gen, tokenizer):
110
+ """Kartoffel-spezifischer Token-Decoder"""
111
+ buffer = []
112
+ accumulated_text = ""
113
+ processed_count = 0
114
+ chunk_size = 28 # 4 Gruppen à 7 Tokens
115
+
116
+ print("DEBUG KARTOFFEL: Starting token decoding")
117
+
118
+ async for token_text in token_gen:
119
+ accumulated_text += token_text
120
+ print(f"DEBUG KARTOFFEL: Accumulated text length: {len(accumulated_text)}")
121
+
122
+ # Audio-Tokens aus dem akkumulierten Text extrahieren
123
+ valid_codes = extract_kartoffel_tokens(accumulated_text, tokenizer)
124
+
125
+ if len(valid_codes) > processed_count:
126
+ new_codes = valid_codes[processed_count:]
127
+ buffer.extend(new_codes)
128
+ print(f"DEBUG KARTOFFEL: Added {len(new_codes)} new codes. Buffer size: {len(buffer)}")
129
+
130
+ # Wenn genug Codes für Audio-Generation vorhanden
131
+ while len(buffer) >= chunk_size:
132
+ codes_to_process = buffer[:chunk_size]
133
+ buffer = buffer[chunk_size:]
134
+ processed_count += chunk_size
135
+
136
+ print(f"DEBUG KARTOFFEL: Processing {len(codes_to_process)} codes")
137
+
138
+ # Audio generieren
139
+ audio_tensor = redistribute_codes_kartoffel(codes_to_process)
140
+ audio_bytes = convert_to_audio_kartoffel(audio_tensor)
141
+
142
+ if audio_bytes:
143
+ print(f"DEBUG KARTOFFEL: Generated {len(audio_bytes)} bytes of audio")
144
+ yield audio_bytes
145
+ else:
146
+ print("DEBUG KARTOFFEL: No audio bytes generated")
147
+
148
+ # Verbleibende Codes verarbeiten
149
+ if len(buffer) >= 7: # Mindestens eine vollständige Gruppe
150
+ final_count = (len(buffer) // 7) * 7
151
+ final_codes = buffer[:final_count]
152
+
153
+ print(f"DEBUG KARTOFFEL: Processing final {len(final_codes)} codes")
154
+
155
+ audio_tensor = redistribute_codes_kartoffel(final_codes)
156
+ audio_bytes = convert_to_audio_kartoffel(audio_tensor)
157
+
158
+ if audio_bytes:
159
+ print(f"DEBUG KARTOFFEL: Generated final {len(audio_bytes)} bytes of audio")
160
+ yield audio_bytes
161
+
162
+ print("DEBUG KARTOFFEL: Token decoding completed")
163
+
164
+ def tokens_decoder_kartoffel_sync(syn_token_gen, tokenizer):
165
+ """Synchroner Wrapper für den Kartoffel-Decoder"""
166
+ audio_queue = queue.Queue()
167
+
168
+ # Synchronen Generator zu async konvertieren
169
+ async def async_token_gen():
170
+ for token in syn_token_gen:
171
+ yield token
172
+
173
+ async def async_producer():
174
+ try:
175
+ async for audio_chunk in tokens_decoder_kartoffel(async_token_gen(), tokenizer):
176
+ audio_queue.put(audio_chunk)
177
+ except Exception as e:
178
+ print(f"DEBUG KARTOFFEL: Error in async producer: {e}")
179
+ import traceback
180
+ traceback.print_exc()
181
+ finally:
182
+ audio_queue.put(None) # Sentinel
183
+
184
+ def run_async():
185
+ asyncio.run(async_producer())
186
+
187
+ thread = threading.Thread(target=run_async)
188
+ thread.start()
189
+
190
+ while True:
191
+ audio = audio_queue.get()
192
+ if audio is None:
193
+ break
194
+ yield audio
195
+
196
+ thread.join()