File size: 4,552 Bytes
a09ea48
 
 
0316ec3
4189fe1
0316ec3
a09ea48
 
0dfc310
0316ec3
0dfc310
a09ea48
 
 
 
2008a3f
9cd424e
1ab029d
0316ec3
9cd424e
d408dd5
a09ea48
0316ec3
9cd424e
674acbf
9cd424e
 
0dfc310
 
 
 
 
 
 
 
 
f001a32
d408dd5
 
9cd424e
 
 
d408dd5
a09ea48
9cd424e
b3e4aa7
0dfc310
d408dd5
 
9cd424e
a09ea48
 
9cd424e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d408dd5
b3e4aa7
9cd424e
 
 
b3e4aa7
9cd424e
 
 
97006e1
9cd424e
4189fe1
 
d408dd5
9cd424e
 
d408dd5
a8606ac
a09ea48
 
4189fe1
9cd424e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3e4aa7
4189fe1
a09ea48
 
9cd424e
a09ea48
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import json
import asyncio
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login, snapshot_download

# — ENV & HF‑AUTH —
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(token=HF_TOKEN)

# — Device —
device = "cuda" if torch.cuda.is_available() else "cpu"

# — SNAC laden —
print("Loading SNAC model...")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)

# — Orpheus‑Modell vorbereiten —
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"

# Nur Konfig+Weights (ermöglicht schlankeren Container)
snapshot_download(
    repo_id=model_name,
    allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
    ignore_patterns=[
        "optimizer.pt", "pytorch_model.bin", "training_args.bin",
        "scheduler.pt", "tokenizer.json", "tokenizer_config.json",
        "special_tokens_map.json", "vocab.json", "merges.txt", "tokenizer.*"
    ]
)

print("Loading Orpheus model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
).to(device)
model.config.pad_token_id = model.config.eos_token_id

tokenizer = AutoTokenizer.from_pretrained(model_name)


# — Hilfsfunktionen —

def process_prompt(text: str, voice: str):
    prompt = f"{voice}: {text}"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    # füge Start-/End-Tokens hinzu
    start = torch.tensor([[128259]], device=device)
    end   = torch.tensor([[128009, 128260]], device=device)
    input_ids = torch.cat([start, inputs.input_ids, end], dim=1)
    return input_ids

def parse_output(generated_ids: torch.LongTensor):
    token_to_find   = 128257
    token_to_remove = 128258

    idxs = (generated_ids == token_to_find).nonzero(as_tuple=True)[1]
    if idxs.numel() > 0:
        cropped = generated_ids[:, idxs[-1].item() + 1 :]
    else:
        cropped = generated_ids

    row = cropped[0][cropped[0] != token_to_remove]
    return row.tolist()

def redistribute_codes(code_list: list[int], snac_model: SNAC):
    layer1, layer2, layer3 = [], [], []
    for i in range((len(code_list) + 1) // 7):
        base = code_list[7*i : 7*i+7]
        layer1.append(base[0])
        layer2.append(base[1] -   4096)
        layer3.append(base[2] - 2*4096)
        layer3.append(base[3] - 3*4096)
        layer2.append(base[4] - 4*4096)
        layer3.append(base[5] - 5*4096)
        layer3.append(base[6] - 6*4096)

    dev = next(snac_model.parameters()).device
    codes = [
        torch.tensor(layer1, device=dev).unsqueeze(0),
        torch.tensor(layer2, device=dev).unsqueeze(0),
        torch.tensor(layer3, device=dev).unsqueeze(0),
    ]
    audio = snac_model.decode(codes)
    return audio.detach().squeeze().cpu().numpy()


# — FastAPI App —  
app = FastAPI()

@app.get("/")
async def hello():
    return {"message": "Hello, Orpheus TTS is up and running!"}

@app.websocket("/ws/tts")
async def tts_ws(ws: WebSocket):
    await ws.accept()
    try:
        # **Nur EIN Request pro Connection**
        raw = await ws.receive_text()
        data = json.loads(raw)
        text  = data.get("text", "")
        voice = data.get("voice", "Jakob")

        # 1) Text → input_ids
        input_ids = process_prompt(text, voice)

        # 2) Generation
        gen_ids = model.generate(
            input_ids=input_ids,
            max_new_tokens=2000,    # hier kannst du hochsetzen
            do_sample=True,
            temperature=0.7,
            top_p=0.95,
            repetition_penalty=1.1,
            eos_token_id=model.config.eos_token_id,
        )

        # 3) Token → Audio
        codes    = parse_output(gen_ids)
        audio_np = redistribute_codes(codes, snac)

        # 4) PCM16-Bytes in ~0.1s‑Chunks streamen
        pcm16 = (audio_np * 32767).astype("int16").tobytes()
        chunk_size = 2400 * 2  # 2400 Samples @24kHz = 0.1s * 2 Byte
        for i in range(0, len(pcm16), chunk_size):
            await ws.send_bytes(pcm16[i : i+chunk_size])
            await asyncio.sleep(0.1)

        # Sauber schließen, Client erhält ConnectionClosedOK
        await ws.close()

    except WebSocketDisconnect:
        print("Client disconnected")
    except Exception as e:
        # Log und saubere Fehler‑Closure
        print("Error in /ws/tts:", e)
        await ws.close(code=1011)