|
import os, json, asyncio |
|
import torch |
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
|
from dotenv import load_dotenv |
|
from snac import SNAC |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from huggingface_hub import login, snapshot_download |
|
|
|
load_dotenv() |
|
if (tok := os.getenv("HF_TOKEN")): |
|
login(token=tok) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
print("Loading SNAC…") |
|
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) |
|
|
|
model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" |
|
snapshot_download( |
|
repo_id=model_name, |
|
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"], |
|
ignore_patterns=[ "optimizer.pt", "pytorch_model.bin", "training_args.bin", |
|
"scheduler.pt", "tokenizer.*", "vocab.json", "merges.txt" ] |
|
) |
|
|
|
print("Loading Orpheus…") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.bfloat16 |
|
) |
|
model = model.to(device) |
|
model.config.pad_token_id = model.config.eos_token_id |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
def process_prompt(text: str, voice: str): |
|
prompt = f"{voice}: {text}" |
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
start = torch.tensor([[128259]], device=device) |
|
end = torch.tensor([[128009, 128260]], device=device) |
|
return torch.cat([start, inputs.input_ids, end], dim=1) |
|
|
|
def parse_output(ids: torch.LongTensor): |
|
st, rm = 128257, 128258 |
|
idxs = (ids==st).nonzero(as_tuple=True)[1] |
|
cropped = ids[:, idxs[-1].item()+1:] if idxs.numel()>0 else ids |
|
row = cropped[0][cropped[0]!=rm] |
|
return row.tolist() |
|
|
|
def redistribute_codes(codes: list[int], snac_model: SNAC): |
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"status":"ok","msg":"Hello, Orpheus TTS up!"} |
|
|
|
@app.websocket("/ws/tts") |
|
async def ws_tts(ws: WebSocket): |
|
await ws.accept() |
|
try: |
|
msg = json.loads(await ws.receive_text()) |
|
text, voice = msg.get("text",""), msg.get("voice","Jakob") |
|
ids = process_prompt(text, voice) |
|
gen = model.generate( |
|
input_ids=ids, |
|
max_new_tokens=2000, |
|
do_sample=True, temperature=0.7, top_p=0.95, |
|
repetition_penalty=1.1, |
|
eos_token_id=model.config.eos_token_id, |
|
) |
|
codes = parse_output(gen) |
|
audio_np = redistribute_codes(codes, snac) |
|
pcm16 = (audio_np*32767).astype("int16").tobytes() |
|
chunk = 2400*2 |
|
for i in range(0,len(pcm16),chunk): |
|
await ws.send_bytes(pcm16[i:i+chunk]) |
|
await asyncio.sleep(0.1) |
|
await ws.close() |
|
except WebSocketDisconnect: |
|
print("Client left") |
|
except Exception as e: |
|
print("Error in /ws/tts:",e) |
|
await ws.close(code=1011) |
|
|
|
if __name__=="__main__": |
|
import uvicorn |
|
uvicorn.run("app:app",host="0.0.0.0",port=7860) |
|
|