import gradio as gr import google.generativeai as genai import numpy as np import re import torch from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import snapshot_download, login import logging import os import spaces import warnings from snac import SNAC from dotenv import load_dotenv load_dotenv() logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=RuntimeWarning) device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") model = None tokenizer = None snac_model = None @spaces.GPU() def load_model(): global model, tokenizer, snac_model try: logger.info("Loading SNAC model...") snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") snac_model = snac_model.to(device) logger.info("Loading Orpheus model...") model_name = "canopylabs/orpheus-3b-0.1-ft" hf_token = os.environ.get("HUGGINGFACE_TOKEN") if not hf_token: raise ValueError("HUGGINGFACE_TOKEN environment variable is not set") login(token=hf_token) snapshot_download( repo_id=model_name, use_auth_token=hf_token, allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"], ignore_patterns=["optimizer.pt", "pytorch_model.bin", "training_args.bin", "scheduler.pt", "tokenizer.*"] ) model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) model.to(device) tokenizer = AutoTokenizer.from_pretrained(model_name) logger.info(f"Orpheus model and tokenizer loaded to {device}") except Exception as e: logger.error(f"Error loading model: {str(e)}") raise @spaces.GPU() def generate_podcast_script(api_key, content, uploaded_file, duration, num_hosts): try: genai.configure(api_key=api_key) model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25') combined_content = content or "" if uploaded_file: file_content = uploaded_file.read().decode('utf-8') combined_content += "\n" + file_content if combined_content else file_content prompt = f""" Create a podcast script for {'one person' if num_hosts == 1 else 'two people'} discussing: {combined_content} Duration: {duration}. Include natural speech, humor, and occasional off-topic thoughts. Use speech fillers like um, ah. Vary emotional tone. Format: {'Monologue' if num_hosts == 1 else 'Alternating dialogue'} without speaker labels. Separate {'paragraphs' if num_hosts == 1 else 'lines'} with blank lines. Use emotion tags in angle brackets: , , , , , , , . Example: "I can't believe I stayed up all night only to find out the meeting was canceled ." Ensure content flows naturally and stays on topic. Match the script length to {duration}. """ response = model.generate_content(prompt) return re.sub(r'[^a-zA-Z0-9\s.,?!<>]', '', response.text) except Exception as e: logger.error(f"Error generating podcast script: {str(e)}") raise def process_prompt(prompt, voice, tokenizer, device): prompt = f"{voice}: {prompt}" input_ids = tokenizer(prompt, return_tensors="pt").input_ids start_token = torch.tensor([[128259]], dtype=torch.int64) end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) attention_mask = torch.ones_like(modified_input_ids) return modified_input_ids.to(device), attention_mask.to(device) def parse_output(generated_ids): token_to_find = 128257 token_to_remove = 128258 token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True) if len(token_indices[1]) > 0: last_occurrence_idx = token_indices[1][-1].item() cropped_tensor = generated_ids[:, last_occurrence_idx+1:] else: cropped_tensor = generated_ids processed_rows = [] for row in cropped_tensor: masked_row = row[row != token_to_remove] processed_rows.append(masked_row) code_lists = [] for row in processed_rows: row_length = row.size(0) new_length = (row_length // 7) * 7 trimmed_row = row[:new_length] trimmed_row = [t - 128266 for t in trimmed_row] code_lists.append(trimmed_row) return code_lists[0] def redistribute_codes(code_list, snac_model): device = next(snac_model.parameters()).device layer_1, layer_2, layer_3 = [], [], [] for i in range((len(code_list)+1)//7): layer_1.append(code_list[7*i]) layer_2.append(code_list[7*i+1]-4096) layer_3.append(code_list[7*i+2]-(2*4096)) layer_3.append(code_list[7*i+3]-(3*4096)) layer_2.append(code_list[7*i+4]-(4*4096)) layer_3.append(code_list[7*i+5]-(5*4096)) layer_3.append(code_list[7*i+6]-(6*4096)) codes = [ torch.tensor(layer_1, device=device).unsqueeze(0), torch.tensor(layer_2, device=device).unsqueeze(0), torch.tensor(layer_3, device=device).unsqueeze(0) ] audio_hat = snac_model.decode(codes) return audio_hat.detach().squeeze().cpu().numpy() @spaces.GPU() def text_to_speech(text, voice, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200): global model, tokenizer, snac_model if model is None or tokenizer is None or snac_model is None: load_model() if not text.strip(): return None try: input_ids, attention_mask = process_prompt(text, voice, tokenizer, device) with torch.no_grad(): generated_ids = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, num_return_sequences=1, eos_token_id=128258, ) code_list = parse_output(generated_ids) audio_samples = redistribute_codes(code_list, snac_model) return (24000, audio_samples) except Exception as e: logger.error(f"Error in text_to_speech: {str(e)}") raise @spaces.GPU() def render_podcast(api_key, script, voice1, voice2, num_hosts): try: lines = [line for line in script.split('\n') if line.strip()] audio_segments = [] for i, line in enumerate(lines): voice = voice1 if num_hosts == 1 or i % 2 == 0 else voice2 try: result = text_to_speech(line, voice) if result is not None: sample_rate, audio = result audio_segments.append(audio) except Exception as e: logger.error(f"Error processing audio segment: {str(e)}") if not audio_segments: logger.warning("No valid audio segments were generated.") return (24000, np.zeros(24000, dtype=np.float32)) podcast_audio = np.concatenate(audio_segments) podcast_audio = np.clip(podcast_audio, -1, 1) podcast_audio = (podcast_audio * 32767).astype(np.int16) return (24000, podcast_audio) except Exception as e: logger.error(f"Error rendering podcast: {str(e)}") raise with gr.Blocks() as demo: gr.Markdown("# AI Podcast Generator") api_key_input = gr.Textbox(label="Enter your Gemini API Key", type="password") with gr.Row(): content_input = gr.Textbox( label="Paste your content (optional)", lines=8 ) document_upload = gr.File(label="Upload Document (optional)") duration = gr.Radio(["1-5 min", "5-10 min", "10-15 min"], label="Estimated podcast duration", value="1-5 min") num_hosts = gr.Radio([1, 2], label="Number of podcast hosts", value=2) voice_options = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"] voice1_select = gr.Dropdown(label="Select Voice 1", choices=voice_options, value="tara") voice2_select = gr.Dropdown(label="Select Voice 2", choices=voice_options, value="leo") generate_btn = gr.Button("Generate Script") script_output = gr.Textbox(label="Generated Script", lines=10) render_btn = gr.Button("Render Podcast") audio_output = gr.Audio(label="Generated Podcast") generate_btn.click(generate_podcast_script, inputs=[api_key_input, content_input, document_upload, duration, num_hosts], outputs=script_output) render_btn.click(render_podcast, inputs=[api_key_input, script_output, voice1_select, voice2_select, num_hosts], outputs=audio_output) num_hosts.change(lambda x: gr.update(visible=x == 2), inputs=[num_hosts], outputs=[voice2_select]) if __name__ == "__main__": try: load_model() demo.launch() except Exception as e: logger.error(f"Error launching the application: {str(e)}")