|
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
import outetts |
|
import io |
|
import json |
|
import base64 |
|
import struct |
|
import os |
|
|
|
interface = outetts.Interface( |
|
config=outetts.ModelConfig.auto_config( |
|
model=outetts.Models.VERSION_1_0_SIZE_1B, |
|
|
|
|
|
|
|
|
|
backend=outetts.Backend.HF, |
|
) |
|
) |
|
|
|
|
|
speaker = interface.load_default_speaker("EN-FEMALE-1-NEUTRAL") |
|
|
|
app = FastAPI() |
|
|
|
@app.get("/") |
|
def greet_json(): |
|
return {"Hello": "World!"} |
|
|
|
@app.websocket("/ws/tts") |
|
async def websocket_tts(websocket: WebSocket): |
|
await websocket.accept() |
|
try: |
|
while True: |
|
|
|
data = await websocket.receive_text() |
|
|
|
await websocket.send_text(json.dumps({"generation_status": "Warming up TTS model"})) |
|
output = interface.generate( |
|
config=outetts.GenerationConfig( |
|
text=data, |
|
generation_type=outetts.GenerationType.CHUNKED, |
|
speaker=speaker, |
|
sampler_config=outetts.SamplerConfig( |
|
temperature=0.4 |
|
), |
|
) |
|
) |
|
|
|
await websocket.send_text(json.dumps({"generation_status": "Generating linguistic features"})) |
|
|
|
import uuid |
|
temp_path = f"temp_{uuid.uuid4().hex}.wav" |
|
output.save(temp_path) |
|
chunk_size = 4096 |
|
try: |
|
with open(temp_path, "rb") as f: |
|
wav_data = f.read() |
|
|
|
|
|
if wav_data[:4] != b'RIFF' or wav_data[8:12] != b'WAVE': |
|
raise ValueError("Not a valid WAV file") |
|
|
|
data_offset = wav_data.find(b'data') |
|
if data_offset == -1: |
|
raise ValueError("No 'data' chunk found in WAV file") |
|
header_end = data_offset + 8 |
|
wav_header = wav_data[:header_end] |
|
pcm_data = wav_data[header_end:] |
|
|
|
first_chunk = pcm_data[:chunk_size] |
|
audio_b64 = base64.b64encode(wav_header + first_chunk).decode("ascii") |
|
await websocket.send_text(json.dumps({ |
|
"data": { |
|
"audio_bytes": audio_b64, |
|
"duration": None, |
|
"request_finished": False |
|
} |
|
})) |
|
|
|
idx = chunk_size |
|
while idx < len(pcm_data): |
|
chunk = pcm_data[idx:idx+chunk_size] |
|
if not chunk: |
|
break |
|
audio_b64 = base64.b64encode(chunk).decode("ascii") |
|
await websocket.send_text(json.dumps({ |
|
"data": { |
|
"audio_bytes": audio_b64, |
|
"duration": None, |
|
"request_finished": False |
|
} |
|
})) |
|
idx += chunk_size |
|
finally: |
|
try: |
|
os.remove(temp_path) |
|
except FileNotFoundError: |
|
pass |
|
|
|
await websocket.send_text(json.dumps({ |
|
"data": { |
|
"audio_bytes": "", |
|
"duration": None, |
|
"request_finished": True |
|
} |
|
})) |
|
except WebSocketDisconnect: |
|
pass |
|
except Exception as e: |
|
await websocket.send_text(json.dumps({"error": str(e)})) |