import os import io import base64 import torch import numpy as np from transformers import BarkModel, BarkProcessor from typing import Dict, List, Any class EndpointHandler: def __init__(self, path=""): """ Initialize the handler for Bark text-to-speech model. Args: path (str, optional): Path to the model directory. Defaults to "". """ self.path = path self.model = None self.processor = None self.device = "cuda" if torch.cuda.is_available() else "cpu" self.initialized = False def setup(self, **kwargs): """ Load the model and processor. Args: **kwargs: Additional arguments. """ # Load model from the local directory self.model = BarkModel.from_pretrained(self.path) self.model.to(self.device) # Load processor self.processor = BarkProcessor.from_pretrained(self.path) self.initialized = True print(f"Bark model loaded on {self.device}") def preprocess(self, request: Dict) -> Dict: """ Process the input request before inference. Args: request (Dict): The request data containing text to convert to speech. Returns: Dict: Processed inputs for the model. """ if not self.initialized: self.setup() inputs = {} # Get text from the request if "inputs" in request: if isinstance(request["inputs"], str): # Single text input inputs["text"] = request["inputs"] elif isinstance(request["inputs"], list): # List of text inputs inputs["text"] = request["inputs"][0] # Take the first text # Get optional parameters params = request.get("parameters", {}) # Speaker ID/voice preset if "speaker_id" in params: inputs["speaker_id"] = params["speaker_id"] elif "voice_preset" in params: inputs["voice_preset"] = params["voice_preset"] # Other generation parameters if "temperature" in params: inputs["temperature"] = params.get("temperature", 0.7) return inputs def inference(self, inputs: Dict) -> Dict: """ Run model inference on the processed inputs. Args: inputs (Dict): Processed inputs for the model. Returns: Dict: Model outputs. """ text = inputs.get("text", "") if not text: return {"error": "No text provided for speech generation"} # Extract optional parameters speaker_id = inputs.get("speaker_id", None) voice_preset = inputs.get("voice_preset", None) temperature = inputs.get("temperature", 0.7) # Prepare inputs for the model input_ids = self.processor(text).to(self.device) # Generate speech with torch.no_grad(): if speaker_id: # Use speaker_id if provided speech_output = self.model.generate( input_ids=input_ids, speaker_id=speaker_id, temperature=temperature ) elif voice_preset: # Use voice_preset if provided speech_output = self.model.generate( input_ids=input_ids, voice_preset=voice_preset, temperature=temperature ) else: # Use default settings speech_output = self.model.generate( input_ids=input_ids, temperature=temperature ) # Convert to numpy array audio_array = speech_output.cpu().numpy().squeeze() return {"audio_array": audio_array, "sample_rate": self.model.generation_config.sample_rate} def postprocess(self, inference_output: Dict) -> Dict: """ Process the model outputs after inference. Args: inference_output (Dict): Model outputs. Returns: Dict: Processed outputs ready for the response. """ if "error" in inference_output: return {"error": inference_output["error"]} audio_array = inference_output.get("audio_array") sample_rate = inference_output.get("sample_rate", 24000) # Convert audio array to WAV format try: import scipy.io.wavfile as wav audio_buffer = io.BytesIO() wav.write(audio_buffer, sample_rate, audio_array) audio_buffer.seek(0) audio_data = audio_buffer.read() # Encode audio data to base64 audio_base64 = base64.b64encode(audio_data).decode("utf-8") return { "audio": audio_base64, "sample_rate": sample_rate, "format": "wav" } except Exception as e: return {"error": f"Error converting audio: {str(e)}"} def __call__(self, data: Dict) -> Dict: """ Main entry point for the handler. Args: data (Dict): Request data. Returns: Dict: Response data. """ # Ensure the model is initialized if not self.initialized: self.setup() # Process the request try: inputs = self.preprocess(data) outputs = self.inference(inputs) response = self.postprocess(outputs) return response except Exception as e: return {"error": f"Error processing request: {str(e)}"}