bluenevus's picture
Update app.py
e3bea0f verified
raw
history blame
9.54 kB
]import gradio as gr
import google.generativeai as genai
import numpy as np
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import snapshot_download, login
import logging
import os
import spaces
import warnings
from snac import SNAC
from dotenv import load_dotenv
load_dotenv()
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
model = None
tokenizer = None
snac_model = None
def load_model():
global model, tokenizer, snac_model
try:
logger.info("Loading SNAC model...")
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
snac_model = snac_model.to(device)
logger.info("Loading Orpheus model...")
model_name = "canopylabs/orpheus-3b-0.1-ft"
hf_token = os.environ.get("HUGGINGFACE_TOKEN")
if not hf_token:
raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
login(token=hf_token)
snapshot_download(
repo_id=model_name,
use_auth_token=hf_token,
allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
ignore_patterns=["optimizer.pt", "pytorch_model.bin", "training_args.bin", "scheduler.pt", "tokenizer.*"]
)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Orpheus model and tokenizer loaded to {device}")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
def generate_podcast_script(api_key, content, uploaded_file, duration, num_hosts):
try:
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
combined_content = content or ""
if uploaded_file:
file_content = uploaded_file.read().decode('utf-8')
combined_content += "\n" + file_content if combined_content else file_content
prompt = f"""
Create a podcast script for {'one person' if num_hosts == 1 else 'two people'} discussing:
{combined_content}
Duration: {duration}. Include natural speech, humor, and occasional off-topic thoughts.
Use speech fillers like um, ah. Vary emotional tone.
Format: {'Monologue' if num_hosts == 1 else 'Alternating dialogue'} without speaker labels.
Separate {'paragraphs' if num_hosts == 1 else 'lines'} with blank lines.
Use emotion tags in angle brackets: <laugh>, <sigh>, <chuckle>, <cough>, <sniffle>, <groan>, <yawn>, <gasp>.
Example: "I can't believe I stayed up all night <yawn> only to find out the meeting was canceled <groan>."
Ensure content flows naturally and stays on topic. Match the script length to {duration}.
"""
response = model.generate_content(prompt)
return re.sub(r'[^a-zA-Z0-9\s.,?!<>]', '', response.text)
except Exception as e:
logger.error(f"Error generating podcast script: {str(e)}")
raise
def process_prompt(prompt, voice, tokenizer, device):
prompt = f"{voice}: {prompt}"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
start_token = torch.tensor([[128259]], dtype=torch.int64)
end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
attention_mask = torch.ones_like(modified_input_ids)
return modified_input_ids.to(device), attention_mask.to(device)
def parse_output(generated_ids):
token_to_find = 128257
token_to_remove = 128258
token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
if len(token_indices[1]) > 0:
last_occurrence_idx = token_indices[1][-1].item()
cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
else:
cropped_tensor = generated_ids
processed_rows = []
for row in cropped_tensor:
masked_row = row[row != token_to_remove]
processed_rows.append(masked_row)
code_lists = []
for row in processed_rows:
row_length = row.size(0)
new_length = (row_length // 7) * 7
trimmed_row = row[:new_length]
trimmed_row = [t - 128266 for t in trimmed_row]
code_lists.append(trimmed_row)
return code_lists[0]
def redistribute_codes(code_list, snac_model):
device = next(snac_model.parameters()).device
layer_1, layer_2, layer_3 = [], [], []
for i in range((len(code_list)+1)//7):
layer_1.append(code_list[7*i])
layer_2.append(code_list[7*i+1]-4096)
layer_3.append(code_list[7*i+2]-(2*4096))
layer_3.append(code_list[7*i+3]-(3*4096))
layer_2.append(code_list[7*i+4]-(4*4096))
layer_3.append(code_list[7*i+5]-(5*4096))
layer_3.append(code_list[7*i+6]-(6*4096))
codes = [
torch.tensor(layer_1, device=device).unsqueeze(0),
torch.tensor(layer_2, device=device).unsqueeze(0),
torch.tensor(layer_3, device=device).unsqueeze(0)
]
audio_hat = snac_model.decode(codes)
return audio_hat.detach().squeeze().cpu().numpy()
def text_to_speech(text, voice, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200):
global model, tokenizer, snac_model
if model is None or tokenizer is None or snac_model is None:
load_model()
if not text.strip():
return None
try:
input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
with torch.no_grad():
generated_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
num_return_sequences=1,
eos_token_id=128258,
)
code_list = parse_output(generated_ids)
audio_samples = redistribute_codes(code_list, snac_model)
return (24000, audio_samples)
except Exception as e:
logger.error(f"Error in text_to_speech: {str(e)}")
raise
def render_podcast(api_key, script, voice1, voice2, num_hosts):
try:
lines = [line for line in script.split('\n') if line.strip()]
audio_segments = []
for i, line in enumerate(lines):
voice = voice1 if num_hosts == 1 or i % 2 == 0 else voice2
try:
result = text_to_speech(line, voice)
if result is not None:
sample_rate, audio = result
audio_segments.append(audio)
except Exception as e:
logger.error(f"Error processing audio segment: {str(e)}")
if not audio_segments:
logger.warning("No valid audio segments were generated.")
return (24000, np.zeros(24000, dtype=np.float32))
podcast_audio = np.concatenate(audio_segments)
podcast_audio = np.clip(podcast_audio, -1, 1)
podcast_audio = (podcast_audio * 32767).astype(np.int16)
return (24000, podcast_audio)
except Exception as e:
logger.error(f"Error rendering podcast: {str(e)}")
raise
with gr.Blocks() as demo:
gr.Markdown("# AI Podcast Generator")
api_key_input = gr.Textbox(label="Enter your Gemini API Key", type="password")
with gr.Row():
content_input = gr.Textbox(label="Paste your content (optional)")
document_upload = gr.File(label="Upload Document (optional)")
duration = gr.Radio(["1-5 min", "5-10 min", "10-15 min"], label="Estimated podcast duration")
num_hosts = gr.Radio([1, 2], label="Number of podcast hosts", value=2)
voice_options = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
voice1_select = gr.Dropdown(label="Select Voice 1", choices=voice_options, value="tara")
voice2_select = gr.Dropdown(label="Select Voice 2", choices=voice_options, value="leo")
generate_btn = gr.Button("Generate Script")
script_output = gr.Textbox(label="Generated Script", lines=10)
render_btn = gr.Button("Render Podcast")
audio_output = gr.Audio(label="Generated Podcast")
generate_btn.click(generate_podcast_script,
inputs=[api_key_input, content_input, document_upload, duration, num_hosts],
outputs=script_output)
render_btn.click(render_podcast,
inputs=[api_key_input, script_output, voice1_select, voice2_select, num_hosts],
outputs=audio_output)
num_hosts.change(lambda x: gr.update(visible=x == 2),
inputs=[num_hosts],
outputs=[voice2_select])
if __name__ == "__main__":
try:
load_model()
demo.launch()
except Exception as e:
logger.error(f"Error launching the application: {str(e)}")