Audio-RedTeaming-Demo / smallest_ai.py
jena-shreyas's picture
Upload folder using huggingface_hub
ce2ed27 verified
raw
history blame
3.72 kB
import os
import json
import asyncio
import aiofiles
from time import time
import json
from pprint import pprint
from smallestai.waves import WavesClient, AsyncWavesClient
class SmallestAITTS:
def __init__(
self,
model_name: str,
api_key: str,
provider: str,
endpoint_url: str,
voice_id: str = None,
sample_rate: int = 24000,
speed: float = 1.0,
is_async: bool = False,
):
if is_async:
self.client = AsyncWavesClient(api_key=api_key)
else:
self.client = WavesClient(api_key=api_key)
self.model_name = model_name
self.api_key = api_key
self.provider = provider
self.endpoint_url = endpoint_url
self.voice_id = voice_id # if passed as None, initialized later using `load_voice()` function
self.sample_rate = sample_rate
self.speed = speed
self.tts = self._async_tts if is_async else self._tts
self.is_async = is_async
def load_voice(self, voice_id: str):
"""
Used for loading voices (Optional)
"""
self.voice_id = voice_id
# Create a common interface method
def synthesize(self, text: str, output_filepath: str):
"""
Unified interface for text-to-speech synthesis.
Will automatically use async or sync implementation based on initialization.
Args:
text: The text to synthesize
output_filepath: Path to save the audio file
"""
if self.is_async:
# For async usage, wrap in asyncio.run() if not in an async context
try:
return asyncio.get_event_loop().run_until_complete(
self._async_tts(text, output_filepath)
)
except RuntimeError:
# If there's no event loop running
return asyncio.run(self._async_tts(text, output_filepath))
else:
return self._tts(text, output_filepath)
def _tts(self, text: str, output_filepath: str):
# If voice style is not set before TTS
assert self.voice_id is not None, "Please set a voice style."
self.client.synthesize(
text,
save_as=output_filepath,
model=self.model_name,
voice_id=self.voice_id,
speed=self.speed,
sample_rate=self.sample_rate,
)
async def _async_tts(self, text: str, output_filepath: str):
# If voice style is not set before TTS
assert self.voice_id is not None, "Please set a voice style."
async with self.client:
audio_bytes = await self.client.synthesize(
text,
model=self.model_name,
voice_id=self.voice_id,
speed=self.speed,
sample_rate=self.sample_rate,
)
async with aiofiles.open(output_filepath, "wb") as f:
await f.write(audio_bytes)
# Wrapper for SmallestAI client's default functions
def get_languages(self):
return self.client.get_languages()
def get_voices(self, model="lightning", voiceId=None, **kwargs) -> list:
voices = json.loads(self.client.get_voices(model))["voices"]
# recursively filter the voices based on the kwargs
if voiceId is not None:
voices = [voice for voice in voices if voice["voiceId"] == voiceId]
else:
for key in kwargs:
voices = [
voice for voice in voices if voice["tags"][key] == kwargs[key]
]
return voices
def get_models(self):
return self.client.get_models()