import torch import torchaudio import numpy as np from transformers import AutoTokenizer, AutoModelForCausalLM from livekit import rtc import asyncio import os class EndpointHandler: def __init__(self, path: str = ""): # Load the Orpheus TTS model and tokenizer from the given path (Hub repository). self.device = "cuda" if torch.cuda.is_available() else "cpu" path = "atharva27/orpheus" self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) self.model.to(self.device) self.model.eval() def __call__(self, data: dict) -> list: # Extract input text and optional voice and LiveKit parameters. text_input = data.get("inputs") or data.get("text") or "" if not isinstance(text_input, str) or text_input.strip() == "": raise ValueError("No text input provided for TTS") voice = data.get("voice", "tara") # default voice (e.g., "tara") # Format prompt with voice name (Orpheus expects prompts like "voice: text"). prompt = f"{voice}: {text_input}" # Encode prompt and generate output tokens with the TTS model. input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) generate_kwargs = { "max_new_tokens": 1024, # allow sufficient tokens for audio output "do_sample": True, "temperature": 0.8, "top_p": 0.95, "repetition_penalty": 1.1, # >=1.1 for stable speech generation "pad_token_id": self.tokenizer.eos_token_id, } output_ids = self.model.generate(input_ids, **generate_kwargs) # The generated sequence includes the prompt; isolate newly generated tokens: generated_tokens = output_ids[0, input_ids.size(1):] output_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=False) # Extract audio token IDs (assume tokens are in the output_text) # This is a placeholder for token extraction, replace with actual logic. audio_token_ids = [int(m) for m in output_text.split()] # Example: convert the audio token IDs to waveform data waveform = self.generate_waveform_from_tokens(audio_token_ids) # Save or stream waveform torchaudio.save("output_audio.wav", waveform, 24000) # Save as a 24 kHz audio file # For real-time streaming, we will use LiveKit to stream the audio lk_url = data.get("livekit_url") lk_token = data.get("livekit_token") room_name = data.get("livekit_room", "default-room") # Streaming logic asyncio.run(self.stream_audio(lk_url, lk_token, room_name, waveform)) return [{"status": "success"}] def generate_waveform_from_tokens(self, audio_token_ids): """ Convert audio tokens into a waveform (this part is for demonstration). You should implement a proper method to decode tokens to actual audio. """ # Here we're simulating the waveform by generating random data based on the tokens # Replace this logic with actual audio generation num_samples = len(audio_token_ids) * 100 # Estimate number of samples based on tokens waveform = torch.randn(1, num_samples) # Simulate random audio waveform return waveform async def stream_audio(self, lk_url, lk_token, room_name, waveform): room = rtc.Room() try: await room.connect(lk_url, lk_token, options=rtc.RoomOptions(auto_subscribe=True)) except Exception as e: return f"Failed to connect to LiveKit: {e}" # Create an audio track for streaming the TTS output source = rtc.AudioSource(sample_rate=24000, num_channels=1) track = rtc.LocalAudioTrack.create_audio_track("tts-audio", source) await room.local_participant.publish_track(track, rtc.TrackPublishOptions(name="TTS Audio")) # Stream the waveform data in chunks for real-time playback frame_duration = 0.05 # 50 ms per frame frame_samples = int(24000 * frame_duration) # 50 ms of audio at 24 kHz sample rate total_samples = waveform.size(1) for start in range(0, total_samples, frame_samples): end = min(start + frame_samples, total_samples) chunk = waveform[:, start:end].numpy().astype(np.int16) # Convert chunk to 16-bit PCM # Create an AudioFrame and send to LiveKit audio_frame = rtc.AudioFrame.create(24000, 1, len(chunk)) np.copyto(audio_frame.data, chunk) await source.capture_frame(audio_frame) # Sleep to maintain real-time pace (synchronize with frame duration) await asyncio.sleep(frame_duration) # Disconnect from the room after streaming is finished await room.disconnect()