|
import gradio as gr |
|
import google.generativeai as genai |
|
import numpy as np |
|
import re |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from huggingface_hub import snapshot_download |
|
import logging |
|
import os |
|
import spaces |
|
import warnings |
|
from snac import SNAC |
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
warnings.filterwarnings("ignore", category=RuntimeWarning) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.info(f"Using device: {device}") |
|
|
|
model = None |
|
tokenizer = None |
|
snac_model = None |
|
|
|
@spaces.GPU() |
|
def load_model(): |
|
global model, tokenizer, snac_model |
|
try: |
|
print("Loading SNAC model...") |
|
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") |
|
snac_model = snac_model.to(device) |
|
|
|
print("Loading Orpheus model...") |
|
model_name = "canopylabs/orpheus-3b-0.1-ft" |
|
|
|
snapshot_download( |
|
repo_id=model_name, |
|
use_auth_token=os.environ.get("HUGGINGFACE_TOKEN"), |
|
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json", "vocab.json", "merges.txt", "tokenizer.json"], |
|
ignore_patterns=["optimizer.pt", "pytorch_model.bin", "training_args.bin", "scheduler.pt"] |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) |
|
model.to(device) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
print(f"Orpheus model and tokenizer loaded to {device}") |
|
except Exception as e: |
|
print(f"Error loading model: {str(e)}") |
|
raise |
|
|
|
@spaces.GPU() |
|
def generate_podcast_script(api_key, content, uploaded_file, duration, num_hosts): |
|
try: |
|
genai.configure(api_key=api_key) |
|
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25') |
|
|
|
combined_content = content or "" |
|
if uploaded_file: |
|
file_content = uploaded_file.read().decode('utf-8') |
|
combined_content += "\n" + file_content if combined_content else file_content |
|
|
|
prompt = f""" |
|
Create a podcast script for {'one person' if num_hosts == 1 else 'two people'} discussing: |
|
{combined_content} |
|
|
|
Duration: {duration}. Include natural speech, humor, and occasional off-topic thoughts. |
|
Use speech fillers like um, ah. Vary emotional tone. |
|
|
|
Format: {'Monologue' if num_hosts == 1 else 'Alternating dialogue'} without speaker labels. |
|
Separate {'paragraphs' if num_hosts == 1 else 'lines'} with blank lines. |
|
|
|
Use emotion tags in angle brackets: <laugh>, <sigh>, <chuckle>, <cough>, <sniffle>, <groan>, <yawn>, <gasp>. |
|
|
|
Example: "I can't believe I stayed up all night <yawn> only to find out the meeting was canceled <groan>." |
|
|
|
Ensure content flows naturally and stays on topic. Match the script length to {duration}. |
|
""" |
|
|
|
response = model.generate_content(prompt) |
|
return re.sub(r'[^a-zA-Z0-9\s.,?!<>]', '', response.text) |
|
except Exception as e: |
|
logger.error(f"Error generating podcast script: {str(e)}") |
|
raise |
|
|
|
def process_prompt(prompt, voice, tokenizer, device): |
|
prompt = f"{voice}: {prompt}" |
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids |
|
|
|
start_token = torch.tensor([[128259]], dtype=torch.int64) |
|
end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) |
|
|
|
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) |
|
attention_mask = torch.ones_like(modified_input_ids) |
|
|
|
return modified_input_ids.to(device), attention_mask.to(device) |
|
|
|
def parse_output(generated_ids): |
|
token_to_find = 128257 |
|
token_to_remove = 128258 |
|
|
|
token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True) |
|
|
|
if len(token_indices[1]) > 0: |
|
last_occurrence_idx = token_indices[1][-1].item() |
|
cropped_tensor = generated_ids[:, last_occurrence_idx+1:] |
|
else: |
|
cropped_tensor = generated_ids |
|
|
|
processed_rows = [] |
|
for row in cropped_tensor: |
|
masked_row = row[row != token_to_remove] |
|
processed_rows.append(masked_row) |
|
|
|
code_lists = [] |
|
for row in processed_rows: |
|
row_length = row.size(0) |
|
new_length = (row_length // 7) * 7 |
|
trimmed_row = row[:new_length] |
|
trimmed_row = [t - 128266 for t in trimmed_row] |
|
code_lists.append(trimmed_row) |
|
|
|
return code_lists[0] |
|
|
|
def redistribute_codes(code_list, snac_model): |
|
try: |
|
device = next(snac_model.parameters()).device |
|
|
|
layer_1, layer_2, layer_3 = [], [], [] |
|
for i in range((len(code_list)+1)//7): |
|
layer_1.append(code_list[7*i]) |
|
layer_2.append(code_list[7*i+1]-4096) |
|
layer_3.append(code_list[7*i+2]-(2*4096)) |
|
layer_3.append(code_list[7*i+3]-(3*4096)) |
|
layer_2.append(code_list[7*i+4]-(4*4096)) |
|
layer_3.append(code_list[7*i+5]-(5*4096)) |
|
layer_3.append(code_list[7*i+6]-(6*4096)) |
|
|
|
codes = [ |
|
torch.tensor(layer_1, device=device).unsqueeze(0), |
|
torch.tensor(layer_2, device=device).unsqueeze(0), |
|
torch.tensor(layer_3, device=device).unsqueeze(0) |
|
] |
|
|
|
audio_hat = snac_model.decode(codes) |
|
return audio_hat.detach().squeeze().cpu().numpy() |
|
except Exception as e: |
|
logger.error(f"Error in redistribute_codes: {e}", exc_info=True) |
|
return None |
|
|
|
@spaces.GPU() |
|
def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens): |
|
global tokenizer, model |
|
if tokenizer is None or model is None: |
|
print("Model or tokenizer is not initialized. Please ensure the model is properly loaded.") |
|
return None |
|
|
|
if not text.strip(): |
|
return None |
|
|
|
try: |
|
input_ids, attention_mask = process_prompt(text, voice, tokenizer, device) |
|
|
|
with torch.no_grad(): |
|
generated_ids = model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
num_return_sequences=1, |
|
eos_token_id=128258, |
|
) |
|
|
|
code_list = parse_output(generated_ids) |
|
audio_samples = redistribute_codes(code_list, snac_model) |
|
|
|
return (24000, audio_samples) |
|
except Exception as e: |
|
print(f"Error generating speech: {e}") |
|
return None |
|
|
|
@spaces.GPU() |
|
def render_podcast(api_key, script, voice1, voice2, num_hosts): |
|
try: |
|
lines = [line for line in script.split('\n') if line.strip()] |
|
audio_segments = [] |
|
|
|
for i, line in enumerate(lines): |
|
voice = voice1 if num_hosts == 1 or i % 2 == 0 else voice2 |
|
result = generate_speech(line, voice, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200) |
|
if result is not None: |
|
sample_rate, audio = result |
|
audio_segments.append(audio) |
|
|
|
if not audio_segments: |
|
logger.warning("No valid audio segments were generated.") |
|
return (24000, np.zeros(24000, dtype=np.float32)) |
|
|
|
podcast_audio = np.concatenate(audio_segments) |
|
podcast_audio = np.clip(podcast_audio, -1, 1) |
|
podcast_audio = (podcast_audio * 32767).astype(np.int16) |
|
|
|
return (24000, podcast_audio) |
|
except Exception as e: |
|
logger.error(f"Error rendering podcast: {str(e)}") |
|
raise |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# AI Podcast Generator") |
|
|
|
api_key_input = gr.Textbox(label="Enter your Gemini API Key", type="password") |
|
|
|
with gr.Row(): |
|
content_input = gr.Textbox(label="Paste your content (optional)", lines=4) |
|
document_upload = gr.File(label="Upload Document (optional)") |
|
|
|
duration = gr.Radio(["1-5 min", "5-10 min", "10-15 min"], label="Estimated podcast duration", value="1-5 min") |
|
num_hosts = gr.Radio([1, 2], label="Number of podcast hosts", value=2) |
|
|
|
voice_options = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"] |
|
voice1_select = gr.Dropdown(label="Select Voice 1", choices=voice_options, value="tara") |
|
voice2_select = gr.Dropdown(label="Select Voice 2", choices=voice_options, value="leo") |
|
|
|
generate_btn = gr.Button("Generate Script") |
|
script_output = gr.Textbox(label="Generated Script", lines=10) |
|
|
|
render_btn = gr.Button("Render Podcast") |
|
audio_output = gr.Audio(label="Generated Podcast") |
|
|
|
generate_btn.click(generate_podcast_script, |
|
inputs=[api_key_input, content_input, document_upload, duration, num_hosts], |
|
outputs=script_output) |
|
|
|
render_btn.click(render_podcast, |
|
inputs=[api_key_input, script_output, voice1_select, voice2_select, num_hosts], |
|
outputs=audio_output) |
|
|
|
num_hosts.change(lambda x: gr.update(visible=x == 2), |
|
inputs=[num_hosts], |
|
outputs=[voice2_select]) |
|
|
|
if __name__ == "__main__": |
|
try: |
|
load_model() |
|
demo.queue().launch() |
|
except Exception as e: |
|
logger.error(f"Error launching the application: {str(e)}") |