File size: 5,738 Bytes
f749a75 f954241 f749a75 f954241 f749a75 e97cf4d f749a75 e97cf4d f749a75 e97cf4d f749a75 e97cf4d f749a75 e97cf4d f749a75 e97cf4d f749a75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from snac import SNAC
import numpy as np
import torch
import asyncio
import threading
import queue
import os
print("DEBUG: Loading SNAC model...")
model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
snac_device = os.environ.get("SNAC_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
model = model.to(snac_device)
print(f"DEBUG: SNAC model loaded successfully on device: {snac_device}")
def convert_to_audio(multiframe, count):
frames = []
if len(multiframe) < 7:
return
codes_0 = torch.tensor([], device=snac_device, dtype=torch.int32)
codes_1 = torch.tensor([], device=snac_device, dtype=torch.int32)
codes_2 = torch.tensor([], device=snac_device, dtype=torch.int32)
num_frames = len(multiframe) // 7
frame = multiframe[:num_frames*7]
for j in range(num_frames):
i = 7*j
if codes_0.shape[0] == 0:
codes_0 = torch.tensor([frame[i]], device=snac_device, dtype=torch.int32)
else:
codes_0 = torch.cat([codes_0, torch.tensor([frame[i]], device=snac_device, dtype=torch.int32)])
if codes_1.shape[0] == 0:
codes_1 = torch.tensor([frame[i+1]], device=snac_device, dtype=torch.int32)
codes_1 = torch.cat([codes_1, torch.tensor([frame[i+4]], device=snac_device, dtype=torch.int32)])
else:
codes_1 = torch.cat([codes_1, torch.tensor([frame[i+1]], device=snac_device, dtype=torch.int32)])
codes_1 = torch.cat([codes_1, torch.tensor([frame[i+4]], device=snac_device, dtype=torch.int32)])
if codes_2.shape[0] == 0:
codes_2 = torch.tensor([frame[i+2]], device=snac_device, dtype=torch.int32)
codes_2 = torch.cat([codes_2, torch.tensor([frame[i+3]], device=snac_device, dtype=torch.int32)])
codes_2 = torch.cat([codes_2, torch.tensor([frame[i+5]], device=snac_device, dtype=torch.int32)])
codes_2 = torch.cat([codes_2, torch.tensor([frame[i+6]], device=snac_device, dtype=torch.int32)])
else:
codes_2 = torch.cat([codes_2, torch.tensor([frame[i+2]], device=snac_device, dtype=torch.int32)])
codes_2 = torch.cat([codes_2, torch.tensor([frame[i+3]], device=snac_device, dtype=torch.int32)])
codes_2 = torch.cat([codes_2, torch.tensor([frame[i+5]], device=snac_device, dtype=torch.int32)])
codes_2 = torch.cat([codes_2, torch.tensor([frame[i+6]], device=snac_device, dtype=torch.int32)])
codes = [codes_0.unsqueeze(0), codes_1.unsqueeze(0), codes_2.unsqueeze(0)]
# check that all tokens are between 0 and 4096 otherwise return *
if torch.any(codes[0] < 0) or torch.any(codes[0] > 4096) or torch.any(codes[1] < 0) or torch.any(codes[1] > 4096) or torch.any(codes[2] < 0) or torch.any(codes[2] > 4096):
return
with torch.inference_mode():
audio_hat = model.decode(codes)
audio_slice = audio_hat[:, :, 2048:4096]
detached_audio = audio_slice.detach().cpu()
audio_np = detached_audio.numpy()
audio_int16 = (audio_np * 32767).astype(np.int16)
audio_bytes = audio_int16.tobytes()
return audio_bytes
def turn_token_into_id(token_string, index):
# Strip whitespace
token_string = token_string.strip()
# Find the last token in the string
last_token_start = token_string.rfind("<custom_token_")
if last_token_start == -1:
print("No token found in the string")
return None
# Extract the last token
last_token = token_string[last_token_start:]
# Process the last token
if last_token.startswith("<custom_token_") and last_token.endswith(">"):
try:
number_str = last_token[14:-1]
return int(number_str) - 10 - ((index % 7) * 4096)
except ValueError:
return None
else:
return None
async def tokens_decoder(token_gen):
buffer = []
count = 0
token_count = 0
async for token_sim in token_gen:
token_count += 1
print(f"DEBUG DECODER: Processing token {token_count}: {repr(token_sim)}")
token = turn_token_into_id(token_sim, count)
print(f"DEBUG DECODER: Converted to ID: {token}")
if token is None:
pass
else:
if token > 0:
buffer.append(token)
count += 1
print(f"DEBUG DECODER: Added to buffer. Count: {count}, Buffer size: {len(buffer)}")
if count % 7 == 0 and count > 27:
buffer_to_proc = buffer[-28:]
print(f"DEBUG DECODER: Converting buffer to audio. Buffer: {buffer_to_proc}")
audio_samples = convert_to_audio(buffer_to_proc, count)
if audio_samples is not None:
print(f"DEBUG DECODER: Generated audio chunk of {len(audio_samples)} bytes")
yield audio_samples
else:
print("DEBUG DECODER: convert_to_audio returned None")
# ------------------ Synchronous Tokens Decoder Wrapper ------------------ #
def tokens_decoder_sync(syn_token_gen):
audio_queue = queue.Queue()
# Convert the synchronous token generator into an async generator.
async def async_token_gen():
for token in syn_token_gen:
yield token
async def async_producer():
# tokens_decoder.tokens_decoder is assumed to be an async generator that processes tokens.
async for audio_chunk in tokens_decoder(async_token_gen()):
audio_queue.put(audio_chunk)
audio_queue.put(None) # Sentinel
def run_async():
asyncio.run(async_producer())
thread = threading.Thread(target=run_async)
thread.start()
while True:
audio = audio_queue.get()
if audio is None:
break
yield audio
thread.join() |