File size: 5,631 Bytes
bca75ea
 
4189fe1
9bf14d0
bca75ea
d9ea17d
0316ec3
bca75ea
d9ea17d
2c15189
a4cfefc
2008a3f
1ab029d
0316ec3
bca75ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bf14d0
0dfc310
9bf14d0
bca75ea
 
9bf14d0
 
d9ea17d
bca75ea
 
 
9bf14d0
f63f843
bca75ea
 
 
 
f63f843
bca75ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8606ac
bca75ea
a09ea48
4189fe1
bca75ea
 
 
f63f843
bca75ea
 
 
f63f843
 
bca75ea
 
 
 
 
 
 
f63f843
bca75ea
f63f843
bca75ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4cfefc
a09ea48
bca75ea
a09ea48
bca75ea
 
 
a4cfefc
bca75ea
a4cfefc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# app.py  ─────────────────────────────────────────────────────────────
import os, json, asyncio, torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor
from snac import SNAC

# ── 0.Β HF‑Auth & Device ──────────────────────────────────────────────
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(HF_TOKEN)

device = "cuda" if torch.cuda.is_available() else "cpu"

# Flash‑Attention‑Bug in PyTorchΒ 2.2.x
torch.backends.cuda.enable_flash_sdp(False)

# ── 1.Β Konstanten ────────────────────────────────────────────────────
REPO              = "SebastianBodza/Kartoffel_Orpheus-3B_german_synthetic-v0.1"
CHUNK_TOKENS      = 50
START_TOKEN       = 128259
NEW_BLOCK_TOKEN   = 128257
EOS_TOKEN         = 128258
AUDIO_BASE        = 128266
VALID_AUDIO_IDS   = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096)

# ── 2.Β Logit‑Processor zum Maskieren ────────────────────────────────
class AudioLogitMask(LogitsProcessor):
    def __init__(self, allowed_ids: torch.Tensor):
        super().__init__()
        self.allowed = allowed_ids

    def __call__(self, input_ids, scores):
        # scores shape: [batch, vocab]
        mask = torch.full_like(scores, float("-inf"))
        mask[:, self.allowed] = 0
        return scores + mask

ALLOWED_IDS = torch.cat(
    [VALID_AUDIO_IDS, torch.tensor([NEW_BLOCK_TOKEN, EOS_TOKEN])]
).to(device)
MASKER = AudioLogitMask(ALLOWED_IDS)

# ── 3.Β FastAPI ‑ GrundgerΓΌst ─────────────────────────────────────────
app = FastAPI()

@app.get("/")
async def ping():
    return {"msg": "Orpheus‑TTS OK"}

@app.on_event("startup")
async def load_models():
    global tok, model, snac
    tok   = AutoTokenizer.from_pretrained(REPO)
    snac  = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
    model = AutoModelForCausalLM.from_pretrained(
        REPO,
        low_cpu_mem_usage=True,
        device_map={"": 0} if device == "cuda" else None,
        torch_dtype=torch.bfloat16 if device == "cuda" else None,
    )
    model.config.pad_token_id = model.config.eos_token_id
    model.config.use_cache    = True

# ── 4.Β Hilfs‑Funktionen ─────────────────────────────────────────────
def build_prompt(text:str, voice:str):
    base = f"{voice}: {text}"
    ids  = tok(base, return_tensors="pt").input_ids.to(device)
    ids  = torch.cat(
        [
            torch.tensor([[START_TOKEN]], device=device),
            ids,
            torch.tensor([[128009, 128260]], device=device),
        ],
        1,
    )
    return ids, torch.ones_like(ids)

def decode_snac(block7:list[int])->bytes:
    l1,l2,l3=[],[],[]
    b=block7
    l1.append(b[0])
    l2.append(b[1]-4096)
    l3.extend([b[2]-8192, b[3]-12288])
    l2.append(b[4]-16384)
    l3.extend([b[5]-20480, b[6]-24576])

    codes=[torch.tensor(x,device=device).unsqueeze(0)
           for x in (l1,l2,l3)]
    audio=snac.decode(codes).squeeze().cpu().numpy()
    return (audio*32767).astype("int16").tobytes()

# ── 5.Β WebSocket‑Endpoint ───────────────────────────────────────────
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
    await ws.accept()
    try:
        req = json.loads(await ws.receive_text())
        text  = req.get("text","")
        voice = req.get("voice","Jakob")

        ids, attn = build_prompt(text, voice)
        past      = None
        buf       = []

        while True:
            out = model.generate(
                input_ids=ids if past is None else None,
                attention_mask=attn if past is None else None,
                past_key_values=past,
                max_new_tokens=CHUNK_TOKENS,
                logits_processor=[MASKER],
                do_sample=True, temperature=0.7, top_p=0.95,
                use_cache=True,
                return_dict_in_generate=True,
            )
            past   = out.past_key_values
            newtok = out.sequences[0,-out.num_generated_tokens:].tolist()

            for t in newtok:
                if t==EOS_TOKEN:
                    raise StopIteration
                if t==NEW_BLOCK_TOKEN:
                    buf.clear(); continue
                buf.append(t-AUDIO_BASE)
                if len(buf)==7:
                    await ws.send_bytes(decode_snac(buf))
                    buf.clear()

            # ab jetzt nur noch mit Cache weiter‑generieren
            ids, attn = None, None

    except (StopIteration, WebSocketDisconnect):
        pass
    except Exception as e:
        print("WS‑Error:", e)
        await ws.close(code=1011)
    finally:
        if ws.client_state.name!="DISCONNECTED":
            await ws.close()

# ── 6.Β Lokaler Test ─────────────────────────────────────────────────
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860)