File size: 6,052 Bytes
0b5b901
87012a8
4189fe1
9bf14d0
87012a8
5031731
d9ea17d
0316ec3
d44e840
479f253
 
 
2008a3f
1ab029d
d44e840
 
 
 
 
 
 
 
 
 
 
 
 
 
0b5b901
479f253
d44e840
 
 
 
 
0b5b901
d44e840
0b5b901
 
bca75ea
0b5b901
bca75ea
 
d44e840
9bf14d0
0dfc310
9bf14d0
0b5b901
d44e840
9bf14d0
 
0b5b901
5031731
 
0b5b901
 
9bf14d0
5031731
bca75ea
0b5b901
d44e840
 
f63f843
5031731
d44e840
5031731
bca75ea
d44e840
0b5b901
5031731
f92444a
0b5b901
 
 
87012a8
479f253
d44e840
0b5b901
d44e840
 
 
 
 
0b5b901
 
 
 
d44e840
a8606ac
d44e840
a09ea48
4189fe1
d44e840
 
 
 
f63f843
 
d44e840
479f253
d44e840
 
 
 
 
5031731
0b5b901
 
d44e840
5031731
d44e840
0b5b901
5031731
d44e840
f92444a
d44e840
 
 
5d73119
d44e840
0b5b901
87012a8
0b5b901
d44e840
 
 
9ef5e61
bca75ea
0b5b901
d44e840
 
bca75ea
5031731
479f253
a09ea48
5031731
 
479f253
5031731
 
0b5b901
 
5031731
d44e840
a4cfefc
5031731
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# app.py ──────────────────────────────────────────────────────────────
import os, json, torch, asyncio
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
from transformers.generation.utils import Cache
from snac import SNAC

# 0 · Auth & Device ---------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(HF_TOKEN)

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.enable_flash_sdp(False)            # SDP‑Assert fix

# 1 · Konstanten ------------------------------------------------------
REPO         = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
CHUNK_TOKENS = 50
START_TOKEN  = 128259
NEW_BLOCK    = 128257
EOS_TOKEN    = 128258
AUDIO_BASE   = 128266
AUDIO_SPAN   = 4096 * 7                                # 28 672 Codes
VALID_AUDIO  = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)

# 2 · Logit‑Masker ----------------------------------------------------
class DynamicMask(LogitsProcessor):
    def __init__(self, audio_ids: torch.Tensor, min_blocks:int=1):
        super().__init__()
        self.audio_ids = audio_ids
        self.ctrl_ids  = torch.tensor([NEW_BLOCK], device=audio_ids.device)
        self.blocks    = 0
        self.min_blk   = min_blocks
    def __call__(self, inp_ids, scores):
        allow = torch.cat([self.audio_ids, self.ctrl_ids])
        if self.blocks >= self.min_blk:
            allow = torch.cat([allow,
                               torch.tensor([EOS_TOKEN], device=scores.device)])
        mask = torch.full_like(scores, float("-inf"))
        mask[:, allow] = 0
        return scores + mask

# 3 · FastAPI‑App -----------------------------------------------------
app = FastAPI()

@app.get("/")
async def root():
    return {"msg": "Orpheus‑TTS online"}

@app.on_event("startup")
async def load():
    global tok, model, snac, masker
    print("⏳ Lade Modelle …")
    tok   = AutoTokenizer.from_pretrained(REPO)
    snac  = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
    model = AutoModelForCausalLM.from_pretrained(
        REPO,
        low_cpu_mem_usage=True,
        device_map={"":0} if device=="cuda" else None,
        torch_dtype=torch.bfloat16 if device=="cuda" else None,
    )
    model.config.pad_token_id = model.config.eos_token_id
    model.config.use_cache    = True
    masker = DynamicMask(VALID_AUDIO.to(device))
    print("✅ Modelle geladen")

# 4 · Hilfsfunktionen -------------------------------------------------
def build_inputs(text:str, voice:str):
    prompt = f"{voice}: {text}"
    ids = tok(prompt, return_tensors="pt").input_ids.to(device)
    ids = torch.cat([torch.tensor([[START_TOKEN]], device=device),
                     ids,
                     torch.tensor([[128009,128260]], device=device)],1)
    return ids, torch.ones_like(ids)

def decode_block(b):
    l1,l2,l3=[],[],[]
    l1.append(b[0])
    l2.append(b[1]-4096)
    l3 += [b[2]-8192,  b[3]-12288]
    l2.append(b[4]-16384)
    l3 += [b[5]-20480, b[6]-24576]
    codes=[torch.tensor(x,device=device).unsqueeze(0) for x in (l1,l2,l3)]
    audio=snac.decode(codes).squeeze().cpu().numpy()
    return (audio*32767).astype("int16").tobytes()

# 5 · WebSocket‑Endpoint ---------------------------------------------
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
    await ws.accept()
    try:
        req   = json.loads(await ws.receive_text())
        ids, attn = build_inputs(req.get("text",""), req.get("voice","Jakob"))
        past, last_tok, buf = None, None, []
        prompt_len = ids.shape[1]

        while True:
            print(f"DEBUG: Before generate - past is None: {past is None}") # Added logging
            out = model.generate(
                input_ids      = ids if past is None else torch.tensor([[last_tok]], device=device),
                attention_mask = attn if past is None else None,
                past_key_values= past,
                max_new_tokens = CHUNK_TOKENS,
                logits_processor=[masker],
                do_sample=True, temperature=0.7, top_p=0.95,
                use_cache=True, return_dict_in_generate=True,
                return_legacy_cache=True)
            print(f"DEBUG: After generate - type of out.past_key_values: {type(out.past_key_values)}") # Added logging
            pkv = out.past_key_values
            print(f"DEBUG: After getting pkv - type of pkv: {type(pkv)}") # Added logging
            if isinstance(pkv, Cache): pkv = pkv.to_legacy()
            past = pkv
            print(f"DEBUG: After cache handling - past is None: {past is None}") # Added logging

            seq  = out.sequences[0].tolist()
            new  = seq[prompt_len:];  prompt_len = len(seq)
            print("new tokens:", new[:25])

            if not new: raise StopIteration
            for t in new:
                last_tok = t
                if t == EOS_TOKEN: raise StopIteration
                if t == NEW_BLOCK: buf.clear(); continue
                buf.append(t - AUDIO_BASE)
                if len(buf) == 7:
                    await ws.send_bytes(decode_block(buf))
                    buf.clear()
                    masker.blocks += 1

            ids, attn = None, None                  # ab jetzt 1‑Token‑Step

    except (StopIteration, WebSocketDisconnect):
        pass
    except Exception as e:
        print("❌ WS‑Error:", e)
        if ws.client_state.name != "DISCONNECTED":
            await ws.close(code=1011)
    finally:
        if ws.client_state.name != "DISCONNECTED":
            try: await ws.close()
            except RuntimeError: pass

# 6 · Local run -------------------------------------------------------
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860)