File size: 6,615 Bytes
0b5b901
87012a8
4189fe1
9bf14d0
87012a8
5031731
d9ea17d
0316ec3
0b5b901
479f253
 
 
2008a3f
1ab029d
0b5b901
 
 
 
 
 
 
 
 
 
 
 
5031731
0b5b901
479f253
0b5b901
 
 
 
 
 
 
 
 
bca75ea
0b5b901
bca75ea
 
0b5b901
9bf14d0
0dfc310
9bf14d0
0b5b901
 
9bf14d0
 
0b5b901
5031731
 
0b5b901
 
9bf14d0
5031731
bca75ea
0b5b901
 
f63f843
5031731
0b5b901
5031731
bca75ea
0b5b901
 
5031731
f92444a
0b5b901
 
 
87012a8
479f253
0b5b901
 
 
 
 
 
 
 
 
 
 
 
a8606ac
0b5b901
a09ea48
4189fe1
0b5b901
 
 
5031731
0b5b901
 
 
 
 
f63f843
 
479f253
5d73119
5031731
 
 
5d73119
5031731
0b5b901
 
9ef5e61
5031731
0b5b901
5031731
f92444a
0b5b901
 
 
 
9ef5e61
0b5b901
5d73119
 
0b5b901
87012a8
0b5b901
 
 
 
 
9ef5e61
bca75ea
0b5b901
 
bca75ea
5031731
479f253
a09ea48
5031731
 
479f253
5031731
 
0b5b901
 
5031731
0b5b901
a4cfefc
5031731
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# app.py ──────────────────────────────────────────────────────────────
import os, json, torch, asyncio
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
from transformers.generation.utils import Cache
from snac import SNAC

# ── 0 · Login & Device ───────────────────────────────────────────────
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(HF_TOKEN)

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.enable_flash_sdp(False)        # CUDA‑Assert‑Fix

# ── 1 · Konstanten ───────────────────────────────────────────────────
REPO           = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
CHUNK_TOKENS   = 50
START_TOKEN    = 128259
NEW_BLOCK      = 128257
EOS_TOKEN      = 128258
AUDIO_BASE     = 128266
VALID_AUDIO    = torch.arange(AUDIO_BASE, AUDIO_BASE+4096)

# ── 2 · Logit‑Masker ─────────────────────────────────────────────────
class DynamicAudioMask(LogitsProcessor):
    def __init__(self, audio_ids: torch.Tensor, min_blocks:int=1):
        super().__init__()
        self.audio_ids  = audio_ids
        self.ctrl_ids   = torch.tensor([NEW_BLOCK], device=audio_ids.device)
        self.min_blocks = min_blocks
        self.blocks     = 0
    def __call__(self, inp, scores):
        allow = torch.cat([self.audio_ids, self.ctrl_ids])
        if self.blocks >= self.min_blocks:
            allow = torch.cat([allow,
                               torch.tensor([EOS_TOKEN], device=scores.device)])
        mask = torch.full_like(scores, float("-inf"))
        mask[:, allow] = 0
        return scores + mask

# ── 3 · FastAPI‑App ──────────────────────────────────────────────────
app = FastAPI()

@app.get("/")
async def root():
    return {"msg": "Orpheus‑TTS alive"}

@app.on_event("startup")
async def load():
    global tok, model, snac, masker
    print("⏳ Lade Modelle …")
    tok   = AutoTokenizer.from_pretrained(REPO)
    snac  = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
    model = AutoModelForCausalLM.from_pretrained(
        REPO,
        low_cpu_mem_usage=True,
        device_map={"":0} if device=="cuda" else None,
        torch_dtype=torch.bfloat16 if device=="cuda" else None)
    model.config.pad_token_id = model.config.eos_token_id
    model.config.use_cache    = True
    masker = DynamicAudioMask(VALID_AUDIO.to(device))
    print("✅ Modelle geladen")

# ── 4 · Hilfsfunktionen ──────────────────────────────────────────────
def build_inputs(text:str, voice:str):
    prompt = f"{voice}: {text}"
    ids = tok(prompt, return_tensors="pt").input_ids.to(device)
    ids = torch.cat([torch.tensor([[START_TOKEN]], device=device),
                     ids,
                     torch.tensor([[128009,128260]], device=device)],1)
    return ids, torch.ones_like(ids)

def decode_block(block):
    l1,l2,l3=[],[],[]
    l1.append(block[0])
    l2.append(block[1]-4096)
    l3.extend([block[2]-8192, block[3]-12288])
    l2.append(block[4]-16384)
    l3.extend([block[5]-20480, block[6]-24576])
    codes=[torch.tensor(x,device=device).unsqueeze(0) for x in (l1,l2,l3)]
    audio=snac.decode(codes).squeeze().cpu().numpy()
    return (audio*32767).astype("int16").tobytes()

# ── 5 · WebSocket‑TTS ────────────────────────────────────────────────
@app.websocket("/ws/tts")
async def tts(ws:WebSocket):
    await ws.accept()
    try:
        req = json.loads(await ws.receive_text())
        text  = req.get("text","")
        voice = req.get("voice","Jakob")

        ids, attn   = build_inputs(text, voice)
        total_len   = ids.shape[1]            # Länge des Prompts
        past        = None
        last_tok    = None
        buf         = []

        while True:
            out = model.generate(
                input_ids       = ids if past is None else torch.tensor([[last_tok]], device=device),
                attention_mask  = attn if past is None else None,
                past_key_values = past,
                max_new_tokens  = CHUNK_TOKENS,
                logits_processor= [masker],
                do_sample=True, temperature=0.7, top_p=0.95,
                use_cache=True, return_dict_in_generate=True,
                return_legacy_cache=True)

            pkv = out.past_key_values
            if isinstance(pkv, Cache): pkv = pkv.to_legacy()
            past = pkv

            seq   = out.sequences[0].tolist()
            new   = seq[total_len:]            # alles *nach* Prompt
            total_len = len(seq)               # fürs nächste Mal
            print("new tokens:", new[:32])

            if not new:                        # nichts generiert
                raise StopIteration

            for t in new:
                last_tok = t
                if t == EOS_TOKEN: raise StopIteration
                if t == NEW_BLOCK:
                    buf.clear();  continue
                buf.append(t-AUDIO_BASE)
                if len(buf)==7:
                    await ws.send_bytes(decode_block(buf))
                    buf.clear()
                    masker.blocks += 1
            ids, attn = None, None             # ab jetzt 1‑Token‑Step

    except (StopIteration, WebSocketDisconnect):
        pass
    except Exception as e:
        print("❌ WS‑Error:", e)
        if ws.client_state.name != "DISCONNECTED":
            await ws.close(code=1011)
    finally:
        if ws.client_state.name != "DISCONNECTED":
            try: await ws.close()
            except RuntimeError: pass

# ── 6 · local run ────────────────────────────────────────────────────
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860)