|
import gradio as gr |
|
import google.generativeai as genai |
|
import numpy as np |
|
import re |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from huggingface_hub import snapshot_download, login |
|
import logging |
|
import os |
|
import spaces |
|
import warnings |
|
from snac import SNAC |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
warnings.filterwarnings("ignore", category=RuntimeWarning) |
|
|
|
def get_device(): |
|
return "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
device = get_device() |
|
logger.info(f"Using device: {device}") |
|
|
|
model = None |
|
tokenizer = None |
|
snac_model = None |
|
|
|
@spaces.GPU() |
|
def load_model(): |
|
global model, tokenizer, snac_model |
|
|
|
logger.info("Loading SNAC model...") |
|
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") |
|
snac_model = snac_model.to(device) |
|
|
|
logger.info("Loading Orpheus model...") |
|
model_name = "canopylabs/orpheus-3b-0.1-ft" |
|
|
|
hf_token = os.environ.get("HUGGINGFACE_TOKEN") |
|
if not hf_token: |
|
raise ValueError("HUGGINGFACE_TOKEN environment variable is not set") |
|
|
|
try: |
|
login(token=hf_token) |
|
|
|
snapshot_download( |
|
repo_id=model_name, |
|
use_auth_token=hf_token, |
|
allow_patterns=[ |
|
"config.json", |
|
"*.safetensors", |
|
"model.safetensors.index.json", |
|
], |
|
ignore_patterns=[ |
|
"optimizer.pt", |
|
"pytorch_model.bin", |
|
"training_args.bin", |
|
"scheduler.pt", |
|
"tokenizer.json", |
|
"tokenizer_config.json", |
|
"special_tokens_map.json", |
|
"vocab.json", |
|
"merges.txt", |
|
"tokenizer.*" |
|
] |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) |
|
model.to(device) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
logger.info(f"Orpheus model and tokenizer loaded to {device}") |
|
except Exception as e: |
|
logger.error(f"Error loading model: {str(e)}") |
|
raise |
|
|
|
@spaces.GPU() |
|
def text_to_speech(text, voice, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200): |
|
global model, tokenizer, snac_model |
|
if model is None or tokenizer is None or snac_model is None: |
|
load_model() |
|
|
|
if not text.strip(): |
|
return None |
|
|
|
try: |
|
input_ids, attention_mask = process_prompt(text, voice, tokenizer, device) |
|
|
|
with torch.no_grad(): |
|
generated_ids = model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
num_return_sequences=1, |
|
eos_token_id=128258, |
|
) |
|
|
|
code_list = parse_output(generated_ids) |
|
audio_samples = redistribute_codes(code_list, snac_model) |
|
|
|
return (24000, audio_samples) |
|
except Exception as e: |
|
logger.error(f"Error in text_to_speech: {str(e)}") |
|
raise |
|
|
|
@spaces.GPU() |
|
def render_podcast(api_key, script, voice1, voice2, num_hosts): |
|
try: |
|
lines = [line for line in script.split('\n') if line.strip()] |
|
audio_segments = [] |
|
|
|
for i, line in enumerate(lines): |
|
voice = voice1 if num_hosts == 1 or i % 2 == 0 else voice2 |
|
try: |
|
sample_rate, audio = text_to_speech(line, voice) |
|
audio_segments.append(audio) |
|
except Exception as e: |
|
logger.error(f"Error processing audio segment: {str(e)}") |
|
|
|
if not audio_segments: |
|
logger.warning("No valid audio segments were generated.") |
|
return (24000, np.zeros(24000, dtype=np.float32)) |
|
|
|
podcast_audio = np.concatenate(audio_segments) |
|
|
|
|
|
podcast_audio = np.clip(podcast_audio, -1, 1) |
|
podcast_audio = (podcast_audio * 32767).astype(np.int16) |
|
|
|
return (24000, podcast_audio) |
|
except Exception as e: |
|
logger.error(f"Error rendering podcast: {str(e)}") |
|
raise |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
try: |
|
load_model() |
|
demo.launch() |
|
except Exception as e: |
|
logger.error(f"Error launching the application: {str(e)}") |