from snac import SNAC import numpy as np import torch import asyncio import threading import queue import os # Kartoffel-spezifische Konstanten (aus Referenz-Implementierung) CODE_TOKEN_OFFSET = 128266 CODE_START_TOKEN_ID = 128257 # Token für Audio-Code-Start CODE_REMOVE_TOKEN_ID = 128258 print("DEBUG KARTOFFEL: Loading SNAC model...") model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval() snac_device = os.environ.get("SNAC_DEVICE", "cuda" if torch.cuda.is_available() else "cpu") model = model.to(snac_device) if snac_device == "cuda": model = model.half() model.eval() print(f"DEBUG KARTOFFEL: SNAC model loaded successfully on device: {snac_device}") def redistribute_codes_kartoffel(code_list): """Kartoffel-spezifische Code-Redistribution""" if not code_list: return torch.tensor([[]], device=snac_device, dtype=torch.float32) num_codes = len(code_list) num_groups = num_codes // 7 if num_groups == 0: return torch.tensor([[]], device=snac_device, dtype=torch.float32) # Nur vollständige 7er-Gruppen verwenden code_list = code_list[:num_groups * 7] layer_1, layer_2, layer_3 = [], [], [] for i in range(num_groups): base_idx = 7 * i try: # Debug: Zeige die Code-Werte if i == 0: # Nur für die erste Gruppe print(f"DEBUG KARTOFFEL: First group codes: {code_list[base_idx:base_idx+7]}") # SNAC erwartet Codes im Bereich 0-4095 - Clamping ist notwendig! code_0 = min(max(code_list[base_idx], 0), 4095) code_1 = min(max(code_list[base_idx + 1] - 4096, 0), 4095) code_2 = min(max(code_list[base_idx + 2] - (2 * 4096), 0), 4095) code_3 = min(max(code_list[base_idx + 3] - (3 * 4096), 0), 4095) code_4 = min(max(code_list[base_idx + 4] - (4 * 4096), 0), 4095) code_5 = min(max(code_list[base_idx + 5] - (5 * 4096), 0), 4095) code_6 = min(max(code_list[base_idx + 6] - (6 * 4096), 0), 4095) layer_1.append(code_0) layer_2.append(code_1) layer_3.append(code_2) layer_3.append(code_3) layer_2.append(code_4) layer_3.append(code_5) layer_3.append(code_6) except IndexError: print(f"DEBUG KARTOFFEL: IndexError during code redistribution at group {i}. Skipping group.") break if not layer_1: return torch.tensor([[]], device=snac_device, dtype=torch.float32) codes = [ torch.tensor(layer_1, device=snac_device, dtype=torch.int32).unsqueeze(0), torch.tensor(layer_2, device=snac_device, dtype=torch.int32).unsqueeze(0), torch.tensor(layer_3, device=snac_device, dtype=torch.int32).unsqueeze(0), ] with torch.no_grad(): audio_hat = model.decode(codes) return audio_hat def convert_to_audio_kartoffel(audio_tensor): """Konvertiert Audio-Tensor zu PCM16-Bytes""" if audio_tensor is None or audio_tensor.numel() == 0: return b'' # Audio zu PCM16 konvertieren audio_numpy = (audio_tensor.squeeze().cpu().to(torch.float32).numpy() * 32767) audio_numpy = np.clip(audio_numpy, -32768, 32767).astype(np.int16) return audio_numpy.tobytes() def extract_kartoffel_tokens(token_text, tokenizer): """Extrahiert Audio-Token-IDs aus dem von vLLM generierten Text""" try: print(f"DEBUG KARTOFFEL: Received token_text: {token_text[:100]}...") # Text zu Token-IDs konvertieren (vLLM generiert Text, nicht numerische IDs) token_ids = tokenizer.encode(token_text) print(f"DEBUG KARTOFFEL: Encoded token_ids count: {len(token_ids)}") print(f"DEBUG KARTOFFEL: First 20 token_ids: {token_ids[:20]}") # Nach Audio-Start-Token suchen (128257) start_idx = -1 for i, token_id in enumerate(token_ids): if token_id == CODE_START_TOKEN_ID: start_idx = i break if start_idx == -1: print(f"DEBUG KARTOFFEL: No audio start token found ({CODE_START_TOKEN_ID})") print(f"DEBUG KARTOFFEL: Available unique tokens: {sorted(set(token_ids))}") return [] print(f"DEBUG KARTOFFEL: Found audio start token at index {start_idx}") # Audio-Tokens extrahieren (nach Start-Token) potential_code_tokens = token_ids[start_idx + 1:] print(f"DEBUG KARTOFFEL: Potential code tokens count: {len(potential_code_tokens)}") print(f"DEBUG KARTOFFEL: First 10 potential codes: {potential_code_tokens[:10]}") # Nur gültige Audio-Tokens (>= CODE_TOKEN_OFFSET, nicht REMOVE_TOKEN) valid_raw_codes = [ token for token in potential_code_tokens if token != CODE_REMOVE_TOKEN_ID and token >= CODE_TOKEN_OFFSET ] print(f"DEBUG KARTOFFEL: Valid raw codes count: {len(valid_raw_codes)}") # Offset abziehen valid_codes = [token - CODE_TOKEN_OFFSET for token in valid_raw_codes] return valid_codes except Exception as e: print(f"DEBUG KARTOFFEL: Error extracting tokens: {e}") return [] async def tokens_decoder_kartoffel(token_gen, tokenizer): """Kartoffel-spezifischer Token-Decoder""" buffer = [] accumulated_text = "" processed_count = 0 chunk_size = 28 # 4 Gruppen à 7 Tokens print("DEBUG KARTOFFEL: Starting token decoding") async for token_text in token_gen: accumulated_text += token_text print(f"DEBUG KARTOFFEL: Accumulated text length: {len(accumulated_text)}") # Audio-Tokens aus dem akkumulierten Text extrahieren valid_codes = extract_kartoffel_tokens(accumulated_text, tokenizer) if len(valid_codes) > processed_count: new_codes = valid_codes[processed_count:] buffer.extend(new_codes) print(f"DEBUG KARTOFFEL: Added {len(new_codes)} new codes. Buffer size: {len(buffer)}") # Wenn genug Codes für Audio-Generation vorhanden while len(buffer) >= chunk_size: codes_to_process = buffer[:chunk_size] buffer = buffer[chunk_size:] processed_count += chunk_size print(f"DEBUG KARTOFFEL: Processing {len(codes_to_process)} codes") # Audio generieren audio_tensor = redistribute_codes_kartoffel(codes_to_process) audio_bytes = convert_to_audio_kartoffel(audio_tensor) if audio_bytes: print(f"DEBUG KARTOFFEL: Generated {len(audio_bytes)} bytes of audio") yield audio_bytes else: print("DEBUG KARTOFFEL: No audio bytes generated") # Verbleibende Codes verarbeiten if len(buffer) >= 7: # Mindestens eine vollständige Gruppe final_count = (len(buffer) // 7) * 7 final_codes = buffer[:final_count] print(f"DEBUG KARTOFFEL: Processing final {len(final_codes)} codes") audio_tensor = redistribute_codes_kartoffel(final_codes) audio_bytes = convert_to_audio_kartoffel(audio_tensor) if audio_bytes: print(f"DEBUG KARTOFFEL: Generated final {len(audio_bytes)} bytes of audio") yield audio_bytes print("DEBUG KARTOFFEL: Token decoding completed") def tokens_decoder_kartoffel_sync(syn_token_gen, tokenizer): """Synchroner Wrapper für den Kartoffel-Decoder""" audio_queue = queue.Queue() # Synchronen Generator zu async konvertieren async def async_token_gen(): for token in syn_token_gen: yield token async def async_producer(): try: async for audio_chunk in tokens_decoder_kartoffel(async_token_gen(), tokenizer): audio_queue.put(audio_chunk) except Exception as e: print(f"DEBUG KARTOFFEL: Error in async producer: {e}") import traceback traceback.print_exc() finally: audio_queue.put(None) # Sentinel def run_async(): asyncio.run(async_producer()) thread = threading.Thread(target=run_async) thread.start() while True: audio = audio_queue.get() if audio is None: break yield audio thread.join()