bluenevus's picture
Update app.py
aa10e55 verified
raw
history blame
7.72 kB
import gradio as gr
import google.generativeai as genai
import numpy as np
import re
import torch
import torchaudio
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import snapshot_download, login
import logging
import os
import spaces
import warnings
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Suppress specific warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
device = get_device()
logger.info(f"Using device: {device}")
model = None
tokenizer = None
@spaces.GPU()
def load_model():
global model, tokenizer
logger.info("Loading Orpheus model...")
model_name = "canopylabs/orpheus-3b-0.1-ft"
hf_token = os.environ.get("HUGGINGFACE_TOKEN")
if not hf_token:
raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
try:
login(token=hf_token)
snapshot_download(
repo_id=model_name,
use_auth_token=hf_token,
allow_patterns=[
"config.json",
"*.safetensors",
"model.safetensors.index.json",
],
ignore_patterns=[
"optimizer.pt",
"pytorch_model.bin",
"training_args.bin",
"scheduler.pt",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.json",
"merges.txt",
"tokenizer.*"
]
)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32 if device.type == 'cpu' else torch.bfloat16)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Orpheus model and tokenizer loaded to {device}")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
def generate_podcast_script(api_key, content, uploaded_file, duration, num_hosts):
try:
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
combined_content = content or ""
if uploaded_file:
file_content = uploaded_file.read().decode('utf-8')
combined_content += "\n" + file_content if combined_content else file_content
prompt = f"""
Create a podcast script for {'one person' if num_hosts == 1 else 'two people'} discussing:
{combined_content}
Duration: {duration}. Include natural speech, humor, and occasional off-topic thoughts.
Use speech fillers like um, ah. Vary emotional tone.
Format: {'Monologue' if num_hosts == 1 else 'Alternating dialogue'} without speaker labels.
Separate {'paragraphs' if num_hosts == 1 else 'lines'} with blank lines.
Use emotion tags in angle brackets: <laugh>, <sigh>, <chuckle>, <cough>, <sniffle>, <groan>, <yawn>, <gasp>.
Example: "I can't believe I stayed up all night <yawn> only to find out the meeting was canceled <groan>."
Ensure content flows naturally and stays on topic. Match the script length to {duration}.
"""
response = model.generate_content(prompt)
return re.sub(r'[^a-zA-Z0-9\s.,?!<>]', '', response.text)
except Exception as e:
logger.error(f"Error generating podcast script: {str(e)}")
raise
@spaces.GPU()
def text_to_speech(text, voice):
global model, tokenizer
try:
if model is None or tokenizer is None:
load_model()
# Remove emotion tags for TTS processing
clean_text = re.sub(r'<[^>]+>', '', text)
inputs = tokenizer(clean_text, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=256)
# Convert output tensor to mel spectrogram
mel = output[0].cpu()
# Normalize the mel spectrogram
mel = (mel - mel.min()) / (mel.max() - mel.min())
# Convert mel spectrogram to audio using torchaudio
griffin_lim = torchaudio.transforms.GriffinLim(n_fft=2048, n_iter=10)
audio = griffin_lim(mel.unsqueeze(0))
# Convert to numpy array and ensure it's in the correct format
audio_np = audio.squeeze().numpy()
audio_np = np.clip(audio_np, -1, 1)
return (24000, audio_np.astype(np.float32)) # Assuming 24kHz sample rate
except Exception as e:
logger.error(f"Error in text_to_speech: {str(e)}")
raise
@spaces.GPU()
def render_podcast(api_key, script, voice1, voice2, num_hosts):
try:
lines = [line for line in script.split('\n') if line.strip()]
audio_segments = []
for i, line in enumerate(lines):
voice = voice1 if num_hosts == 1 or i % 2 == 0 else voice2
try:
_, audio = text_to_speech(line, voice)
audio_segments.append(audio)
except Exception as e:
logger.error(f"Error processing audio segment: {str(e)}")
if not audio_segments:
logger.warning("No valid audio segments were generated.")
return (24000, np.zeros(24000, dtype=np.float32))
podcast_audio = np.concatenate(audio_segments)
# Ensure the audio is in the correct format for Gradio
podcast_audio = np.clip(podcast_audio, -1, 1)
podcast_audio = (podcast_audio * 32767).astype(np.int16)
return (24000, podcast_audio)
except Exception as e:
logger.error(f"Error rendering podcast: {str(e)}")
raise
with gr.Blocks() as demo:
gr.Markdown("# AI Podcast Generator")
api_key_input = gr.Textbox(label="Enter your Gemini API Key", type="password")
with gr.Row():
content_input = gr.Textbox(label="Paste your content (optional)")
document_upload = gr.File(label="Upload Document (optional)")
duration = gr.Radio(["1-5 min", "5-10 min", "10-15 min"], label="Estimated podcast duration")
num_hosts = gr.Radio([1, 2], label="Number of podcast hosts", value=2)
voice_options = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
voice1_select = gr.Dropdown(label="Select Voice 1", choices=voice_options, value="tara")
voice2_select = gr.Dropdown(label="Select Voice 2", choices=voice_options, value="leo")
generate_btn = gr.Button("Generate Script")
script_output = gr.Textbox(label="Generated Script", lines=10)
render_btn = gr.Button("Render Podcast")
audio_output = gr.Audio(label="Generated Podcast")
generate_btn.click(generate_podcast_script,
inputs=[api_key_input, content_input, document_upload, duration, num_hosts],
outputs=script_output)
render_btn.click(render_podcast,
inputs=[api_key_input, script_output, voice1_select, voice2_select, num_hosts],
outputs=audio_output)
num_hosts.change(lambda x: gr.update(visible=x == 2),
inputs=[num_hosts],
outputs=[voice2_select])
if __name__ == "__main__":
try:
load_model()
demo.launch()
except Exception as e:
logger.error(f"Error launching the application: {str(e)}")