File size: 4,843 Bytes
2c15189
 
 
fd06e70
0316ec3
4189fe1
9bf14d0
a09ea48
 
0316ec3
fd06e70
 
2c15189
9bf14d0
2008a3f
fd06e70
1ab029d
0316ec3
fd06e70
9bf14d0
0dfc310
fd06e70
9bf14d0
 
fd06e70
9bf14d0
fd06e70
9bf14d0
fd06e70
9bf14d0
fd06e70
9bf14d0
fd06e70
9bf14d0
 
 
 
fd06e70
 
9bf14d0
fd06e70
 
9bf14d0
3281189
fd06e70
 
 
 
 
 
 
9bf14d0
 
fd06e70
 
 
 
 
9bf14d0
 
fd06e70
 
c70d8eb
fd06e70
 
 
9bf14d0
c70d8eb
 
 
9bf14d0
 
fd06e70
 
9bf14d0
fd06e70
a8606ac
2c15189
a09ea48
4189fe1
fd06e70
9bf14d0
 
 
 
 
fd06e70
9bf14d0
fd06e70
c70d8eb
fd06e70
 
 
 
9bf14d0
2c15189
fd06e70
c70d8eb
fd06e70
9bf14d0
fd06e70
c70d8eb
 
 
 
fd06e70
 
d4630a2
fd06e70
 
 
2c15189
fd06e70
 
 
 
 
 
 
 
 
 
 
 
 
c70d8eb
fd06e70
 
c70d8eb
fd06e70
 
 
 
 
 
 
 
 
9bf14d0
2c15189
fd06e70
c70d8eb
fd06e70
4189fe1
fd06e70
a09ea48
2c15189
a09ea48
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import json
import asyncio

import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer

# — ENV & AUTH —
HF_TOKEN = os.getenv("HF_TOKEN", "")
if HF_TOKEN:
    login(HF_TOKEN)

# — DEVICE SETUP —
device = "cuda" if torch.cuda.is_available() else "cpu"

# — FASTAPI INSTANCE —
app = FastAPI()

# — HEALTHCHECK / ROOT —
@app.get("/")
async def read_root():
    return {"message": "TTS WebSocket up and running!"}

# — LOAD MODELS ON STARTUP —
@app.on_event("startup")
async def startup_event():
    global tokenizer, model, snac
    # 1) SNAC vocoder
    snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
    # 2) TTS model & tokenizer
    model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.bfloat16 if device == "cuda" else None,
        low_cpu_mem_usage=True
    )
    # make pad == eos
    model.config.pad_token_id = model.config.eos_token_id

# — HELPERS —
START_TOKEN      = 128259
END_TOKENS       = [128009, 128260]
RESET_MARKER     = 128257
EOS_TOKEN        = 128258
AUDIO_TOKEN_OFFSET = 128266  # to subtract from token→audio code

def prepare_inputs(text: str, voice: str):
    prompt = f"{voice}: {text}"
    in_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    start = torch.tensor([[START_TOKEN]], dtype=torch.int64, device=device)
    end   = torch.tensor([END_TOKENS], dtype=torch.int64, device=device)
    ids   = torch.cat([start, in_ids, end], dim=1)
    mask  = torch.ones_like(ids)
    return ids, mask

def decode_seven(tokens: list[int]) -> bytes:
    """Take exactly 7 audio‑codes, build SNAC input and decode to PCM16 bytes."""
    b = tokens
    l1 = [ b[0] ]
    l2 = [ b[1] - 1*4096, b[4] - 4*4096 ]
    l3 = [ b[2] - 2*4096, b[3] - 3*4096, b[5] - 5*4096, b[6] - 6*4096 ]
    codes = [
        torch.tensor(l1, device=device).unsqueeze(0),
        torch.tensor(l2, device=device).unsqueeze(0),
        torch.tensor(l3, device=device).unsqueeze(0),
    ]
    audio = snac.decode(codes).squeeze().cpu().numpy()
    pcm16 = (audio * 32767).astype("int16").tobytes()
    return pcm16

# — WEBSOCKET ENDPOINT —
@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
    await ws.accept()
    try:
        # 1) receive JSON request
        msg = await ws.receive_text()
        req = json.loads(msg)
        text  = req.get("text", "")
        voice = req.get("voice", "Jakob")

        # 2) prepare prompt
        input_ids, attention_mask = prepare_inputs(text, voice)
        prompt_len = input_ids.size(1)

        # 3) chunked generation setup
        past_kvs = None
        buffer: list[int] = []
        generated_offset = 0

        while True:
            # 4) generate up to 50 new tokens at a time
            out = model.generate(
                input_ids= input_ids if past_kvs is None else None,
                attention_mask=attention_mask if past_kvs is None else None,
                max_new_tokens=50,
                do_sample=True,
                temperature=0.7,
                top_p=0.95,
                repetition_penalty=1.1,
                eos_token_id=EOS_TOKEN,
                pad_token_id=EOS_TOKEN,
                use_cache=True,
                return_dict_in_generate=False,
                return_legacy_cache=True,
                past_key_values=past_kvs,
            )
            # out is a tuple: (generated_ids, new_past_kvs)
            gen_ids, past_kvs = out

            # 5) extract only newly generated tokens
            seq = gen_ids[0]
            new_seq = seq[prompt_len + generated_offset :]
            generated_offset += new_seq.size(0)

            # 6) process each new token
            stop = False
            for t in new_seq.tolist():
                if t == EOS_TOKEN:
                    stop = True
                    break
                if t == RESET_MARKER:
                    buffer.clear()
                    continue
                # convert to audio-code
                buffer.append(t - AUDIO_TOKEN_OFFSET)
                # once we have 7 codes, decode & stream
                if len(buffer) >= 7:
                    block = buffer[:7]
                    buffer = buffer[7:]
                    pcm_bytes = decode_seven(block)
                    await ws.send_bytes(pcm_bytes)
            if stop:
                break

        # 7) clean close
        await ws.close()

    except WebSocketDisconnect:
        pass
    except Exception as e:
        print("Error in /ws/tts:", e)
        await ws.close(code=1011)