File size: 5,121 Bytes
a09ea48
 
 
0316ec3
4189fe1
0316ec3
a09ea48
 
0dfc310
0316ec3
0dfc310
a09ea48
 
 
 
2008a3f
e97a876
1ab029d
0316ec3
0dfc310
e97a876
a09ea48
0316ec3
674acbf
e97a876
0dfc310
 
 
 
 
 
 
 
 
f001a32
e97a876
a09ea48
 
 
 
0316ec3
 
e97a876
 
 
 
0dfc310
 
a09ea48
 
0316ec3
e97a876
 
a09ea48
0dfc310
 
a09ea48
 
ad94d02
a09ea48
e97a876
 
 
 
 
 
 
 
 
a09ea48
0dfc310
 
0316ec3
a09ea48
0dfc310
e97a876
 
 
 
0dfc310
 
e97a876
 
 
a09ea48
e97a876
0dfc310
a09ea48
0dfc310
 
 
 
 
 
 
 
e97a876
 
 
 
 
97006e1
0dfc310
4189fe1
 
a8606ac
a09ea48
 
4189fe1
 
a09ea48
0dfc310
 
674acbf
0dfc310
e97a876
a09ea48
0dfc310
e97a876
0dfc310
 
 
e97a876
0dfc310
 
 
 
e97a876
0dfc310
 
e97a876
0dfc310
 
 
e97a876
0dfc310
e97a876
0dfc310
 
 
 
4189fe1
a09ea48
 
 
 
4189fe1
a09ea48
4189fe1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login, snapshot_download

# — ENV & HF‑AUTH —
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(token=HF_TOKEN)

# — Device —
device = "cuda" if torch.cuda.is_available() else "cpu"

# — Modelle laden —
print("Loading SNAC model…")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)

model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
# Nur die Konfig + Safetensors, alles andere wird ignoriert
snapshot_download(
    repo_id=model_name,
    allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
    ignore_patterns=[
        "optimizer.pt", "pytorch_model.bin", "training_args.bin",
        "scheduler.pt", "tokenizer.json", "tokenizer_config.json",
        "special_tokens_map.json", "vocab.json", "merges.txt", "tokenizer.*"
    ]
)

print("Loading Orpheus model…")
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.bfloat16
).to(device)
model.config.pad_token_id = model.config.eos_token_id
tokenizer = AutoTokenizer.from_pretrained(model_name)

# — Konstanten für Audio‑Token →
# (muss übereinstimmen mit Deinem Training; hier 128266)
AUDIO_TOKEN_OFFSET = 128266

# — Hilfsfunktionen —

def process_prompt(text: str, voice: str):
    prompt = f"{voice}: {text}"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    # Laut Spezifikation:
    #  start_token=128259, end_tokens=(128009,128260)
    start = torch.tensor([[128259]], dtype=torch.int64)
    end   = torch.tensor([[128009, 128260]], dtype=torch.int64)
    ids  = torch.cat([start, input_ids, end], dim=1).to(device)
    mask = torch.ones_like(ids).to(device)
    return ids, mask

def parse_output(generated_ids: torch.LongTensor):
    """
    Croppt nach dem letzten 128257-Start-Token, entfernt Padding (128258)
    und zieht dann den Audio‑Offset ab, um echte Code‑IDs zu bekommen.
    """
    # finde letztes Audio‑Start‑Token
    token_to_start  = 128257
    token_to_remove = model.config.eos_token_id  # 128258

    idxs = (generated_ids == token_to_start).nonzero(as_tuple=True)[1]
    if idxs.numel() > 0:
        cut = idxs[-1].item() + 1
        cropped = generated_ids[:, cut:]
    else:
        cropped = generated_ids

    # flatten & remove PAD, dann Offset abziehen
    flat = cropped[0][cropped[0] != token_to_remove]
    codes = [(int(t) - AUDIO_TOKEN_OFFSET) for t in flat]
    return codes

def redistribute_codes(code_list: list[int], snac_model: SNAC):
    """
    Verteilt die flache Code‑Liste in 3 Layers und dekodiert mit SNAC.
    """
    layer1, layer2, layer3 = [], [], []
    for i in range(len(code_list) // 7):
        base = code_list[7*i : 7*i+7]
        layer1.append(base[0])
        layer2.append(base[1] -   4096)
        layer3.append(base[2] - 2*4096)
        layer3.append(base[3] - 3*4096)
        layer2.append(base[4] - 4*4096)
        layer3.append(base[5] - 5*4096)
        layer3.append(base[6] - 6*4096)

    dev = next(snac_model.parameters()).device
    c1 = torch.tensor(layer1, device=dev).unsqueeze(0)
    c2 = torch.tensor(layer2, device=dev).unsqueeze(0)
    c3 = torch.tensor(layer3, device=dev).unsqueeze(0)
    audio = snac_model.decode([c1, c2, c3])
    return audio.detach().squeeze().cpu().numpy()

# — FastAPI + WebSocket-Endpoint —
app = FastAPI()

@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
    await ws.accept()
    try:
        while True:
            msg = await ws.receive_text()
            data  = json.loads(msg)
            text  = data.get("text", "")
            voice = data.get("voice", "Jakob")

            # 1) Prompt → Token‑Tensoren
            ids, mask = process_prompt(text, voice)

            # 2) Generation
            gen_ids = model.generate(
                input_ids=ids,
                attention_mask=mask,
                max_new_tokens=200,  # zum Debug
                do_sample=True,
                temperature=0.7,
                top_p=0.95,
                repetition_penalty=1.1,
                eos_token_id=model.config.eos_token_id,
            )

            # 3) Token → Code‑Liste → Audio (Float32 @24 kHz)
            code_list = parse_output(gen_ids)
            audio_np  = redistribute_codes(code_list, snac)

            # 4) In 0.1 s‑Chunks (2400 Samples) als PCM16 streamen
            pcm16 = (audio_np * 32767).astype("int16").tobytes()
            chunk = 2400 * 2
            for i in range(0, len(pcm16), chunk):
                await ws.send_bytes(pcm16[i : i+chunk])
                await asyncio.sleep(0.1)

    except WebSocketDisconnect:
        print("Client disconnected")
    except Exception as e:
        print("Error in /ws/tts:", e)
        await ws.close(code=1011)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860)