File size: 6,091 Bytes
f92444a
 
4189fe1
9bf14d0
f92444a
10540d6
d9ea17d
0316ec3
f92444a
 
 
2008a3f
1ab029d
f92444a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bca75ea
f92444a
bca75ea
 
 
f92444a
 
bca75ea
f92444a
bca75ea
f92444a
9bf14d0
0dfc310
9bf14d0
f92444a
 
 
 
 
9bf14d0
 
d9ea17d
bca75ea
 
 
9bf14d0
f63f843
bca75ea
 
 
 
f63f843
bca75ea
 
f92444a
 
 
 
 
bca75ea
 
 
 
 
 
 
 
 
f92444a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8606ac
bca75ea
a09ea48
4189fe1
bca75ea
9ef5e61
4c833ce
 
f63f843
 
9ef5e61
bca75ea
 
 
 
 
4c833ce
bca75ea
4c833ce
f63f843
9ef5e61
f92444a
4c833ce
 
 
f92444a
4c833ce
 
9ef5e61
 
 
4c833ce
 
 
9ef5e61
4c833ce
9ef5e61
 
 
bca75ea
 
4c833ce
bca75ea
4c833ce
 
a09ea48
bca75ea
f92444a
4c833ce
f92444a
 
a4cfefc
f92444a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# app.py  -------------------------------------------------------------
import os, json, torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
from transformers.generation.utils import Cache
from snac import SNAC

# ── 0. Auth & Device ────────────────────────────────────────────────
if (tok := os.getenv("HF_TOKEN")):
    login(tok)

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.enable_flash_sdp(False)      # PyTorch‑2.2 fix

# ── 1. Konstanten ───────────────────────────────────────────────────
REPO             = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
CHUNK_TOKENS     = 50           #  ≀ 50Β β†’Β <Β 1Β s Latenz
START_TOKEN      = 128259
NEW_BLOCK_TOKEN  = 128257
EOS_TOKEN        = 128258
AUDIO_BASE       = 128266
VALID_AUDIO_IDS  = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)

# ── 2. Logit‑Maske  (nur Audio‑ und Steuer‑Token)  ──────────────────
class AudioMask(LogitsProcessor):
    def __init__(self, allowed: torch.Tensor):        # allowed @device!
        self.allowed = allowed

    def __call__(self, _ids, scores):
        mask = torch.full_like(scores, float("-inf"))
        mask[:, self.allowed] = 0.0
        return scores + mask

ALLOWED_IDS = torch.cat(
    [VALID_AUDIO_IDS,
     torch.tensor([NEW_BLOCK_TOKEN, EOS_TOKEN])]
).to(device)
MASKER = AudioMask(ALLOWED_IDS)

# ── 3. FastAPI GrundgerΓΌst ──────────────────────────────────────────
app = FastAPI()

@app.get("/")
async def root():
    return {"msg": "Orpheus‑TTS ready"}

# global handles
tok = model = snac = None

@app.on_event("startup")
async def load_models():
    global tok, model, snac
    tok   = AutoTokenizer.from_pretrained(REPO)
    snac  = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
    model = AutoModelForCausalLM.from_pretrained(
        REPO,
        low_cpu_mem_usage=True,
        device_map={"": 0} if device == "cuda" else None,
        torch_dtype=torch.bfloat16 if device == "cuda" else None,
    )
    model.config.pad_token_id = model.config.eos_token_id
    model.config.use_cache    = True

# ── 4. Helper ───────────────────────────────────────────────────────
def build_inputs(text: str, voice: str):
    prompt = f"{voice}: {text}"
    ids = tok(prompt, return_tensors="pt").input_ids.to(device)
    ids = torch.cat(
        [
            torch.tensor([[START_TOKEN]], device=device),
            ids,
            torch.tensor([[128009, 128260]], device=device),
        ],
        1,
    )
    return ids, torch.ones_like(ids)

def decode_block(b7: list[int]) -> bytes:
    l1, l2, l3 = [], [], []
    l1.append(b7[0])
    l2.append(b7[1] - 4096)
    l3.extend([b7[2] - 8192, b7[3] - 12288])
    l2.append(b7[4] - 16384)
    l3.extend([b7[5] - 20480, b7[6] - 24576])

    codes = [torch.tensor(x, device=device).unsqueeze(0) for x in (l1, l2, l3)]
    audio = snac.decode(codes).squeeze().cpu().numpy()
    return (audio * 32767).astype("int16").tobytes()

def new_tokens_only(full_seq, prev_len):
    """liefert Liste der Tokens, die *neu* hinzukamen"""
    return full_seq[prev_len:].tolist()

# ── 5. WebSocket‑Endpoint ───────────────────────────────────────────
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
    await ws.accept()
    try:
        req = json.loads(await ws.receive_text())
        ids, attn = build_inputs(req.get("text", ""), req.get("voice", "Jakob"))
        prompt_len = ids.size(1)
        past, buf = None, []

        while True:
            gen = model.generate(
                input_ids=ids if past is None else None,
                attention_mask=attn if past is None else None,
                past_key_values=past,
                max_new_tokens=CHUNK_TOKENS,
                logits_processor=[MASKER],
                do_sample=True, temperature=0.7, top_p=0.95,
                return_dict_in_generate=True,
                use_cache=True, return_legacy_cache=True,
            )

            past = gen.past_key_values if not isinstance(gen.past_key_values, Cache) else gen.past_key_values.to_legacy()
            seq  = gen.sequences[0].tolist()
            new_tok = seq[prompt_len:]
            prompt_len = len(seq)

            if not new_tok:
                continue                      # selten, aber mΓΆglich

            for t in new_tok:
                if t == EOS_TOKEN:
                    # ein einziges Close‑Frame genΓΌgt
                    await ws.close()          # <── einziges explizites close
                    return
                if t == NEW_BLOCK_TOKEN:
                    buf.clear(); continue
                buf.append(t - AUDIO_BASE)
                if len(buf) == 7:
                    await ws.send_bytes(decode_block(buf))
                    buf.clear()

            ids = attn = None                # nur noch Cache

    except WebSocketDisconnect:
        pass                                 # Client ging von selbst
    except Exception as e:
        print("WS‑Error:", e)
        if ws.client_state.name == "CONNECTED":
            await ws.close(code=1011)        # Fehler melden

# ── 6. Local run ────────────────────────────────────────────────────
if __name__ == "__main__":
    import uvicorn, sys
    port = int(sys.argv[1]) if len(sys.argv) > 1 else 7860
    uvicorn.run("app:app", host="0.0.0.0", port=port)